In [None]:
import sys
lib_path = [r'C:\Users\ikahbasi\OneDrive\Applications\GitHub\SeisRoutine',
            r'C:\Users\ikahb\OneDrive\Applications\GitHub\SeisRoutine']
for path in lib_path:
    sys.path.append(path)
##########################################################################
import SeisRoutine.catalog as src
import SeisRoutine.waveform as srw
import SeisRoutine.config as srconf
import SeisRoutine.statistics as srs

In [None]:
import seisbench.data as sbd
import seisbench.generate as sbg

import os
import numpy as np
from scipy import signal
from pprint import pprint as pp
import re
from scipy.stats import skew
from tqdm import tqdm

# Function and Class

In [None]:
class Tapering:
    def __init__(self, alpha=0.3, key='X'):
        self.alpha = alpha  # Tapering Coefficient
        if isinstance(key, str):
            self.key = (key, key)
        else:
            self.key = key

    def __call__(self, state_dict):
        x, metadata = state_dict[self.key[0]]
        taper = signal.windows.tukey(x.shape[-1], self.alpha)
        x = x * taper
        state_dict[self.key[1]] = (x, metadata)

In [None]:
def get_phase_time(metadata):
    keys = list(filter(re.compile("trace_[PpSs].*_arrival_sample").match,
                       metadata.keys()))
    times = {key.lower(): val
             for key, val in metadata.items()
             if (key in keys) and not np.isnan(val)}
    p = None
    s = None
    for key, val in times.items():
        if key.startswith('trace_p'):
            p = val
        if key.startswith('trace_s'):
            s = val
    return p, s

# Reading CSV

In [None]:
init_cfg = srconf.load_config('0-init-cfg.yml')
cfg_path = os.path.join(init_cfg.target_config_filepath,
                        init_cfg.target_config_filename)
cfg = srconf.load_config(cfg_path)

In [None]:
phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",

    "trace_Pg_arrival_sample": "P",
    "trace_PG_arrival_sample": "P",

    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",

    "trace_Sg_arrival_sample": "S",
    "trace_SG_arrival_sample": "S",

    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}

In [None]:
sps = 100
augmentations = [
    # Tapering(),
    # sbg.Filter(N=4,
    #         Wn=[1],
    #         btype='highpass',
    #         forward_backward=True,
    #         ),
    sbg.Normalize(
        demean_axis=-1,
        amp_norm_axis=-1,
        amp_norm_type="peak"),
    # sbg.FixedWindow(
    #     p0=-15*sps,
    #     windowlen=1*60*sps,
    #     strategy="pad",
    #     key='X'),
    # sbg.WindowAroundSample(
    #     metadata_keys=list(phase_dict.keys()),
    #     samples_before=2000,
    #     windowlen=5000,
    #     selection="random",
    #     strategy="variable"),
    # sbg.GaussianNoise(
    #     scale=(0, 0.02),
    #     key='X'),
    # sbg.RandomWindow(
    #     windowlen=3001),
    sbg.ChangeDtype(np.float32),
    sbg.ProbabilisticLabeller(
        label_columns=phase_dict,
        model_labels=cfg.training.hyperparameters.phases,
        sigma=30,
        dim=0),
]

In [None]:
dataset = sbd.WaveformDataset(
    path=cfg.dataset.path,
    sampling_rate=cfg.training.dataset.sampling_rate,
    component_order=cfg.training.dataset.component_order,
   # dimension_order=cfg.training.dataset.dimension_order # must recheck!
   )
# dataset.filter(~(dataset.metadata['trace_name'] == "bucket2$268,:3,:3001").values, inplace=True)
generator = sbg.GenericGenerator(dataset)
generator.add_augmentations(augmentations)
# generator[0]

In [None]:
# df = pd.read_csv('F:\DataSets-Local\Merged_All_DataSets_2025-07-10 (Ahar-Ilam-Kaki-Qeshm)')
# df.shape

# Mark PS-Pairs

In [None]:
def find_ps_pairs(metadata, start_key, end_key):
    keys = metadata.keys()
    df_p = metadata[[key for key in keys
                     if (key.upper().startswith('trace_P'.upper())
                         and
                         key.upper().endswith('_arrival_sample'.upper())
                         )
                    ]]
    p_condition = df_p.notna().any(axis=1)
    ############################################################################
    df_s = metadata[[key for key in keys
                     if (key.upper().startswith('trace_S'.upper())
                         and
                         key.upper().endswith('_arrival_sample'.upper())
                         )
                    ]]
    s_condition = df_s.notna().any(axis=1)
    ############################################################################
    ps_pairs_condition = s_condition == p_condition
    return ps_pairs_condition

In [None]:
metadata = dataset.metadata.copy()
df = metadata

In [None]:
df_p = df[[key for key in df.keys()
           if (key.upper().startswith('trace_P'.upper())
               and
               key.upper().endswith('_arrival_sample'.upper()))
         ]]
p_condition = df_p.notna().any(axis=1)
################################################################################
df_s = df[[key for key in df.keys()
           if (key.upper().startswith('trace_S'.upper())
               and
               key.upper().endswith('_arrival_sample'.upper()))
         ]]
s_condition = df_s.notna().any(axis=1)
################################################################################
ps_pairs_condition = s_condition == p_condition
df['PS-pairs'] = ps_pairs_condition
print(ps_pairs_condition.sum())

In [None]:
metadata = metadata[df['PS-pairs']]

In [None]:
sps = dataset.sampling_rate
for ii in tqdm(metadata.index):
    data_3c, _ = dataset.get_sample(ii)
    for data, channel in zip(data_3c, dataset.component_order):
        ########################################################################
        freq, ampl = srw.waveform.fft(array=data, delta=1/sps)
        fft_low = ampl[freq<1]
        fft_mid = ampl[(1<=freq) & (freq<20)]
        fft_hig = ampl[20<=freq]
        metadata.at[ii, f'trace_{channel}_fft_max_L-band'] = fft_low.max().round(3)
        metadata.at[ii, f'trace_{channel}_fft_max_M-band'] = fft_mid.max().round(3)
        metadata.at[ii, f'trace_{channel}_fft_max_H-band'] = fft_hig.max().round(3)
        ########################################################################
        metadata.at[ii, f'trace_{channel}_skewness'] = skew(data)
    # break

In [None]:
channel = 'E'
noisy_e = metadata[f'trace_{channel}_fft_max_M-band'] < metadata[f'trace_{channel}_fft_max_H-band']
channel = 'N'
noisy_n = metadata[f'trace_{channel}_fft_max_M-band'] < metadata[f'trace_{channel}_fft_max_H-band']
channel = 'Z'
noisy_z = metadata[f'trace_{channel}_fft_max_M-band'] < metadata[f'trace_{channel}_fft_max_H-band']

outlier_noisy = (noisy_e & noisy_n & noisy_z)
sum(outlier_noisy)

In [None]:
import matplotlib.pyplot as plt
from SeisRoutine.waveform.waveform import fft

In [None]:
for ii in outlier_noisy[outlier_noisy].index:
    sample = generator[ii]
    data_3c = sample['X']
    fig, axes = plt.subplots(1, 2,
            figsize=(15, 4))
    jj = -1
    for _x, channel in zip(data_3c, dataset.component_order):
        freq, ampl = fft(array=_x, delta=0.01)
        axes[0].plot(_x+jj, label=channel)
        axes[1].semilogx(freq, ampl, label=channel)
        jj += 1
    plt.legend()
    plt.show()
    # break

In [None]:
skewness_treshold = 3
keys = [f'trace_{cha}_skewness' for cha in "ENZ"]
outlier = (metadata[keys].abs() > skewness_treshold).sum(axis=1)
outlier.hist(log=False)
outlier_skewness = (outlier > 1)

In [None]:
for ii in outlier_skewness[outlier_skewness].index:
    sample = generator[ii]
    data_3c = sample['X']
    fig, axes = plt.subplots(1, 2,
            figsize=(15, 4))
    jj = -1
    for _x, channel in zip(data_3c, dataset.component_order):
        freq, ampl = fft(array=_x, delta=0.01)
        axes[0].plot(_x+jj, label=channel)
        axes[1].semilogx(freq, ampl, label=channel)
        jj += 1
    plt.legend()
    plt.show()
    # break

In [None]:
outlier_all = (outlier_skewness & outlier_noisy)

In [None]:
for ii in outlier_all[outlier_all].index:
    sample = generator[ii]
    data_3c = sample['X']
    fig, axes = plt.subplots(1, 2,
            figsize=(15, 4))
    jj = -1
    for _x, channel in zip(data_3c, dataset.component_order):
        freq, ampl = fft(array=_x, delta=0.01)
        axes[0].plot(_x+jj, label=channel)
        axes[1].semilogx(freq, ampl, label=channel)
        jj += 1
    plt.legend()
    plt.show()
    # break