## Notebook to build a function
This notebook is for making a function that takes in the lists of stations and days and returns CSV files
Reference:
- Marine's notebook: https://github.com/Denolle-Lab/surface_events/blob/marine_cleanup_branch/src/SU_seismic_infrasound_MtRainier.ipynb 
- Yiyu's notebook: https://github.com/congcy/ELEP/blob/main/docs/tutorials/example_BB_continuous_data_PB_B204.ipynb

In [168]:
from obspy.clients.fdsn.client import Client
from pnwstore.mseed import WaveformClient
import torch
import numpy as np
from tqdm import tqdm
import time 
import pandas as pd
import gc
import seisbench.models as sbm
from ELEP.elep.ensemble_statistics import ensemble_statistics
from ELEP.elep.ensemble_coherence import ensemble_semblance 
from ELEP.elep.trigger_func import picks_summary_simple

In [169]:
device = torch.device("cpu")

# 1. Set up the job

* Make a list of stations
* make a list pf days
* set up parallel job using Dask (ask Zoe&Yiyu)

In [177]:
# Make a list of stations
list_sta = ['J57A']

In [178]:
# Make a list of days
list_days = []

In [179]:
twin = 6000     # length of time window
step = 3000     # step length
l_blnd, r_blnd = 500, 500

## 2. Load data

In [180]:
# Define clients
client_inventory = Client('IRIS')
client_waveform = WaveformClient()

In [181]:
# Get the inventory for the stations
stations = list_sta[0]
network = '7D'
channels = '?H?'
client = client_inventory
inventory = client.get_stations(network=network, station=stations)

In [182]:
# Get waveforms and filter
s_J57A = client_waveform.get_waveforms(network="7D", station='J57A', channel="?H?", year=2012, month=7, day=10)
s_J57A.filter(type='bandpass',freqmin=4,freqmax=15)
s_J57A

3 Trace(s) in Stream:
7D.J57A..BH1 | 2012-07-10T00:00:00.010700Z - 2012-07-10T23:59:59.990700Z | 50.0 Hz, 4320000 samples
7D.J57A..BH2 | 2012-07-10T00:00:00.010700Z - 2012-07-10T23:59:59.990700Z | 50.0 Hz, 4320000 samples
7D.J57A..BHZ | 2012-07-10T00:00:00.010700Z - 2012-07-10T23:59:59.990700Z | 50.0 Hz, 4320000 samples

In [183]:
# Get the necassary information about the station
delta = s_J57A[0].stats.delta
starttime = s_J57A[0].stats.starttime
fs = s_J57A[0].stats.sampling_rate
dt = 1/fs

In [184]:
# download models
pretrain_list = ["pnw","ethz","instance","scedc","stead","geofon"]
pn_pnw_model = sbm.EQTransformer.from_pretrained('pnw')
pn_ethz_model = sbm.EQTransformer.from_pretrained("ethz")
pn_instance_model = sbm.EQTransformer.from_pretrained("instance")
pn_scedc_model = sbm.EQTransformer.from_pretrained("scedc")
pn_stead_model = sbm.EQTransformer.from_pretrained("stead")
pn_geofon_model = sbm.EQTransformer.from_pretrained("geofon")

## Reshaping data 
We cut the continuous data into small time windows with overlap. Then, we pre-process the stream and the windowed waveform: switch channel order, demean, normalization, and taper.

Note that `original` pretrain model use std normalization, while others use maximum normalization. Thus, we create two different windowed data: `windows_std` and `windows_max`.

In [185]:

sdata = np.array(s_J57A)
npts = sdata.shape[1]
nseg = int(np.ceil((npts - twin) / step)) + 1
windows = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
tap = 0.5 * (1 + np.cos(np.linspace(np.pi, 2 * np.pi, 6)))


windows_std = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
windows_max = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
windows = np.zeros(shape=(nseg, 3, twin), dtype= np.float32)
windows_idx = np.zeros(nseg, dtype=np.int32)

for iseg in range(nseg):
    idx = iseg * step
    windows[iseg, :] = sdata[:, idx:idx + twin]
    windows[iseg, :] -= np.mean(windows[iseg, :], axis=-1, keepdims=True)
    # original use std norm
    windows_std[iseg, :] = windows[iseg, :] / np.std(windows[iseg, :]) + 1e-10
    # others use max norm
    windows_max[iseg, :] = windows[iseg, :] / (np.max(np.abs(windows[iseg, :]), axis=-1, keepdims=True))
    windows_idx[iseg] = idx

# taper
windows_std[:, :, :6] *= tap; windows_std[:, :, -6:] *= tap[::-1]; 
windows_max[:, :, :6] *= tap; windows_max[:, :, -6:] *= tap[::-1];
del windows

print(f"Window data shape: {windows_std.shape}")

Window data shape: (1439, 3, 6000)


## Predict on base models
Then we use EqTransformer model to perform prediction on each of the time window. We use six different pre-trained model, listed in `pretrain_list`. We loop over these pretrain weights, and save the results in `batch_pred` variable of shape [2, 6, 479, 6000]. Understanding the shape of this varibale is important:
- 2: one for P-wave and another for S-wave. We don't save detection branch.
- 5: five pretrained weights
- 1439: the number of time windows
- 6000: the number of points in each time window (60 sec@100 Hz)

In [186]:
pretrain_list = ['original', 'ethz', 'instance', 'scedc', 'stead']

# dim 0: 0 = P, 1 = S
batch_pred = np.zeros([2, len(pretrain_list), nseg, twin], dtype = np.float32) 

for ipre, pretrain in enumerate(pretrain_list):
    t0 = time.time()
    eqt = sbm.EQTransformer.from_pretrained(pretrain)
    eqt.to(device);
    eqt._annotate_args['overlap'] = ('Overlap between prediction windows in samples \
                                    (only for window prediction models)', step)
    eqt._annotate_args['blinding'] = ('Number of prediction samples to discard on \
                                     each side of each window prediction', (l_blnd, r_blnd))
    eqt.eval();
    if pretrain == 'original':
        # batch prediction through torch model
        windows_std_tt = torch.Tensor(windows_std)
        _torch_pred = eqt(windows_std_tt.to(device))
    else:
        windows_max_tt = torch.Tensor(windows_max)
        _torch_pred = eqt(windows_max_tt.to(device))
    batch_pred[0, ipre, :] = _torch_pred[1].detach().cpu().numpy()
    batch_pred[1, ipre, :] = _torch_pred[2].detach().cpu().numpy()
    
    t1 = time.time()
    print(f"picking using [{pretrain}] model: %.3f second" % (t1 - t0))
    
# clean up memory
del _torch_pred, windows_max_tt, windows_std_tt
del windows_std, windows_max
gc.collect()
torch.cuda.empty_cache()

print(f"All prediction shape: {batch_pred.shape}")

picking using [original] model: 101.236 second
picking using [ethz] model: 56.532 second
picking using [instance] model: 59.654 second
picking using [scedc] model: 56.654 second
picking using [stead] model: 58.972 second
All prediction shape: (2, 5, 1439, 6000)


## Stacking 
This section we merge all time windows into the continuous prediction. We use the `stacking` function: it takes matrix of shape [nseg, ntrace] and merge it into an 1D array.

In [187]:
def stacking(data, npts, l_blnd, r_blnd):
    _data = data.copy()
    stack = np.full(npts, np.nan, dtype = np.float32)
    _data[:, :l_blnd] = np.nan; _data[:, -r_blnd:] = np.nan
    stack[:twin] = _data[0, :]
    for iseg in range(nseg-1):
        idx = step*(iseg+1)
        stack[idx:idx + twin] = \
                np.nanmax([stack[idx:idx + twin], _data[iseg+1, :]], axis = 0)
    return stack

In [188]:
pretrain_pred = np.zeros([2, len(pretrain_list), npts], dtype = np.float32)
for ipre, pretrain in enumerate(pretrain_list):
    # 0 for P-wave
    pretrain_pred[0, ipre, :] = stacking(batch_pred[0, ipre, :], npts, l_blnd, r_blnd)
    
    # 1 for S-wave
    pretrain_pred[1, ipre, :] = stacking(batch_pred[1, ipre, :], npts, l_blnd, r_blnd)

  np.nanmax([stack[idx:idx + twin], _data[iseg+1, :]], axis = 0)


In [189]:
paras_semblance = {'dt':dt, 'semblance_order':2, 'window_flag':True, 
                   'semblance_win':0.5, 'weight_flag':'max'}
p_thrd, s_thrd = 0.05, 0.05

smb_pred = np.zeros([2, nseg, twin], dtype = np.float32)

In [190]:
# calculate the semblance
## the semblance may takes a while bit to calculate
for iseg in tqdm(range(nseg)):
    # 0 for P-wave
    smb_pred[0, iseg, :] = ensemble_semblance(batch_pred[0, :, iseg, :], paras_semblance)
    
    # 1 for P-wave
    smb_pred[1, iseg, :] = ensemble_semblance(batch_pred[1, :, iseg, :], paras_semblance)

## ... and stack
# 0 for P-wave
smb_p = stacking(smb_pred[0, :], npts, l_blnd, r_blnd)

# 1 for P-wave
smb_s = stacking(smb_pred[1, :], npts, l_blnd, r_blnd)

# clean-up RAM
del smb_pred, batch_pred

100%|██████████| 1439/1439 [03:16<00:00,  7.31it/s]
  np.nanmax([stack[idx:idx + twin], _data[iseg+1, :]], axis = 0)


## Create a csv file
- Create a dictionary and the keys for event_id, source_type, station_network_code,station_channel_code,station_code,station_location_code,station_latitude_deg,station_longitude_deg,station_elevation_m,trace_name,trace_sampling_rate_hz,trace_start_time,trace_S_arrival_sample,trace_P_arrival_sample,trace_S_onset,trace_P_onset,trace_snr_db,trace_p_arrival, and trace_s_arrival
- The keys used in the CamCat dataset in the seisbench format:  event_id,source_origin_time,source_latitude_deg,source_longitude_deg,source_type,source_depth_km,preferred_source_magnitude,preferred_source_magnitude_type,preferred_source_magnitude_uncertainty,source_depth_uncertainty_km,source_horizontal_uncertainty_km,station_network_code,station_channel_code,station_code,station_location_code,station_latitude_deg,station_longitude_deg,station_elevation_m,trace_name,trace_sampling_rate_hz,trace_start_time,trace_S_arrival_sample,trace_P_arrival_sample,trace_S_arrival_uncertainty_s,trace_P_arrival_uncertainty_s,trace_P_polarity,trace_S_onset,trace_P_onset,trace_snr_db,source_type_pnsn_label,source_local_magnitude,source_local_magnitude_uncertainty,source_duration_magnitude,source_duration_magnitude_uncertainty,source_hand_magnitude,trace_missing_channel,trace_has_offset

In [191]:
p_index = picks_summary_simple(smb_p, p_thrd)
s_index = picks_summary_simple(smb_s, s_thrd)
print(f"{len(p_index)} P picks\n{len(s_index)} S picks")

48 P picks
3 S picks


In [192]:
# Create lists and a data frame
event_id = []
source_type = []
station_network_code = []
station_channel_code = []
station_code = []
station_location_code = []
station_latitude_deg= []
station_longitude_deg = []
station_elevation_m = []
trace_name = []
trace_sampling_rate_hz = []
trace_start_time = []
trace_S_arrival_sample = []
trace_P_arrival_sample = []
trace_S_onset = []
trace_P_onset = []
trace_snr_db = []
trace_p_arrival = []
trace_s_arrival = []

for i, idx in enumerate(p_index):
    event_id.append(' ')
    source_type.append(' ')
    station_network_code.append('7D')
    station_channel_code.append(' ')
    station_code.append('J57D')
    station_location_code.append(s_J57A[0].stats.location)   
    station_latitude_deg.append(inventory[0][0].latitude)
    station_longitude_deg.append(inventory[0][0].longitude)   
    station_elevation_m.append(inventory[0][0].elevation)
    trace_name.append(' ')
    trace_sampling_rate_hz.append(s_J57A[0].stats.sampling_rate)
    trace_start_time.append(s_J57A[0].stats.starttime)
    trace_S_arrival_sample.append(' ')
    trace_P_arrival_sample.append(' ')
    trace_S_onset.append(' ')
    trace_P_onset.append(' ')
    trace_snr_db.append(' ')
    trace_s_arrival.append(np.nan)
    trace_p_arrival.append(str(starttime  + idx * delta))
    
for i, idx in enumerate(s_index):
    event_id.append(' ')
    source_type.append(' ')
    station_network_code.append('7D')
    station_channel_code.append(' ')
    station_code.append('J57D')
    station_location_code.append(s_J57A[0].stats.location)   
    station_latitude_deg.append(inventory[0][0].latitude)
    station_longitude_deg.append(inventory[0][0].longitude)   
    station_elevation_m.append(inventory[0][0].elevation)
    trace_name.append(' ')
    trace_sampling_rate_hz.append(s_J57A[0].stats.sampling_rate)
    trace_start_time.append(s_J57A[0].stats.starttime)
    trace_S_arrival_sample.append(' ')
    trace_P_arrival_sample.append(' ')
    trace_S_onset.append(' ')
    trace_P_onset.append(' ')
    trace_snr_db.append(' ')
    trace_s_arrival.append(str(starttime  + idx * delta))
    trace_p_arrival.append(np.nan)

# dictionary of lists
dict = {'event_id':event_id,'source_type':source_type,'station_network_code':station_network_code,\
        'station_channel_code':station_channel_code,'station_code':station_code,'station_location_code':station_location_code,\
        'station_latitude_deg':station_latitude_deg,'station_longitude_deg':station_longitude_deg, \
        'station_elevation_m':station_elevation_m,'trace_name':trace_name,'trace_sampling_rate_hz':trace_sampling_rate_hz,\
        'trace_start_time':trace_start_time,'trace_S_arrival_sample':trace_S_arrival_sample,\
        'trace_P_arrival_sample':trace_P_arrival_sample, 'trace_S_onset':trace_S_onset,'trace_P_onset':trace_P_onset,\
        'trace_snr_db':trace_snr_db, 'trace_s_arrival':trace_s_arrival, 'trace_p_arrival':trace_p_arrival}
     
df = pd.DataFrame(dict)
 
print(df)

   event_id source_type station_network_code station_channel_code  \
0                                         7D                        
1                                         7D                        
2                                         7D                        
3                                         7D                        
4                                         7D                        
5                                         7D                        
6                                         7D                        
7                                         7D                        
8                                         7D                        
9                                         7D                        
10                                        7D                        
11                                        7D                        
12                                        7D                        
13                                

In [193]:
# Create the CSV file
df.to_csv('cat_elep.csv')

In [194]:
# Read the CSV file
cat_elep = pd.read_csv('cat_elep.csv')
cat_elep

Unnamed: 0.1,Unnamed: 0,event_id,source_type,station_network_code,station_channel_code,station_code,station_location_code,station_latitude_deg,station_longitude_deg,station_elevation_m,trace_name,trace_sampling_rate_hz,trace_start_time,trace_S_arrival_sample,trace_P_arrival_sample,trace_S_onset,trace_P_onset,trace_snr_db,trace_s_arrival,trace_p_arrival
0,0,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T00:12:18.490700Z
1,1,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T00:21:37.330700Z
2,2,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T02:42:49.830700Z
3,3,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T02:49:03.390700Z
4,4,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T02:53:31.610700Z
5,5,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T03:01:53.010700Z
6,6,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T03:09:46.210700Z
7,7,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T03:27:11.530700Z
8,8,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T03:31:09.830700Z
9,9,,,7D,,J57D,,47.080101,-124.4505,-55.8,,50.0,2012-07-10T00:00:00.010700Z,,,,,,,2012-07-10T03:35:38.990700Z


## Calculating the metrics that can be added to the CSV file

In [56]:
# Calculate the SNR
snr = np.zeros(len(p_index))
for i, idx in enumerate(p_index):
    snr[i] = np.mean(sdata[:, idx-int(3/delta):idx])/np.mean(sdata[:, idx:idx+int(3/delta)])

snr 

array([ -0.61720848,  -0.09768351,   1.74035625,   0.62378588,
         0.67595104,  -1.36076867,  -1.01128689,   0.14048542,
        -0.31960256,   0.19071466,  -0.69516703,   0.05569674,
        -1.32098761,   0.08101457, -12.80655422,  -3.51066439,
        -0.51928966,   0.07429897,   3.01106875,  -0.27501024,
        -4.18768781,   4.72571781,   0.84652496,  -1.11724439,
        -0.36222411,  31.20533436,   1.32077671, -40.10909755,
        -1.46579317,   1.32295763,  -0.64379084,  -0.63322123,
        -0.13161624,  -0.37482107,   0.3458798 ,  -4.00910733,
        -0.70278175,   1.01139866,   0.40466517,  -1.39533298,
        -0.52807273,  -0.59767938,   1.47946156,   0.29501883,
         0.36193009,  -2.81448793,  -2.21202827,   1.37509341])