In [None]:
import sys
lib_path = [r'C:\Users\ikahbasi\OneDrive\Applications\GitHub\SeisRoutine',
            r'C:\Users\ikahb\OneDrive\Applications\GitHub\SeisRoutine']
for path in lib_path:
    sys.path.append(path)
##########################################################################
import SeisRoutine.catalog as src
import SeisRoutine.waveform as srw
import SeisRoutine.config as srconf
import SeisRoutine.statistics as srs

In [None]:
import seisbench.generate as sbg
import seisbench.models as sbm
import torch
from tqdm import tqdm
from scipy import signal
import os
import seisbench.data as sbd
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import label
import pandas as pd

# Classes and Functions

In [None]:
class Tapering:
    def __init__(self, alpha=0.3, key='X'):
        self.alpha = alpha  # Tapering Coefficient
        if isinstance(key, str):
            self.key = (key, key)
        else:
            self.key = key

    def __call__(self, state_dict):
        x, metadata = state_dict[self.key[0]]
        taper = signal.windows.tukey(x.shape[-1], self.alpha)
        x = x * taper
        state_dict[self.key[1]] = (x, metadata)

In [None]:
def find_peaks(data, treshold):
    mask = data > treshold
    labeled, num_features = label(mask)
    peaks = []
    for i in range(1, num_features + 1):
        segment_indices = np.where(labeled == i)[0]
        segment_values = data[segment_indices]
        max_index = np.argmax(segment_values)
        max_index_in_segment = segment_indices[max_index]
        peaks.append(max_index_in_segment)
    return peaks

# Loading DataSets

In [None]:
init_cfg = srconf.load_config('0-init-cfg.yml')
cfg_path = os.path.join(init_cfg.target_config_filepath,
                        init_cfg.target_config_filename)
cfg = srconf.load_config(cfg_path)

In [None]:
dataset = sbd.WaveformDataset(
    path=cfg.dataset.path,
    sampling_rate=cfg.training.dataset.sampling_rate,
    component_order=cfg.training.dataset.component_order,
          )

In [None]:
sps = 100
augmentations = [
    # Tapering(),
    sbg.Filter(N=4,
               Wn=[0.5],
               btype='highpass',
               forward_backward=True,
               ),
    sbg.Normalize(
        demean_axis=-1,
        amp_norm_axis=-1,
        amp_norm_type="peak"),
    # sbg.FixedWindow(
    #     p0=-15*sps,
    #     windowlen=1*60*sps,
    #     strategy="pad",
    #     key='X'),
    sbg.ChangeDtype(np.float32),
]
generator = sbg.GenericGenerator(dataset)
generator.add_augmentations(augmentations)

# Loading Auto-Pickers

In [None]:
models = {'phasenet-original': sbm.PhaseNet(phases='NPS').from_pretrained('original')}
for key, model in models.items():
    model.cuda()

# Run

In [None]:
metadata = dataset.metadata.copy()

In [None]:
lst = []
for ii in tqdm(range(len(metadata))):
    data = generator[ii]
    data_X = data['X']
    lst.append(data_X.shape)
for index, el in enumerate(lst):
    if el != (3, 3001):
        print(f'{el=}\t{index=}\t{lst.count(el)=}')

In [None]:
treshold = 0.3
for key, model in models.items():
    for phase in ['P', 'S']:
        metadata[f'trace_{phase}_{key}-AutoPik'] = None  # or object
        metadata = metadata.astype({f'trace_{phase}_{key}-AutoPik': 'object'}, copy=False)

for ii in tqdm(range(len(metadata))):
    data = generator[ii]
    data_X = data['X']
    if data_X.shape != (3, 3001):
        continue
    # plt.plot(data_X.T); plt.show()
    for key, model in models.items():
        with torch.no_grad():
            data_X = torch.tensor(data_X, device=model.device).unsqueeze(0)
            pred = model(data_X)
            pred = pred.cpu().detach().numpy().squeeze()
        p_peaks = find_peaks(pred[1], treshold=treshold)
        s_peaks = find_peaks(pred[2], treshold=treshold)
        # print(p_peaks, s_peaks)
        metadata.at[ii, f'trace_P_{key}-AutoPik'] = p_peaks
        metadata.at[ii, f'trace_S_{key}-AutoPik'] = s_peaks
        #
        # plt.plot(pred.T[:, 1:]); plt.legend([_ for _ in 'PS'])
        # plt.vlines(p_peaks, ymin=0, ymax=1, colors='red')
        # plt.vlines(s_peaks, ymin=0, ymax=1, colors='blue')
        # results[key] = pred
    # plt.show()
    # if ii == 1:
    #     break
path = os.path.join(cfg.dataset.path, 'metadata-with-AutoPicks.pkl')
metadata.to_pickle(path)