In [45]:
#!/usr/bin/env python
# coding: utf-8

import obspy
from obspy.clients.fdsn import Client
from obspy import UTCDateTime
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch

import json
import time

import seisbench.models as sbm
import seisbench.data as sbd
import seisbench.generate as sbg

from seisbench.util import worker_seeding
from torch.utils.data import DataLoader
from obspy.clients.fdsn import Client
from scipy.signal import find_peaks



In [46]:
# Load configuration from JSON file
with open('config.json', 'r') as f:
    config = json.load(f)

In [47]:
config

{'peak_detection': {'sampling_rate': 100, 'height': 0.5, 'distance': 100},
 'training': {'batch_size': 64,
  'num_workers': 4,
  'learning_rate': 0.01,
  'epochs': 50,
  'patience': 5,
  'loss_weights': [0.01, 0.4, 0.59],
  'optimization': {'mixed_precision': True,
   'gradient_accumulation_steps': 1,
   'pin_memory': True,
   'prefetch_factor': 2,
   'persistent_workers': True}},
 'device': {'use_cuda': True, 'device_id': 0}}

In [48]:
# Set random seed for reproducibility
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


In [49]:
import os

# Loader the picker
#model = sbm.EQTransformer.from_pretrained("original")
model = sbm.PhaseNet.from_pretrained("stead")

In [50]:
# Set up device
device = torch.device(f"cuda:{config['device']['device_id']}" if torch.cuda.is_available() and config['device']['use_cuda'] else "cpu")
print(f"Using device: {device}")
model.to(device)


Using device: cpu


PhaseNet(
  (inc): Conv1d(3, 8, kernel_size=(7,), stride=(1,), padding=same)
  (in_bn): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  (down_branch): ModuleList(
    (0): ModuleList(
      (0): Conv1d(8, 8, kernel_size=(7,), stride=(1,), padding=same, bias=False)
      (1): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): Conv1d(8, 8, kernel_size=(7,), stride=(4,), padding=(3,), bias=False)
      (3): BatchNorm1d(8, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): ModuleList(
      (0): Conv1d(8, 16, kernel_size=(7,), stride=(1,), padding=same, bias=False)
      (1): BatchNorm1d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
      (2): Conv1d(16, 16, kernel_size=(7,), stride=(4,), bias=False)
      (3): BatchNorm1d(16, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): ModuleList(
      (0): Conv1d(16, 32, kernel_size=(7,), stride=(1,

In [51]:
# Load the data
print("Loading data...")
data = sbd.OKLA_1Mil_120s_Ver_3(sampling_rate=100, force=True, component_order="ENZ")



Loading data...


In [52]:
# Create a random sample
sample_fraction = 0.02  # Sample 20% of the data
print(f"Creating random sample of {sample_fraction*100}% of the data...")

Creating random sample of 2.0% of the data...


In [53]:
# Create a random mask for sampling
np.random.seed(42)  # For reproducibility

mask = np.random.random(len(data)) < sample_fraction
data.filter(mask)

print(f"Sampled dataset size: {len(data)}")


Sampled dataset size: 22782


In [54]:
print("Sample metadata:")
data.metadata.head()

Sample metadata:


Unnamed: 0,index,station_network_code,station_code,trace_channel,station_latitude_deg,station_longitude_deg,station_elevation_m,trace_p_arrival_sample,trace_p_status,trace_p_weight,...,trace_snr_db,trace_coda_end_sample,trace_start_time,trace_category,trace_name,split,trace_name_original,trace_chunk,trace_sampling_rate_hz,trace_component_order
72,72,2V,TG11,EHE,35.2689,-97.8146,407.0,6000.0,manual,1.0,...,,,2023-09-11T10:04:26.195000Z,earthquake_local,"bucket0$53,:3,:12001",train,2V.TG11.EHE.EHN.EHZ.2023-09-11T1004262023-09-1...,,100,ZNE
128,128,2V,TG11,EHE,35.2689,-97.8146,407.0,6000.0,manual,1.0,...,,,2023-10-15T03:02:11.464999Z,earthquake_local,"bucket0$97,:3,:12001",train,2V.TG11.EHE.EHN.EHZ.2023-10-15T0302112023-10-1...,,100,ZNE
171,171,2V,TG11,EHE,35.2689,-97.8146,407.0,6000.0,manual,1.0,...,,,2023-11-02T09:08:31.285000Z,earthquake_local,"bucket0$124,:3,:12001",train,2V.TG11.EHE.EHN.EHZ.2023-11-02T0908312023-11-0...,,100,ZNE
205,205,2V,TG11,EHE,35.2689,-97.8146,407.0,6000.0,manual,1.0,...,,,2023-11-29T04:45:18.075000Z,earthquake_local,"bucket5$28,:3,:12001",test,2V.TG11.EHE.EHN.EHZ.2023-11-29T0445182023-11-2...,,100,ZNE
208,208,2V,TG11,EHE,35.2689,-97.8146,407.0,5999.0,manual,1.0,...,,,2023-11-30T00:58:44.934999Z,earthquake_local,"bucket0$153,:3,:12001",train,2V.TG11.EHE.EHN.EHZ.2023-11-30T0058442023-11-3...,,100,ZNE


In [55]:
# Split data
train, dev, test = data.train_dev_test()


In [56]:

print("Train:", train)
print("Dev:", dev)
print("Test:", test)


Train: OKLA_1Mil_120s_Ver_3 - 15909 traces
Dev: OKLA_1Mil_120s_Ver_3 - 3429 traces
Test: OKLA_1Mil_120s_Ver_3 - 3444 traces


In [57]:

# Set up data augmentation

phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",
    "trace_Pg_arrival_sample": "P",
    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",
    "trace_Sg_arrival_sample": "S",
    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}

In [58]:
# Create the data generators for training and validation
train_generator = sbg.GenericGenerator(train)
dev_generator = sbg.GenericGenerator(dev)
test_generator = sbg.GenericGenerator(test)


In [59]:
# Define phase lists for labeling
p_phases = [key for key, val in phase_dict.items() if val == "P"]
s_phases = [key for key, val in phase_dict.items() if val == "S"]

train_generator = sbg.GenericGenerator(train)
dev_generator = sbg.GenericGenerator(dev)
test_generator = sbg.GenericGenerator(test)

augmentations = [
    sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random", strategy="variable"),
    sbg.RandomWindow(windowlen=3001, strategy="pad"),
    sbg.Normalize(demean_axis=-1, detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
    sbg.ChangeDtype(np.float32),
    sbg.ProbabilisticLabeller(sigma=30, dim=0),
]

train_generator.add_augmentations(augmentations)
dev_generator.add_augmentations(augmentations)
test_generator.add_augmentations(augmentations)


In [60]:
# Parameters for peak detection
sampling_rate = config['peak_detection']['sampling_rate']
height = config['peak_detection']['height']
distance = config['peak_detection']['distance']

batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']

In [61]:
# Load the data for machine learning

train_loader = DataLoader(train_generator,batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_seeding,pin_memory=True,prefetch_factor=4,persistent_workers=True)
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding,pin_memory=True,prefetch_factor=4,persistent_workers=True)
val_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding,pin_memory=True,prefetch_factor=4,persistent_workers=True)


In [62]:
# Define loss function
def loss_fn(y_pred, y_true, eps=1e-5):
    h = y_true * torch.log(y_pred + eps)
    h = h.mean(-1).sum(-1)
    h = h.mean()
    return -h

In [63]:
# Learning rate and number of epochs
learning_rate = config['training']['learning_rate']
epochs = config['training']['epochs']


In [64]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


In [65]:
# EarlyStopping class
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, checkpoint_path='checkpoint.pt', 
                 best_model_path='best_model.pth', final_model_path='final_model.pth'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.checkpoint_path = checkpoint_path
        self.best_model_path = best_model_path
        self.final_model_path = final_model_path

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.save_best_model(model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
                self.save_final_model(model)
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.save_best_model(model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...')
        torch.save(model.state_dict(), self.checkpoint_path)
        self.val_loss_min = val_loss

    def save_best_model(self, model):
        if self.verbose:
            print(f'Saving best model to {self.best_model_path}')
        torch.save(model.state_dict(), self.best_model_path)

    def save_final_model(self, model):
        if self.verbose:
            print(f'Early stopping triggered. Saving final model to {self.final_model_path}')
        torch.save(model.state_dict(), self.final_model_path)


In [66]:
# Function to train for one epoch
def train_one_epoch(dataloader, model, loss_fn, optimizer, device):
    model.train()
    total_loss = 0
    size = len(dataloader.dataset)

    for batch_id, batch in enumerate(dataloader):
        pred = model(batch["X"].to(device))
        loss = loss_fn(pred, batch["y"].to(device))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id % 5 == 0:
            print(f"loss: {loss.item():>7f}  [{batch_id * len(batch['X']):>5d}/{size:>5d}]")

        total_loss += loss.item()

    return total_loss / len(dataloader)


In [67]:
# Function to evaluate the model
def evaluate_model(dataloader, model, loss_fn, device):
    model.eval()
    val_loss = 0

    with torch.no_grad():
        for batch in dataloader:
            pred = model(batch["X"].to(device))
            val_loss += loss_fn(pred, batch["y"].to(device)).item()

    return val_loss / len(dataloader)


In [68]:
# Function to plot training history
def plot_training_history(history):
    plt.figure(figsize=(10, 6))
    plt.plot(history['train_loss'], label='Training Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss Over Time')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.fill_between(range(len(history['train_loss'])), 
                     history['train_loss'], history['val_loss'],
                     alpha=0.3, color='red', 
                     where=(np.array(history['val_loss']) > np.array(history['train_loss'])),
                     label='Potential Overfitting Gap')
    plt.savefig('training_history.png')
    plt.close()


In [69]:
# Training routine with EarlyStopping and scheduler
def train_model(train_loader, val_loader, model, optimizer, loss_fn, device, num_epochs=25, patience=7):
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    history = {
        'train_loss': [],
        'val_loss': []
    }

    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        train_loss = train_one_epoch(train_loader, model, loss_fn, optimizer, device)
        val_loss = evaluate_model(val_loader, model, loss_fn, device)

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)

        print(f"Epoch {epoch+1} results: Train loss: {train_loss:.6f}, Val loss: {val_loss:.6f}")

        scheduler.step(val_loss)
        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print("Early stopping triggered")
            break

    model.load_state_dict(torch.load('checkpoint.pt'))
    plot_training_history(history)
    return model, history

if __name__ == "__main__":
    # Call the training function
    patience = config['training']['patience'] if 'patience' in config['training'] else 7
    trained_model, training_history = train_model(
        train_loader=train_loader,
        val_loader=val_loader,
        model=model,
        optimizer=optimizer,
        loss_fn=loss_fn,
        device=device,
        num_epochs=epochs,
        patience=patience
    )
    
    # Evaluate on test set
    test_loss = evaluate_model(test_loader, trained_model, loss_fn, device)
    print(f"Final test loss: {test_loss:.6f}")
    
    print("Training completed!")


Epoch 1/50




loss: 0.214218  [    0/15909]
loss: 0.134593  [  320/15909]
loss: 0.113847  [  640/15909]
loss: 0.100160  [  960/15909]
loss: 0.092996  [ 1280/15909]
loss: 0.105860  [ 1600/15909]
loss: 0.081091  [ 1920/15909]
loss: 0.084028  [ 2240/15909]
loss: 0.077439  [ 2560/15909]
loss: 0.076635  [ 2880/15909]
loss: 0.085372  [ 3200/15909]
loss: 0.074468  [ 3520/15909]
loss: 0.077225  [ 3840/15909]
loss: 0.079956  [ 4160/15909]
loss: 0.075479  [ 4480/15909]
loss: 0.084582  [ 4800/15909]
loss: 0.081816  [ 5120/15909]
loss: 0.074436  [ 5440/15909]
loss: 0.066858  [ 5760/15909]
loss: 0.083428  [ 6080/15909]
loss: 0.081881  [ 6400/15909]
loss: 0.067153  [ 6720/15909]
loss: 0.069370  [ 7040/15909]
loss: 0.074499  [ 7360/15909]
loss: 0.071951  [ 7680/15909]
loss: 0.077865  [ 8000/15909]
loss: 0.071058  [ 8320/15909]
loss: 0.074053  [ 8640/15909]
loss: 0.068093  [ 8960/15909]
loss: 0.062439  [ 9280/15909]
loss: 0.085430  [ 9600/15909]
loss: 0.072582  [ 9920/15909]
loss: 0.061198  [10240/15909]
loss: 0.07