In [None]:
import sys
lib_path = [r'C:\Users\ikahbasi\OneDrive\Applications\GitHub\SeisRoutine',
            r'C:\Users\ikahb\OneDrive\Applications\GitHub\SeisRoutine']
for path in lib_path:
    sys.path.append(path)
##########################################################################
import SeisRoutine.catalog as src
import SeisRoutine.waveform as srw
import SeisRoutine.config as srconf
import SeisRoutine.statistics as srs

In [None]:
import seisbench.generate as sbg
import seisbench.models as sbm
import torch
from tqdm import tqdm
from scipy import signal
import os
import seisbench.data as sbd
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import label
import pandas as pd

# Functions and Classes

In [None]:
def find_ps_pairs(metadata):
    keys = metadata.keys()
    df_p = metadata[[key for key in keys
                     if (key.upper().startswith('trace_P'.upper())
                         and
                         key.upper().endswith('_arrival_sample'.upper())
                         )
                    ]]
    p_condition = df_p.notna().any(axis=1)
    ############################################################################
    df_s = metadata[[key for key in keys
                     if (key.upper().startswith('trace_S'.upper())
                         and
                         key.upper().endswith('_arrival_sample'.upper())
                         )
                    ]]
    s_condition = df_s.notna().any(axis=1)
    ############################################################################
    ps_pairs_condition = s_condition == p_condition
    return ps_pairs_condition

# Run

In [None]:
path = r'F:\DataSets-Local\Merged_All_DataSets_2025-07-10 (Ahar-Ilam-Kaki-Qeshm)\metadata.csv'
df_metadata = pd.read_csv(path, low_memory=False)

In [None]:
df_metadata.shape

In [None]:
cond_PS_pairs = find_ps_pairs(metadata=df_metadata)
sum(cond_PS_pairs)

In [None]:
key = 'trace_npts'
cond_data_available = (df_metadata[key] == 3001)
sum(cond_data_available)

In [None]:
treshold_snr = 2
keys = [key for key in df_metadata.keys() if key.endswith('_snr')]
cond_good_snr_channels = df_metadata[keys] >= treshold_snr
sum(cond_good_snr_channels.sum(axis=1)==3)

# Skewness

In [None]:
path = r'F:\DataSets-Local\Merged_All_DataSets_2025-07-10 (Ahar-Ilam-Kaki-Qeshm)\metadata-with-skewness.pkl'
# path = r'F:\DataSets-Local\Merged_All_DataSets_2025-07-10 (Ahar-Ilam-Kaki-Qeshm)\metadata-with-skewness-old.pkl'
df_skewness = pd.read_pickle(path)

In [None]:
keys = [key for key in df_skewness.keys() if key.endswith('skewness')]
keys1 = [key for key in keys if 'no-filter' in key]
keys2 = [key for key in keys if 'with-filter' in key]

In [None]:
treshold_skewness = 5
cond_good_skewness1_channels = (df_skewness[keys1].abs() <= treshold_skewness)
cond_good_skewness2_channels = (df_skewness[keys2].abs() <= treshold_skewness)

# Noisy Data (Frequency)

In [None]:
path = r'F:\DataSets-Local\Merged_All_DataSets_2025-07-10 (Ahar-Ilam-Kaki-Qeshm)\metadata-with-frequency.pkl'
df_fft = pd.read_pickle(path)
keys = [key for key in df_fft.keys() if key.endswith('fft')]


In [None]:
for channel in ['E', 'N', 'Z']:
    m_band = df_fft[f'trace_{channel}_max_M-band_fft']
    h_band = df_fft[f'trace_{channel}_max_H-band_fft']
    df_fft[f'trace_{channel}_noise_level'] = h_band / m_band

In [None]:
keys = [key for key in df_fft.keys() if key.endswith('_noise_level')]

treshold_noisy_level = 1
cond_good_noisy_channels = df_fft[keys] < treshold_noisy_level
# cond_good_noisy_channel

# Flat Signal

In [None]:
path = r'F:\DataSets-Local\Merged_All_DataSets_2025-07-10 (Ahar-Ilam-Kaki-Qeshm)\metadata-with-std.pkl'
df_std = pd.read_pickle(path)

In [None]:
threshold_flatness = 0.01

keys = [key for key in df_std.keys()
        if key.endswith('_std')]
cond_no_flat_signals_channels = threshold_flatness <= df_std[keys]

# Merge

In [None]:
cond_all = pd.concat([cond_no_flat_signals_channels,
                      cond_good_noisy_channels,
                      cond_good_skewness1_channels, 
                      cond_good_skewness2_channels,
                      cond_good_snr_channels],
                      axis=1)

In [None]:
keys_z = [key for key in cond_all.keys() if '_Z_' in key]
keys_e = [key for key in cond_all.keys() if '_E_' in key]
keys_n = [key for key in cond_all.keys() if '_N_' in key]

df_channel_condition = pd.DataFrame({
    'z': cond_all[keys_z].all(axis=1),
    'n': cond_all[keys_n].all(axis=1),
    'e': cond_all[keys_e].all(axis=1),
    })

In [None]:
df_channel_condition.iloc[32]

In [None]:
func = lambda x: np.where(x)[0].tolist()
cond = df_channel_condition.apply(func, axis=1)

In [None]:
def check(lst, pattern = {0, 1, 2}):
    output = lst
    if len(lst)!=3 and len(lst)!=0:
        target = pattern - set(lst)
        replacement_element = lst[0]
        for el in sorted(target):
            output.insert(el, replacement_element)
    return output

In [None]:
channel_status = cond.apply(check)

In [None]:
init_cfg = srconf.load_config('0-init-cfg.yml')
cfg_path = os.path.join(init_cfg.target_config_filepath,
                        init_cfg.target_config_filename)
cfg = srconf.load_config(cfg_path)

In [None]:
dataset = sbd.WaveformDataset(
    path=cfg.dataset.path,
    sampling_rate=cfg.training.dataset.sampling_rate,
    component_order=cfg.training.dataset.component_order,
          )

In [None]:
dataset.metadata['channel_status'] = channel_status

In [None]:
class channel_condition:
    def __init__(self, alpha=0.3, key='X'):
        self.alpha = alpha  # Tapering Coefficient
        if isinstance(key, str):
            self.key = (key, key)
        else:
            self.key = key

    def __call__(self, state_dict):
        x, metadata = state_dict[self.key[0]]
        channel_status = metadata['channel_status']
        x = x[channel_status]
        state_dict[self.key[1]] = (x, metadata)

In [None]:
sps = 100
augmentations = [
    # Tapering(),
    sbg.Filter(N=4,
               Wn=[0.5],
               btype='highpass',
               forward_backward=True,
               ),
    channel_condition(),
    sbg.Normalize(
        demean_axis=-1,
        amp_norm_axis=-1,
        amp_norm_type="peak"),
    sbg.ChangeDtype(np.float32),
]
generator = sbg.GenericGenerator(dataset)
generator.add_augmentations(augmentations)

In [None]:
sps = 100
augmentations0 = [
    # Tapering(),
    sbg.Filter(N=4,
               Wn=[0.5],
               btype='highpass',
               forward_backward=True,
               ),
    sbg.Normalize(
        demean_axis=-1,
        amp_norm_axis=-1,
        amp_norm_type="peak"),
    sbg.ChangeDtype(np.float32),
]
generator0 = sbg.GenericGenerator(dataset)
generator0.add_augmentations(augmentations0)

In [None]:
ii = 32
metadata = dataset.metadata.iloc[ii]
print(ii, metadata['channel_status'])
data = generator[ii]
data_X = data['X']
###
data0 = generator0[ii]
data_X0 = data0['X']
###
fig, axes = plt.subplots(1, 2,
    figsize=(10, 3))
label = [_ for _ in dataset.component_order]
axes[0].plot(data_X.T +[-1, 0, 1], label=label)
axes[1].plot(data_X0.T+[-1, 0, 1], label=label)
plt.legend(loc=1)
plt.show()
print(cond_all.iloc[ii])
print(df_channel_condition.iloc[ii], cond[ii])

In [None]:
n_plots = 0
for ii in range(len(dataset.metadata)):
    metadata = dataset.metadata.iloc[ii]
    print(ii, metadata['channel_status'])
    data = generator[ii]
    data_X = data['X']
    ###
    data0 = generator0[ii]
    data_X0 = data0['X']
    ###
    if data_X.shape != (3, 3001):
        continue
    if dataset.metadata['channel_status'][ii] == [0, 1, 2]:
        continue
    ###
    fig, axes = plt.subplots(1, 2,
        figsize=(10, 3))
    label = [_ for _ in dataset.component_order]
    axes[0].plot(data_X.T +[-1, 0, 1], label=label)
    axes[1].plot(data_X0.T+[-1, 0, 1], label=label)
    plt.legend(loc=1)
    plt.show()
    n_plots += 1
    if n_plots == 10:
        break

In [None]:
dataset.metadata['channel_status'][ii]

In [None]:
metadata = dataset.metadata.iloc[32]

with pd.option_context('display.max_rows', None):
    print(metadata, type(metadata))

In [None]:
data

In [None]:
n = [18, 30,37,76,87,97,105,109,126,137,149,158,161,173,179,196,204,209, 217,
     222, 230,235,251,254,257,261,260,265,266,267,268,273,275,290,302,310,
     312,323,331,335,355,367,376,388,391,393]


for ii in n:
    metadata = dataset.metadata.iloc[ii]
    print(ii, metadata['channel_status'])
    data = generator[ii]
    data_X = data['X']
    ###
    data0 = generator0[ii]
    data_X0 = data0['X']
    ###
    if data_X0.shape != (3, 3001):
        continue
#     if dataset.metadata['channel_status'][ii] == [0, 1, 2]:
#         continue
    ###
    fig, axes = plt.subplots(1, 2,
        figsize=(10, 3))
    label = [_ for _ in dataset.component_order]
    try:
        axes[0].plot(data_X.T + [-1, 0, 1], label=label)
        axes[1].plot(data_X0.T + [-1, 0, 1], label=label)
    except:
        pass
    plt.legend(loc=1)
    plt.show()

In [None]:
cond = dataset.metadata['channel_status'].apply(len)
cond.plot(kind='hist')

In [None]:
def find_ps_pairs(metadata):
    keys = metadata.keys()
    df_p = metadata[[key for key in keys
                     if (key.upper().startswith('trace_P'.upper())
                         and
                         key.upper().endswith('_arrival_sample'.upper())
                         )
                    ]]
    p_condition = df_p.notna().any(axis=1)
    ############################################################################
    df_s = metadata[[key for key in keys
                     if (key.upper().startswith('trace_S'.upper())
                         and
                         key.upper().endswith('_arrival_sample'.upper())
                         )
                    ]]
    s_condition = df_s.notna().any(axis=1)
    ############################################################################
    ps_pairs_condition = s_condition == p_condition
    return ps_pairs_condition

In [None]:
cond_PS_pairs = find_ps_pairs(metadata=df_metadata)


In [None]:
sum((cond==3) & cond_PS_pairs)

In [None]:
sum(cond==3), sum(cond_PS_pairs), sum((cond==3) & cond_PS_pairs)

In [None]:
cond_PS_pairs.size - 23432#sum((cond==3) & cond_PS_pairs)

In [None]:
list(df_metadata.keys())