In [3]:
import logging
import os

from obspy.clients.fdsn import Client
import numpy as np
import obspy
import matplotlib.pyplot as plt
import datetime
from datetime import timedelta
import pandas as pd
import dask
from dask.diagnostics import ProgressBar

from obspy.clients.fdsn.client import Client
from obspy.core.utcdatetime import UTCDateTime
from obspy import Stream

from pnwstore.mseed import WaveformClient
import torch
import numpy as np
from tqdm import tqdm
import time 
import pandas as pd
import gc
import seisbench.models as sbm
from ELEP.elep.ensemble_statistics import ensemble_statistics
from ELEP.elep.ensemble_coherence import ensemble_semblance 
from ELEP.elep.trigger_func import picks_summary_simple

In [4]:
device = torch.device("cpu")

# Define clients
client_inventory = Client('IRIS')
client_waveform = WaveformClient()
client_ncedc = Client('NCEDC')

twin = 6000     # length of time window
step = 3000     # step length
l_blnd, r_blnd = 500, 500

In [None]:
def plot_pred(batch_pred)

    pretrain_list = ['original', 'ethz', 'instance', 'scedc', 'stead']
    
    plt.figure(figsize = (10, 3))

    iwin = np.random.randint(100)
    cat_p = metadata.loc[iwin, 'trace_P_arrival_sample']
    plt.figure(figsize = (10, 8))
    plt.subplots_adjust(hspace = 0.1)
    plt.subplot(3, 1, 1)
    plt.title(f"P-wave (index {iwin})")
    for ipre, pretrain in enumerate(pretrain_list):
        if not np.isnan(pretrain_peak[0, iwin, ipre]):
            plt.plot(batch_pred[0, ipre, iwin, :], 
                 label = f"{pretrain} ($\\Delta t = ${np.abs(pretrain_peak[0, iwin, ipre] - cat_p)/100} sec)")
        else:
            plt.plot(batch_pred[0, ipre, iwin, :], 
                 label = f"{pretrain} (N/A)")
    plt.ylabel("Prediction", fontsize = 15); 
    plt.ylim([0, 1]); plt.xlim([0, twin]); plt.xticks([]);
    plt.hlines(p_thrd, 0, twin, linestyle = '--', color = 'k')
    plt.legend(ncols = 2, loc = 'upper right')

In [None]:
def plot_waveforms(idx,mycatalog,mycatalog_picks,network)
    # Plot the earthquake moveout for one of the unmatched events for all stations 
    # event = new_events_deg.iloc[idx]
    event=mycatalog
    picks = mycatalog_picks
    pick_sta = np.unique(picks['station'])

    # otime = UTCDateTime(event['datetime'])
    otime =UTCDateTime(event["datetime"].iloc[idx])
    distances = []

    # Assuming networks_stas is a list of tuples with network and station identifiers
    for station in pick_sta:
        try:
            sta_inv = client2.get_stations(network=network,
                                           station=station, channel="?H?", 
                                           starttime=otime - 1e4, endtime=otime + 1e4)[0][0]
        except Exception as e:
            print(f"Failed to fetch for {network} {station} {otime}: {e}")
            continue

        slat = sta_inv.latitude
        slon = sta_inv.longitude
        olat = event['latitude']
        olon = event['longitude']

        dis1 = locations2degrees(olat, olon, slat, slon)
        dist = degrees2kilometers(dis1)
        distances.append([None,station,dist])

    # Sort distances
    distances = sorted(distances, key=lambda item: item[-1])
    distances = distances[0:11]
    # print(distances)
    # print(otime)
    plt.figure(dpi=150)
    for i, ii in enumerate(distances):
        st = client.get_waveforms(network="*",
                                  station=ii[1], channel="?HZ", starttime=otime-30, endtime=otime+120)
        st = obspy.Stream(filter(lambda st:st.stats.sampling_rate>10, st))
        st.filter(type='bandpass',freqmin=4,freqmax=15)

        trim_st = st.copy()
        if len(trim_st)>0:
            trim_st = trim_st.normalize()
            offsets1  = ii[2]
            offsets2 = 0
    #         for ii in range(len(trim_st)):
            wave=trim_st[0].data
            wave=wave/np.nanmax(wave,axis=-1,keepdims=True)
            plt.plot(trim_st[0].times(),wave *30+offsets1, 
                     color = 'black', alpha=0.7, lw=0.5)    
    #         time_pick = [[x['time_pick'], x['phase']] for _, x in mycatalog[mycatalog['idx'] == idx].iterrows() 
    #                      if x['station'] == sta]
    #         if len(time_pick) > 0:
    #             for p in time_pick:
    #                 if p[1] == 'P':
            plt.text(trim_st[0].times()[0]-5, trim_st[0].data[0] * 10 + offsets1-2, 
                         [ii[1]], fontsize=8, verticalalignment='bottom')

    #         plt.vlines(ii[2]/5, offsets1-5, 
    #                          offsets1+5, color='r')
            sta_picks = picks.loc[picks['station']==ii[1]]

            p_picks = sta_picks.loc[sta_picks['phase']=='P']
            s_picks = sta_picks.loc[sta_picks['phase']=='S']


            if len(p_picks)>0:
                plt.vlines(UTCDateTime(p_picks.iloc[0]['time_pick'])-otime+30, offsets1-5, 
                             offsets1+5, color='r')

            if len(s_picks)>0:
                plt.vlines(UTCDateTime(s_picks.iloc[0]['time_pick'])-otime+30, offsets1-5, 
                             offsets1+5, color='b')

    #                 else:
    #                     plt.vlines(p[0], offsets1[ii]*0.5+offsets2[i]-1, 
    #                                      offsets1[ii]*0.5+offsets2[i]+1, color='b')
        else:                 
            pass 
    plt.title(f"Event Offshore Oregon (Z component): Origin Time={otime}, Latitude={round(event['latitude'].iloc[0],2)}, Longtitude={round(event['longitude'].iloc[0],2)}")
    plt.xlabel('Time [sec]')
    plt.ylabel('Distance [km]')
    plt.ylim(0,420)
    plt.xlim(20,150)

    plt.grid(alpha=0.5)

    plt.savefig("event_offshore_OR_Z_010.pdf", format="pdf", bbox_inches="tight")
    plt.show()

In [None]:
def run_detection(network,station,t1,t2,filepath,twin,step,l_blnd,r_blnd,lat,lon,elev):
    # Define tstring
    tstring = t1.strftime('%Y%m%d')

    if os.path.exists(filepath+station+'_'+tstring+'.csv'):
        print('File '+filepath+station+'_'+tstring+'.csv already exists')
        return
    # print the file path 
    print('test1')
    print(filepath+station+'_'+tstring+'.csv')
	# Load data
	# Reshape data
	# Predict on base models
	# Stack
	# Create and write csv file. Define file name using the station code and the input filepath
    
    
    network = network
#     channels = '[HB][HN][BH]?'
#     channels = 'HH?,HN?,BH?' 
    channels = '?H?'
    
   
    
    # Get waveforms and filter
    
    try:
        if network in ['NC', 'BK']:
            # Query waveforms
            _sdata = client_ncedc.get_waveforms(network=network, station=station, location="*", channel=channels,
                                               starttime=UTCDateTime(t1), endtime=UTCDateTime(t2))
        else: 
            _sdata = client_waveform.get_waveforms(network=network, station=station, channel=channels, 
                                              starttime=UTCDateTime(t1), endtime=UTCDateTime(t2))
    except obspy.clients.fdsn.header.FDSNNoDataException:
        print(f"WARNING: No data for {network}.{station}.{channels} on {t1}.")
        return
    
#     sdata = sdata.select(channel = "[HB]H?")
        
# Create a new stream
    sdata = Stream()
# Check for HH and BH channels presence
    has_HH = bool(_sdata.select(channel="HH?"))
    has_BH = bool(_sdata.select(channel="BH?"))

    # Apply selection logic based on channel presence
    if has_HH and has_BH:
        # If both HH and BH channels are present, select only HH
        sdata += _sdata.select(channel="HH?")
    elif has_HH:
        # If only HH channels are present
        sdata += _sdata.select(channel="HH?")
    elif has_BH:
        # If only BH channels are present
        sdata += _sdata.select(channel="BH?")

    ###############################
    # If no data returned, skipping
    if len(sdata) == 0:
        logging.warning("No stream returned. Skipping.")
        return
    ###############################
    
    sdata.filter(type='bandpass',freqmin=4,freqmax=15)
    
    ###############################
    sdata.merge(fill_value='interpolate') # fill gaps if there are any.
    ###############################

    # Get the necassary information about the station
    delta = sdata[0].stats.delta
    starttime = sdata[0].stats.starttime
    fs = sdata[0].stats.sampling_rate
    dt = 1/fs
    

    # Make all the traces in the stream have the same lengths
    max_starttime = max([tr.stats.starttime for tr in sdata])
    min_endtime = min([tr.stats.endtime for tr in sdata])
    
    for tr in sdata:
        tr.trim(starttime=max_starttime,endtime=min_endtime, nearest_sample=True)    
        
    # Reshaping data
    arr_sdata = np.array(sdata)
    npts = arr_sdata.shape[1]
    ############################### avoiding errors at the end of a stream
   #nseg = int(np.ceil((npts - twin) / step)) + 1
    nseg = int(np.floor((npts - twin) / step)) + 1
    ###############################
    windows = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    tap = 0.5 * (1 + np.cos(np.linspace(np.pi, 2 * np.pi, 6)))
    
    # Define the parameters for semblance
    paras_semblance = {'dt':dt, 'semblance_order':2, 'window_flag':True, 
                   'semblance_win':0.5, 'weight_flag':'max'}
    p_thrd, s_thrd = 0.05, 0.05

    windows_std = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    windows_max = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    windows = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    windows_idx = np.zeros(nseg, dtype=np.int32)

    for iseg in range(nseg):
        idx = iseg * step
        windows[iseg, :] = arr_sdata[:, idx:idx + twin]
        windows[iseg, :] -= np.mean(windows[iseg, :], axis=-1, keepdims=True)
        # original use std norm
        windows_std[iseg, :] = windows[iseg, :] / np.std(windows[iseg, :]) + 1e-10
        # others use max norm
        windows_max[iseg, :] = windows[iseg, :] / (np.max(np.abs(windows[iseg, :]), axis=-1, keepdims=True))
        windows_idx[iseg] = idx

    # taper
    windows_std[:, :, :6] *= tap; windows_std[:, :, -6:] *= tap[::-1]; 
    windows_max[:, :, :6] *= tap; windows_max[:, :, -6:] *= tap[::-1];
    del windows

    print(f"Window data shape: {windows_std.shape}")
    
    # Predict on base models
    
    pretrain_list = ['original', 'ethz', 'instance', 'scedc', 'stead']

    # dim 0: 0 = P, 1 = S
    batch_pred = np.zeros([2, len(pretrain_list), nseg, twin], dtype = np.float32) 
    for ipre, pretrain in enumerate(pretrain_list):
        print('test10')
        t0 = time.time()
        eqt = sbm.EQTransformer.from_pretrained(pretrain)
        eqt.to(device);
        eqt._annotate_args['overlap'] = ('Overlap between prediction windows in samples \
                                        (only for window prediction models)', step)
        eqt._annotate_args['blinding'] = ('Number of prediction samples to discard on \
                                         each side of each window prediction', (l_blnd, r_blnd))
        eqt.eval();
        if pretrain == 'original':
            # batch prediction through torch model
            windows_std_tt = torch.Tensor(windows_std)
            _torch_pred = eqt(windows_std_tt.to(device))
        else:
            windows_max_tt = torch.Tensor(windows_max)
            _torch_pred = eqt(windows_max_tt.to(device))
        batch_pred[0, ipre, :] = _torch_pred[1].detach().cpu().numpy()
        batch_pred[1, ipre, :] = _torch_pred[2].detach().cpu().numpy()
    
    plt.figure(figsize = (10, 3))
    
    for ipre in (pretrain_list)   
        d_p = batch_pred[0,]
        plt.plot(d.T)
        plt.vlines(metadata.loc[iwin, 'trace_P_arrival_sample'], d.max(), d.min(), 
                   color = 'b', label = "catalog P pick", linewidth = 2, linestyle = '--')
        plt.vlines(metadata.loc[iwin, 'trace_S_arrival_sample'], d.max(), d.min(), 
                   color = 'r', label = "catalog S pick", linewidth = 2, linestyle = '--')
        plt.legend(loc = 'upper right')
        plt.xlabel("Time", fontsize = 15)
        plt.ylabel("Raw Waveform", fontsize = 15)
        plt.xlim([metadata.loc[iwin, 'trace_P_arrival_sample'] - 1000,
                  metadata.loc[iwin, 'trace_S_arrival_sample'] + 1000]); 
        plt.ylim([d.min(), d.max()])

#     # clean up memory
#     del _torch_pred, windows_max_tt, windows_std_tt
#     del windows_std, windows_max
#     gc.collect()
#     torch.cuda.empty_cache()

#     print(f"All prediction shape: {batch_pred.shape}")
    
#     ####################### You don't need this
# #     pretrain_pred = np.zeros([2, len(pretrain_list), npts], dtype = np.float32)
# #     for ipre, pretrain in enumerate(pretrain_list):
# #        # 0 for P-wave
# #         pretrain_pred[0, ipre, :] = stacking(batch_pred[0, ipre, :], npts, l_blnd, r_blnd)
# # 
# #        # 1 for S-wave
# #        pretrain_pred[1, ipre, :] = stacking(batch_pred[1, ipre, :], npts, l_blnd, r_blnd)
#     ####################### You don't need this
    
#     smb_pred = np.zeros([2, nseg, twin], dtype = np.float32)
#     # calculate the semblance
#     ## the semblance may takes a while bit to calculate
    
#     ############################# remove tqdm (extra progress bar)
# #     for iseg in tqdm(range(nseg)):
#     for iseg in range(nseg):
#     #############################
#         # 0 for P-wave
#         smb_pred[0, iseg, :] = ensemble_semblance(batch_pred[0, :, iseg, :], paras_semblance)

#         # 1 for P-wave
#         smb_pred[1, iseg, :] = ensemble_semblance(batch_pred[1, :, iseg, :], paras_semblance)

#     ## ... and stack
#     # 0 for P-wave
#     ####################### add a nseg argument here
#     #smb_p = stacking(smb_pred[0, :], npts, l_blnd, r_blnd)
#     smb_p = stacking(smb_pred[0, :], npts, l_blnd, r_blnd, nseg)

#     # 1 for P-wave
#     #smb_s = stacking(smb_pred[1, :], npts, l_blnd, r_blnd)
#     smb_s = stacking(smb_pred[1, :], npts, l_blnd, r_blnd, nseg)
#     #######################
#     # clean-up RAM
#     del smb_pred, batch_pred

#     p_index = picks_summary_simple(smb_p, p_thrd)
#     s_index = picks_summary_simple(smb_s, s_thrd)
#     print(f"{len(p_index)} P picks\n{len(s_index)} S picks")
    