In [17]:
from obspy.clients.fdsn import Client
import numpy as np
import obspy
import matplotlib.pyplot as plt
from obspy.clients.fdsn import Client
from datetime import datetime
import pandas as pd
import dask
from dask.diagnostics import ProgressBar

from obspy.clients.fdsn.client import Client
from pnwstore.mseed import WaveformClient
import torch
import numpy as np
from tqdm import tqdm
import time 
import pandas as pd
import gc
import seisbench.models as sbm
from ELEP.elep.ensemble_statistics import ensemble_statistics
from ELEP.elep.ensemble_coherence import ensemble_semblance 
from ELEP.elep.trigger_func import picks_summary_simple

In [18]:
device = torch.device("cpu")

In [19]:
# Define clients
client_inventory = Client('IRIS')
client_waveform = WaveformClient()

In [20]:
# Read in whatever you need to start - likely a list of station codes
station_list = ['J57A', 'J41A', 'M09B']

In [21]:
# Specify some parameters - you can change what you specify here vs. within the large function, this is just an example.
# Depending on whether the pertained models take a long time to load every time, you may want to load those outside the function and just feed them to the function rather than loading them every time in parallel.
twin = 6000     # length of time window
step = 3000     # step length
l_blnd, r_blnd = 500, 500
filepath = 'https://cascadia.ess.washington.edu/jhub/user/hbito/notebooks/elep-test/surface_events/src'

In [22]:
# download models
pretrain_list = ["pnw","ethz","instance","scedc","stead","geofon"]
pn_pnw_model = sbm.EQTransformer.from_pretrained('pnw')
pn_ethz_model = sbm.EQTransformer.from_pretrained("ethz")
pn_instance_model = sbm.EQTransformer.from_pretrained("instance")
pn_scedc_model = sbm.EQTransformer.from_pretrained("scedc")
pn_stead_model = sbm.EQTransformer.from_pretrained("stead")
pn_geofon_model = sbm.EQTransformer.from_pretrained("geofon")

In [23]:
# Define the function for stacking the segmented time windows after prediction
def stacking(data, npts, l_blnd, r_blnd):
    _data = data.copy()
    stack = np.full(npts, np.nan, dtype = np.float32)
    _data[:, :l_blnd] = np.nan; _data[:, -r_blnd:] = np.nan
    stack[:twin] = _data[0, :]
    for iseg in range(nseg-1):
        idx = step*(iseg+1)
        stack[idx:idx + twin] = \
                np.nanmax([stack[idx:idx + twin], _data[iseg+1, :]], axis = 0)
    return stack

In [44]:
# Write your function that you want to run in parallel: I recommend you design this to essentially perform your entire workflow on one station for one day, and write a csv file for that station, much the way you already have it.
# This is what will run in parallel!
# So, the only inputs are the station name, the start and end times you want to detect for, the path of the folder you want to write the results to, and the parameters you already specified. Here is where you could also feed in the preloaded models if that becomes important.
def run_detection(station,t1,t2,filepath,twin,step,l_blnd,r_blnd):
	# Load data
	# Reshape data
	# Predict on base models
	# Stack
	# Create and write csv file. Define file name using the station code and the input filepath
    
    # Get the inventory for the stations
    stations = station
    network = '7D'
    channels = '?H?'
    client = client_inventory
    inventory = client.get_stations(network=network, station=stations)
    
    # Get waveforms and filter
    sdata = client_waveform.get_waveforms(network="7D", station=station, channel="BH?", starttime=t1, year=t1.strftime('%Y'), month=t1.strftime('%m'), day=t1.strftime('%d'))
    sdata.filter(type='bandpass',freqmin=4,freqmax=15)

    # Get the necassary information about the station
    delta = sdata[0].stats.delta
    starttime = sdata[0].stats.starttime
    fs = sdata[0].stats.sampling_rate
    dt = 1/fs
    
    # Reshaping data
    sdata = np.array(sdata)
    npts = sdata.shape[1]
    nseg = int(np.ceil((npts - twin) / step)) + 1
    windows = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    tap = 0.5 * (1 + np.cos(np.linspace(np.pi, 2 * np.pi, 6)))
    
    # Define the parameters for semblance
    paras_semblance = {'dt':dt, 'semblance_order':2, 'window_flag':True, 
                   'semblance_win':0.5, 'weight_flag':'max'}
    p_thrd, s_thrd = 0.05, 0.05

    windows_std = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    windows_max = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    windows = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
    windows_idx = np.zeros(nseg, dtype=np.int32)

    for iseg in range(nseg):
        idx = iseg * step
        windows[iseg, :] = sdata[:, idx:idx + twin]
        windows[iseg, :] -= np.mean(windows[iseg, :], axis=-1, keepdims=True)
        # original use std norm
        windows_std[iseg, :] = windows[iseg, :] / np.std(windows[iseg, :]) + 1e-10
        # others use max norm
        windows_max[iseg, :] = windows[iseg, :] / (np.max(np.abs(windows[iseg, :]), axis=-1, keepdims=True))
        windows_idx[iseg] = idx

    # taper
    windows_std[:, :, :6] *= tap; windows_std[:, :, -6:] *= tap[::-1]; 
    windows_max[:, :, :6] *= tap; windows_max[:, :, -6:] *= tap[::-1];
    del windows

#     print(f"Window data shape: {windows_std.shape}")
    
    # Predict on base models
    
    pretrain_list = ['original', 'ethz', 'instance', 'scedc', 'stead']

    # dim 0: 0 = P, 1 = S
    batch_pred = np.zeros([2, len(pretrain_list), nseg, twin], dtype = np.float32) 

    for ipre, pretrain in enumerate(pretrain_list):
        t0 = time.time()
        eqt = sbm.EQTransformer.from_pretrained(pretrain)
        eqt.to(device);
        eqt._annotate_args['overlap'] = ('Overlap between prediction windows in samples \
                                        (only for window prediction models)', step)
        eqt._annotate_args['blinding'] = ('Number of prediction samples to discard on \
                                         each side of each window prediction', (l_blnd, r_blnd))
        eqt.eval();
        if pretrain == 'original':
            # batch prediction through torch model
            windows_std_tt = torch.Tensor(windows_std)
            _torch_pred = eqt(windows_std_tt.to(device))
        else:
            windows_max_tt = torch.Tensor(windows_max)
            _torch_pred = eqt(windows_max_tt.to(device))
        batch_pred[0, ipre, :] = _torch_pred[1].detach().cpu().numpy()
        batch_pred[1, ipre, :] = _torch_pred[2].detach().cpu().numpy()

#         t1 = time.time()
#         print(f"picking using [{pretrain}] model: %.3f second" % (t1 - t0))

    # clean up memory
    del _torch_pred, windows_max_tt, windows_std_tt
    del windows_std, windows_max
    gc.collect()
    torch.cuda.empty_cache()

    print(f"All prediction shape: {batch_pred.shape}")
    
    pretrain_pred = np.zeros([2, len(pretrain_list), npts], dtype = np.float32)
    for ipre, pretrain in enumerate(pretrain_list):
        # 0 for P-wave
        pretrain_pred[0, ipre, :] = stacking(batch_pred[0, ipre, :], npts, l_blnd, r_blnd)

        # 1 for S-wave
        pretrain_pred[1, ipre, :] = stacking(batch_pred[1, ipre, :], npts, l_blnd, r_blnd)
    
    smb_pred = np.zeros([2, nseg, twin], dtype = np.float32)
    # calculate the semblance
    ## the semblance may takes a while bit to calculate
    for iseg in tqdm(range(nseg)):
        # 0 for P-wave
        smb_pred[0, iseg, :] = ensemble_semblance(batch_pred[0, :, iseg, :], paras_semblance)

        # 1 for P-wave
        smb_pred[1, iseg, :] = ensemble_semblance(batch_pred[1, :, iseg, :], paras_semblance)

    ## ... and stack
    # 0 for P-wave
    smb_p = stacking(smb_pred[0, :], npts, l_blnd, r_blnd)

    # 1 for P-wave
    smb_s = stacking(smb_pred[1, :], npts, l_blnd, r_blnd)

    # clean-up RAM
    del smb_pred, batch_pred

    p_index = picks_summary_simple(smb_p, p_thrd)
    s_index = picks_summary_simple(smb_s, s_thrd)
    print(f"{len(p_index)} P picks\n{len(s_index)} S picks")
    
    # Create lists and a data frame
    event_id = []
    source_type = []
    station_network_code = []
    station_channel_code = []
    station_code = []
    station_location_code = []
    station_latitude_deg= []
    station_longitude_deg = []
    station_elevation_m = []
    trace_name = []
    trace_sampling_rate_hz = []
    trace_start_time = []
    trace_S_arrival_sample = []
    trace_P_arrival_sample = []
    trace_S_onset = []
    trace_P_onset = []
    trace_snr_db = []
    trace_p_arrival = []
    trace_s_arrival = []

    for i, idx in enumerate(p_index):
        event_id.append(' ')
        source_type.append(' ')
        station_network_code.append('7D')
        station_channel_code.append(' ')
        station_code.append(station)
        station_location_code.append(station.stats.location)   
        station_latitude_deg.append(inventory[0][0].latitude)
        station_longitude_deg.append(inventory[0][0].longitude)   
        station_elevation_m.append(inventory[0][0].elevation)
        trace_name.append(' ')
        trace_sampling_rate_hz.append(station.stats.sampling_rate)
        trace_start_time.append(station.stats.starttime)
        trace_S_arrival_sample.append(' ')
        trace_P_arrival_sample.append(' ')
        trace_S_onset.append(' ')
        trace_P_onset.append(' ')
        trace_snr_db.append(' ')
        trace_s_arrival.append(np.nan)
        trace_p_arrival.append(str(starttime  + idx * delta))

    for i, idx in enumerate(s_index):
        event_id.append(' ')
        source_type.append(' ')
        station_network_code.append('7D')
        station_channel_code.append(' ')
        station_code.append(station)
        station_location_code.append(station.stats.location)   
        station_latitude_deg.append(inventory[0][0].latitude)
        station_longitude_deg.append(inventory[0][0].longitude)   
        station_elevation_m.append(inventory[0][0].elevation)
        trace_name.append(' ')
        trace_sampling_rate_hz.append(station.stats.sampling_rate)
        trace_start_time.append(station.stats.starttime)
        trace_S_arrival_sample.append(' ')
        trace_P_arrival_sample.append(' ')
        trace_S_onset.append(' ')
        trace_P_onset.append(' ')
        trace_snr_db.append(' ')
        trace_s_arrival.append(str(starttime  + idx * delta))
        trace_p_arrival.append(np.nan)

    # dictionary of lists
    dict = {'event_id':event_id,'source_type':source_type,'station_network_code':station_network_code,\
            'station_channel_code':station_channel_code,'station_code':station_code,'station_location_code':station_location_code,\
            'station_latitude_deg':station_latitude_deg,'station_longitude_deg':station_longitude_deg, \
            'station_elevation_m':station_elevation_m,'trace_name':trace_name,'trace_sampling_rate_hz':trace_sampling_rate_hz,\
            'trace_start_time':trace_start_time,'trace_S_arrival_sample':trace_S_arrival_sample,\
            'trace_P_arrival_sample':trace_P_arrival_sample, 'trace_S_onset':trace_S_onset,'trace_P_onset':trace_P_onset,\
            'trace_snr_db':trace_snr_db, 'trace_s_arrival':trace_s_arrival, 'trace_p_arrival':trace_p_arrival}

    df = pd.DataFrame(dict)

    # Make the specific day into a string:
    tstring = t1.strftime('%Y%m%d')
    # Build the full file name:
    file_name = file_path+station+'_'+tstring+'.csv'
    # Write to file using that name
    df.to_csv(file_name)

In [26]:
# Now create your list of days to loop over!
t1 = datetime(2012,10,1)
t2 = datetime(2012,10,5)
time_bins = pd.to_datetime(np.arange(t1,t2,pd.Timedelta(1,'days')))


In [27]:
time_bins[0]

Timestamp('2012-10-01 00:00:00')

In [28]:
# Combine that list of days with the list of stations
# We are essentially creating a list of the number of tasks we have to do with the information that is unique to each task; we will do them in parallel
task_list = []
for sta in station_list:
	for t in time_bins:
		task_list.append([sta,t])

In [29]:
arr = np.array(task_list)
arr.shape

(12, 2)

In [34]:
task_list

[['J57A', Timestamp('2012-10-01 00:00:00')],
 ['J57A', Timestamp('2012-10-02 00:00:00')],
 ['J57A', Timestamp('2012-10-03 00:00:00')],
 ['J57A', Timestamp('2012-10-04 00:00:00')],
 ['J41A', Timestamp('2012-10-01 00:00:00')],
 ['J41A', Timestamp('2012-10-02 00:00:00')],
 ['J41A', Timestamp('2012-10-03 00:00:00')],
 ['J41A', Timestamp('2012-10-04 00:00:00')],
 ['M09B', Timestamp('2012-10-01 00:00:00')],
 ['M09B', Timestamp('2012-10-02 00:00:00')],
 ['M09B', Timestamp('2012-10-03 00:00:00')],
 ['M09B', Timestamp('2012-10-04 00:00:00')]]

In [45]:
# Now we start setting up a parallel operation using a package called Dask.

# Start by writing a new a function that is specifically designed to be run in parallel through dask. All it essentially does is define the inputs to the larger run_detection function and then runs the function itself, but because we "decorate" it with @dask.delayed to start, the code will recognize that it should be run in parallel.

@dask.delayed
def loop_tasks(task,filepath,twin,step,l_blnd,r_blnd):

	# Define the parameters that are specific to each task
	t1 = obspy.UTCDateTime(task[1])
	t2 = obspy.UTCDateTime(t1 + pd.Timedelta(1,'days'))
	station = task[0]

	# Call to the function that will perform the operation and write the results to file
	run_detection(station,t1,t2,filepath,twin,step,l_blnd,r_blnd)
	

# Now we set up the parallel operation
# The below builds a framework for the computer to run in parallel. This doesn't actually execute anything.
lazy_results = [loop_tasks(task,filepath,twin,step,l_blnd,r_blnd) for task in task_list]
    

# The below actually executes the parallel operation!
# It's nice to do it with the ProgressBar so you can see how long things are taking.
# Each operation should also write a file so that is another way to check on progress.
with ProgressBar():
	dask.compute(lazy_results)

[                                        ] | 0% Completed | 202.12 ms


ProgrammingError: SQLite objects created in a thread can only be used in that same thread. The object was created in thread id 140062026457856 and this is thread id 140049387599616.