### Model training updates

update 1 - retraining the model with the clean esec data (non flows, non rockfalls, non earthquakes). 

update 2 - June 20, 2025 - retraining the model with 2502 near field  (0-50 km) explosions waveforms that were classified as surface events and earthquakes. 

update 3 - retraining the model by assigning higher penalities whenever there is a confusion between explosion and surface events. 

In [1]:
%load_ext autoreload
%autoreload 2

# === Standard Libraries ===
import os
import sys
import random
import json
from typing import Any

# === Scientific Libraries ===
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from glob import glob

# === Signal Processing ===
import scipy
from scipy import signal
from scipy.signal import butter, filtfilt, correlate

# === Seismology Libraries ===
import obspy
from obspy import UTCDateTime
from obspy.clients.fdsn import Client

# === Machine Learning Libraries ===
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report

# === PyTorch ===
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# === File Handling ===
import h5py

# === Custom Modules ===
module_path = os.path.abspath(os.path.join('../scripts'))
if module_path not in sys.path:
    sys.path.append(module_path)

from sklearn.model_selection import train_test_split
from neural_network_architectures import (
     QuakeXNet_1d, QuakeXNet_2d, SeismicCNN_1d, SeismicCNN_2d )


# === Seismology Client ===
client = Client('IRIS')

from utils import extract_waveforms
from utils import compute_spectrogram
from utils import normalize_spectrogram_minmax
from utils import return_train_val_loaders
from utils import plot_confusion_matrix_and_cr
from utils import train_model
from utils import WaveformPreprocessor


import json

cuda


## Defining some common parameters

In [14]:
# if we are taking all data or not. 
all_data = False

# the start point will be selected randomly from (start, -4)
start = - 40
shifting = True

# training parameters
train_split = 70                                      
val_split=20
test_split = 10
learning_rate=0.001
batch_size=128
n_epochs=60
dropout=0.4
criterion=nn.CrossEntropyLoss()


num_channels = 3
# new sampling rate
fs = 50

## filtering parameters
highcut = 20
lowcut = 1
input_window_length = 100

# randomly starting between -40 to -5s
start = -40

## 1. Additional surface events per station. 

###  I downloaded additional surface event data by downloading from more stations from event to supplement existing data.  the step below is just processing those additionally downloaded waveforms and these will be added later to the the total dataset. 

In [2]:
def process_surface_events(data_path, ids_path, fs=50, original_fs = 100, lowcut=1, highcut=20, window_length=100, taper_alpha=0.1, random_offset=(-40, -5)):
    """
    Processes surface event data by applying tapering, bandpass filtering, resampling, and normalization.
    
    Args:
        data_path (str): Path to the surface event data (.npy file).
        ids_path (str): Path to the surface event IDs (JSON file).
        fs (int): Sampling rate for resampling.
        lowcut (float): Low cutoff frequency for bandpass filter.
        highcut (float): High cutoff frequency for bandpass filter.
        window_length (int): Length of the waveform window (in seconds).
        taper_alpha (float): Alpha value for the Tukey window.
        random_offset (tuple): Range of random offsets for slicing data.
        
    Returns:
        list: Processed waveform data.
        list: Corresponding event IDs.
    """
    # Load data and IDs
    surface_data = np.load(data_path, allow_pickle=True)
    with open(ids_path, "r") as file:
        surface_ids = json.load(file)


    processed_data = []
    processed_ids = []

    # Process each event
    for i in tqdm(range(len(surface_data)), desc="Processing events"):
        try:
            event_data = surface_data[i]
            
            orig_fs = 100
            
            # Randomly select a window of the specified length
            random_shift = np.random.randint(random_offset[0], random_offset[1]) * orig_fs
            
            # assuming the onset at 90s
            start_idx = int(90 * orig_fs) + random_shift
            end_idx = start_idx + int(window_length * orig_fs)
            
    
            # Handle boundary conditions
            max_idx = event_data.shape[-1]
            if end_idx > max_idx:
                end_idx = max_idx
                start_idx = end_idx - int(window_length * orig_fs)
            if start_idx < 0:
                start_idx = 0
                end_idx = int(window_length * orig_fs)
                
                
            sliced = event_data[:, start_idx:end_idx]
            sliced_tensor = torch.tensor(sliced, dtype=torch.float32)

            
            processor = WaveformPreprocessor(
            input_fs=original_fs,
            target_fs=fs,
            lowcut=lowcut,
            highcut=highcut)
            
            processed = processor(sliced_tensor)  # (C, T)
          
        
            if processed.shape[-1] != int(window_length*fs):
                print('error')
                continue

            x = processed.numpy()
            
            if len(x) == 3:  # Ensure the event has three components
                processed_data.append(x)
                processed_ids.append(surface_ids[i])

        except Exception as e:
            # Log or print the exception if needed
            print(f"Error processing event {i}: {e}")
            continue

    return processed_data, processed_ids


# Example usage
data_path = '../../data/new_curated_surface_event_data.npy'
ids_path = '../../data/new_curated_surface_event_ids.json'

processed_additional_su, processed_additional_su_id = process_surface_events(data_path, ids_path)

print(f'Length of additional surface event waveforms {len(processed_additional_su)}')

Processing events: 100%|██████████| 6495/6495 [00:11<00:00, 588.60it/s]

Length of additional surface event waveforms 6487





## 2. Original PNW data

In [10]:
#data files
file_noise="/data/whd01/yiyu_data/PNWML/noise_waveforms.hdf5";
file_comcat=  "/data/whd01/yiyu_data/PNWML/comcat_waveforms.hdf5";
file_exotic="/data/whd01/yiyu_data/PNWML/exotic_waveforms.hdf5";

# metadata
# accessing the comcat metadata
comcat_metadata = pd.read_csv("/data/whd01/yiyu_data/PNWML/comcat_metadata.csv")


# accessing the exotic metadata
exotic_metadata = pd.read_csv("/data/whd01/yiyu_data/PNWML/exotic_metadata.csv")


# accessing the data files
metadata_noise = pd.read_csv("/data/whd01/yiyu_data/PNWML/noise_metadata.csv")


# creating individual data frames for each class
cat_exp = comcat_metadata[comcat_metadata['source_type'] == 'explosion']
cat_eq = comcat_metadata[comcat_metadata['source_type'] == 'earthquake']
cat_su = exotic_metadata[exotic_metadata['source_type'] == 'surface event']
cat_noise = metadata_noise
cat_noise['event_id'] = [cat_noise['trace_start_time'][i]+'_noise' for i in range(len(cat_noise))]

In [11]:
## defining the threshold
SNR_THR = 1

# explosions
trace_snr_db_values = np.array([float(cat_exp.loc[idx, 'trace_snr_db'].split("|")[-1]) for idx in cat_exp.index.values.tolist()])
ii2= np.where(trace_snr_db_values>SNR_THR)[0].astype(int) 
df_exp = cat_exp.iloc[ii2]

# earthquake
trace_snr_db_values = np.array([float(cat_eq.loc[idx, 'trace_snr_db'].split("|")[-1]) for idx in cat_eq.index.values.tolist()])
ii2= np.where(trace_snr_db_values>SNR_THR)[0].astype(int) 
df_eq = cat_eq.iloc[ii2]

# surface events
trace_snr_db_values = np.array([float(cat_su.loc[idx, 'trace_snr_db'].split("|")[-1]) for idx in cat_su.index.values.tolist()])
ii2= np.where(trace_snr_db_values>SNR_THR-2)[0].astype(int) 
df_su = cat_su.iloc[ii2]

# noise
# does not change
df_noise = cat_noise

## Note that we are only selecting three components from each class

In [15]:
number_data_per_class = len(df_su)
# surface events
d_su, id_su = extract_waveforms(df_su, file_exotic, input_window_length = input_window_length, fs=fs,
                                start =start, number_data = number_data_per_class, num_channels = num_channels,
                                shifting = shifting, all_data = all_data, lowcut = lowcut , highcut =highcut)
print(d_su.shape)



number_data_per_class = 15000
# noise
d_noise, id_noise = extract_waveforms(df_noise, file_noise, input_window_length = input_window_length, fs=fs,
                                      start = start, number_data = number_data_per_class,
                                      num_channels = num_channels, shifting = shifting, all_data = all_data, lowcut = lowcut , highcut =highcut)
print(d_noise.shape)



number_data_per_class = len(df_exp)
# explosions
d_exp, id_exp = extract_waveforms(df_exp, file_comcat, input_window_length = input_window_length, fs=fs,
                                  start = start,  number_data = number_data_per_class, num_channels = num_channels,
                                  shifting = shifting, all_data = all_data, lowcut = lowcut , highcut =highcut)

print(d_exp.shape)


number_data_per_class = 17000
# earthquakes
d_eq, id_eq = extract_waveforms(df_eq, file_comcat, input_window_length = input_window_length,  fs=fs,
                                start =start,  number_data = number_data_per_class, num_channels = num_channels,
                                shifting = shifting, all_data = all_data, lowcut = lowcut , highcut =highcut)
print(d_eq.shape)

100%|██████████| 8434/8434 [01:16<00:00, 110.32it/s]
  0%|          | 8/15000 [00:00<03:15, 76.56it/s]

(3778, 3, 5000)


100%|██████████| 15000/15000 [02:34<00:00, 97.04it/s] 
  0%|          | 5/13638 [00:00<04:34, 49.70it/s]

(10583, 3, 5000)


100%|██████████| 13638/13638 [03:10<00:00, 71.62it/s]
  0%|          | 0/17000 [00:00<?, ?it/s]

(8829, 3, 5000)


100%|██████████| 17000/17000 [03:59<00:00, 70.99it/s]


(10506, 3, 5000)


## 3. ESEC waveforms (1866 waveforms)

In [17]:
df = pd.read_csv('../../data/curated_esec_catalog_for_retraining.csv',index_col = 0)
print(len(df))

1866


### In the following cell, we are loading 270s of the three component esec waveforms, (70s before and 200s after), and resampling them to 100 Hz as required for further processing

In [18]:
esec_data = []
esec_ids = []

for i in tqdm(range(len(df))):
    try:
        event_id = df['event_id'].iloc[i]
        station = df['station'].iloc[i]

        # Find all vertical component files for the event and station
        files = glob(f"../../data/iris_esec_waveforms/waveforms/{event_id}/*{station}*")

        if len(files) == 3:
            st = obspy.Stream()
            for file in files:
                st += obspy.read(file)
            
            
            st.resample(100)

            # Convert to NumPy array and clip length to 27000 samples (if possible)
            arr = np.stack([tr.data[:27000] for tr in st])
            esec_data.append(arr)
            esec_ids.append(event_id)

    except Exception as e:
        print(f"Error on index {i}, event {df['event_id'].iloc[i]}: {e}")
        continue


100%|██████████| 1866/1866 [00:25<00:00, 74.13it/s]


In [19]:
esec_data = np.array(esec_data)
esec_ids = np.array(esec_ids)

In [20]:
random_offset=(-40, -5)
fs=50 
original_fs = 100
lowcut=1
highcut=20
window_length=100
taper_alpha=0.1
orig_fs = 100

processed_esec_data = []
processed_esec_ids = []

for i in tqdm(range(len(esec_data))):
    event_data = esec_data[i]
    random_shift = np.random.randint(random_offset[0], random_offset[1]) * orig_fs


    # assuming the onset at 90s
    start_idx = int(90 * orig_fs) + random_shift
    end_idx = start_idx + int(window_length * orig_fs)

    # Handle boundary conditions
    max_idx = event_data.shape[-1]
    if end_idx > max_idx:
        end_idx = max_idx
        start_idx = end_idx - int(window_length * orig_fs)
    if start_idx < 0:
        start_idx = 0
        end_idx = int(window_length * orig_fs)



    sliced = event_data[:, start_idx:end_idx]
    sliced_tensor = torch.tensor(sliced, dtype=torch.float32)


    processor = WaveformPreprocessor(
    input_fs=original_fs,
    target_fs=fs,
    lowcut=lowcut,
    highcut=highcut)

    processed = processor(sliced_tensor)  # (C, T)


    if processed.shape[-1] != int(window_length*fs):
        print('error')
        continue

    x = processed.numpy()
    
    if len(x) == 3:  # Ensure the event has three components
        processed_esec_data.append(x)
        processed_esec_ids.append(esec_ids[i])
        
    

100%|██████████| 1866/1866 [00:02<00:00, 624.46it/s]


## 4. New near field explosion data