In [34]:
#!/usr/bin/env python
# coding: utf-8

"""
Knowledge Distillation Training for XiaoNet
Teacher: PhaseNet (from STEAD)
Student: XiaoNet (v2, v3, v4, or v5)
Dataset: OKLA regional seismic data
"""

# Standard library
import os
import sys
import json
import random
from pathlib import Path

# Scientific computing
import numpy as np
import pandas as pd
from scipy import signal

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Seismology & SeisBench
import obspy
import seisbench.data as sbd
import seisbench.generate as sbg
import seisbench.models as sbm
from seisbench.util import worker_seeding

# Progress bars
from tqdm.notebook import tqdm

# Add parent directory to path for importing local modules
sys.path.append(str(Path.cwd().parent))

# XiaoNet modules
from models.xn_xiao_net_v2 import XiaoNet as XiaoNetV2
from models.xn_xiao_net_v3 import XiaoNet as XiaoNetV3
from models.xn_xiao_net_v4 import XiaoNetFast as XiaoNetV4
from models.xn_xiao_net_v5 import XiaoNetEdge as XiaoNetV5
from loss.xn_distillation_loss import DistillationLoss
from xn_utils import set_seed, setup_device
from xn_early_stopping import EarlyStopping

print("✓ All packages loaded successfully!")

✓ All packages loaded successfully!


In [35]:
# Set random seed for reproducibility
SEED = 0
set_seed(SEED)

# Set device
device = setup_device('cuda')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cpu
CUDA available: False


In [36]:
# Load configuration from config.json
config_path = Path.cwd().parent / "config.json"

if not config_path.exists():
    raise FileNotFoundError(f"Config file not found: {config_path}")

with open(config_path, "r") as f:
    config = json.load(f)

print(f"Loaded configuration from: {config_path}")
print(json.dumps(config, indent=2))

Loaded configuration from: /Users/hongyuxiao/Hongyu_File/xiao_net/config.json
{
  "peak_detection": {
    "sampling_rate": 100,
    "height": 0.5,
    "distance": 100
  },
  "data": {
    "dataset_name": "OKLA_1Mil_120s_Ver_3",
    "sampling_rate": 100,
    "window_len": 3001,
    "samples_before": 3000,
    "windowlen_large": 6000,
    "sample_fraction": 0.1
  },
  "data_filter": {
    "min_magnitude": 1.0,
    "max_magnitude": 2.0
  },
  "training": {
    "batch_size": 64,
    "num_workers": 4,
    "learning_rate": 0.01,
    "epochs": 50,
    "patience": 5,
    "loss_weights": [
      0.01,
      0.4,
      0.59
    ],
    "optimization": {
      "mixed_precision": true,
      "gradient_accumulation_steps": 1,
      "pin_memory": true,
      "prefetch_factor": 2,
      "persistent_workers": true
    }
  },
  "device": {
    "use_cuda": true,
    "device_id": 0
  }
}


In [37]:
# Load PhaseNet teacher model (pretrained on STEAD)
print("Available PhaseNet pretrained models:")
sbm.PhaseNet.list_pretrained()

Available PhaseNet pretrained models:


['diting',
 'ethz',
 'geofon',
 'instance',
 'iquique',
 'jma',
 'jma_wc',
 'lendb',
 'neic',
 'obs',
 'original',
 'phasenet_sn',
 'pisdl',
 'scedc',
 'stead',
 'volpick']

In [38]:
print("\nLoading PhaseNet teacher model...")
model = sbm.PhaseNet.from_pretrained("stead")
model.to(device)
model.eval()  # Set to evaluation mode for teacher

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ PhaseNet teacher loaded successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model on device: {next(model.parameters()).device}")


Loading PhaseNet teacher model...

✓ PhaseNet teacher loaded successfully!
Total parameters: 268,443
Trainable parameters: 268,443
Model on device: cpu


In [39]:
# Load OKLA dataset
print("Loading OKLA regional seismic dataset...")
data = sbd.OKLA_1Mil_120s_Ver_3(sampling_rate=100, force=True, component_order="ENZ")

# Optional: Use subset for faster experimentation
sample_fraction = config.get('data', {}).get('sample_fraction', 0.1)
if sample_fraction < 1.0:
    print(f"Sampling {sample_fraction*100}% of data for faster training...")
    # Create a random mask for sampling
    mask = np.random.random(len(data)) < sample_fraction
    data.filter(mask, inplace=True)
    print(f"Sampled dataset size: {len(data):,}")

# Split into train/dev/test
train, dev, test = data.train_dev_test()

print(f"\n✓ Dataset loaded successfully!")
print(f"Training samples: {len(train):,}")
print(f"Validation samples: {len(dev):,}")
print(f"Test samples: {len(test):,}")
print(f"Total samples: {len(data):,}")

Loading OKLA regional seismic dataset...
Sampling 10.0% of data for faster training...
Sampled dataset size: 113,884

✓ Dataset loaded successfully!
Training samples: 79,670
Validation samples: 17,090
Test samples: 17,124
Total samples: 113,884


In [40]:
# Magnitude filtering (with defaults)
min_magnitude = config.get('data_filter', {}).get('min_magnitude', 1.0)
max_magnitude = config.get('data_filter', {}).get('max_magnitude', 2.0)

print(f"Applying magnitude filters: {min_magnitude} < M < {max_magnitude}")

try:
    # Filter events with magnitude above the minimum
    print(f"✓ [Data Filter]: Start - magnitude > {min_magnitude}")
    mask = data.metadata["source_magnitude"] > min_magnitude
    data.filter(mask, inplace=True)
    print(f"✓ [Data Filter]: Applied - magnitude > {min_magnitude}, remaining samples: {len(data):,}")
except Exception as exc:
    print("✗ [Data Filter]: Error - Failed to apply minimum magnitude filter.")
    print(f"  Details: {exc}")
    raise

try:
    # Filter events with magnitude below the maximum
    print(f"✓ [Data Filter]: Start - magnitude < {max_magnitude}")
    mask = data.metadata["source_magnitude"] < max_magnitude
    data.filter(mask, inplace=True)
    print(f"✓ [Data Filter]: Applied - magnitude < {max_magnitude}, remaining samples: {len(data):,}")
except Exception as exc:
    print("✗ [Data Filter]: Error - Failed to apply maximum magnitude filter.")
    print(f"  Details: {exc}")
    raise

print(f"\n✓ Magnitude filtering complete: {len(data):,} traces in range [{min_magnitude}, {max_magnitude}]")

Applying magnitude filters: 1.0 < M < 2.0
✓ [Data Filter]: Start - magnitude > 1.0
✓ [Data Filter]: Applied - magnitude > 1.0, remaining samples: 108,944
✓ [Data Filter]: Start - magnitude < 2.0
✓ [Data Filter]: Applied - magnitude < 2.0, remaining samples: 36,880

✓ Magnitude filtering complete: 36,880 traces in range [1.0, 2.0]


In [41]:
# Dataset summary for training
print("\n" + "=" * 60)
print("DATASET SUMMARY")
print("=" * 60)

# Core sizes
print(f"Total dataset size: {len(data):,}")
print(f"Train size: {len(train):,}")
print(f"Validation size: {len(dev):,}")
print(f"Test size: {len(test):,}")

# Sampling configuration
sampling_rate = config.get('data', {}).get('sampling_rate', 'unknown')
window_len = config.get('data', {}).get('window_len', 'unknown')
print(f"Sampling rate: {sampling_rate} Hz")
print(f"Window length: {window_len} samples")

# Metadata summary (if available)
if hasattr(data, 'metadata') and data.metadata is not None:
    if 'source_magnitude' in data.metadata:
        mags = data.metadata['source_magnitude']
        print(f"Magnitude stats: min={mags.min():.2f}, max={mags.max():.2f}, mean={mags.mean():.2f}")
    print(f"Metadata columns: {list(data.metadata.columns)}")

print("=" * 60)


DATASET SUMMARY
Total dataset size: 36,880
Train size: 79,670
Validation size: 17,090
Test size: 17,124
Sampling rate: 100 Hz
Window length: 3001 samples
Magnitude stats: min=1.00, max=2.00, mean=1.53
Metadata columns: ['index', 'station_network_code', 'station_code', 'trace_channel', 'station_latitude_deg', 'station_longitude_deg', 'station_elevation_m', 'trace_p_arrival_sample', 'trace_p_status', 'trace_p_weight', 'path_p_travel_sec', 'trace_s_arrival_sample', 'trace_s_status', 'trace_s_weight', 'source_id', 'source_origin_time', 'source_origin_uncertainty_sec', 'source_latitude_deg', 'source_longitude_deg', 'source_error_sec', 'source_gap_deg', 'source_horizontal_uncertainty_km', 'source_depth_km', 'source_depth_uncertainty_km', 'source_magnitude', 'source_magnitude_type', 'source_magnitude_author', 'source_mechanism_strike_dip_rake', 'source_distance_deg', 'source_distance_km', 'path_back_azimuth_deg', 'trace_snr_db', 'trace_coda_end_sample', 'trace_start_time', 'trace_category', 

In [44]:
# Split data into train/dev/test after filtering
train, dev, test = data.train_dev_test()

print("\n✓ Dataset split after filtering")
print(f"Train size: {len(train):,}")
print(f"Validation size: {len(dev):,}")
print(f"Test size: {len(test):,}")

# Split ratios
n_total = len(train) + len(dev) + len(test)
if n_total > 0:
    print(f"Split ratios: train={len(train)/n_total:.2%}, dev={len(dev)/n_total:.2%}, test={len(test)/n_total:.2%}")


✓ Dataset split after filtering
Train size: 25,901
Validation size: 5,518
Test size: 5,461
Split ratios: train=70.23%, dev=14.96%, test=14.81%


In [51]:
# Dataset objects (compact summary)
print("\n" + "=" * 60)
print("DATASET OBJECTS")
print("=" * 60)
print(f"Train dataset: {train}")
print(f"Dev dataset:   {dev}")
print(f"Test dataset:  {test}")
print("=" * 60)


DATASET OBJECTS
Train dataset: OKLA_1Mil_120s_Ver_3 - 25901 traces
Dev dataset:   OKLA_1Mil_120s_Ver_3 - 5518 traces
Test dataset:  OKLA_1Mil_120s_Ver_3 - 5461 traces


In [46]:
# Set up data augmentation

phase_dict = {
    "trace_p_arrival_sample": "P",
    "trace_pP_arrival_sample": "P",
    "trace_P_arrival_sample": "P",
    "trace_P1_arrival_sample": "P",
    "trace_Pg_arrival_sample": "P",
    "trace_Pn_arrival_sample": "P",
    "trace_PmP_arrival_sample": "P",
    "trace_pwP_arrival_sample": "P",
    "trace_pwPm_arrival_sample": "P",
    "trace_s_arrival_sample": "S",
    "trace_S_arrival_sample": "S",
    "trace_S1_arrival_sample": "S",
    "trace_Sg_arrival_sample": "S",
    "trace_SmS_arrival_sample": "S",
    "trace_Sn_arrival_sample": "S",
}

In [47]:
# Create the data generators for training and validation
train_generator = sbg.GenericGenerator(train)
dev_generator = sbg.GenericGenerator(dev)
test_generator = sbg.GenericGenerator(test)

In [48]:
# Define phase lists for labeling
p_phases = [key for key, val in phase_dict.items() if val == "P"]
s_phases = [key for key, val in phase_dict.items() if val == "S"]

train_generator = sbg.GenericGenerator(train)
dev_generator = sbg.GenericGenerator(dev)
test_generator = sbg.GenericGenerator(test)

augmentations = [
    sbg.WindowAroundSample(list(phase_dict.keys()), samples_before=3000, windowlen=6000, selection="random", strategy="variable"),
    sbg.RandomWindow(windowlen=3001, strategy="pad"),
    sbg.Normalize(demean_axis=-1, detrend_axis=-1, amp_norm_axis=-1, amp_norm_type="peak"),
    sbg.ChangeDtype(np.float32),
    sbg.ProbabilisticLabeller(sigma=30, dim=0),
]

train_generator.add_augmentations(augmentations)
dev_generator.add_augmentations(augmentations)
test_generator.add_augmentations(augmentations)


In [50]:
# Parameters for peak detection (with defaults)
sampling_rate = config.get('peak_detection', {}).get('sampling_rate', 100)
height = config.get('peak_detection', {}).get('height', 0.5)
distance = config.get('peak_detection', {}).get('distance', 100)

print("\n" + "=" * 60)
print("PEAK DETECTION SETTINGS")
print("=" * 60)
print(f"Sampling rate: {sampling_rate} Hz")
print(f"Height threshold: {height}")
print(f"Minimum peak distance: {distance} samples")
print("=" * 60)


PEAK DETECTION SETTINGS
Sampling rate: 100 Hz
Height threshold: 0.5
Minimum peak distance: 100 samples


In [53]:
# Parameters for peak detection
batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']
print(f"✓ [DataLoader]: batch_size={batch_size}, num_workers={num_workers}")

✓ [DataLoader]: batch_size=64, num_workers=4


In [62]:
print("\n" + "=" * 60)
print("TRAINING CONFIGURATION SUMMARY")
print("=" * 60)

# Dataset info
print("[Dataset]")
print(f"  Total samples:      {len(data):,}")
print(f"  Train/Validation/Test:     {len(train):,} / {len(dev):,} / {len(test):,}")
print(f"  Sample fraction:    {sample_fraction*100:.1f}%")

# Device
print("\n[Device]")
print(f"  Device:             {device}")

# Training hyperparameters
print("\n[Training]")
print(f"  Batch size:         {batch_size}")
print(f"  Num workers:        {num_workers}")
print(f"  Learning rate:      {config['training']['learning_rate']}")
print(f"  Epochs:             {config['training']['epochs']}")
print(f"  Patience:           {config['training']['patience']}")

# Peak detection
print("\n[Peak Detection]")
print(f"  Sampling rate:      {sampling_rate} Hz")
print(f"  Height threshold:   {height}")
print(f"  Min peak distance:  {distance} samples")

print("=" * 60)
print("Ready to start training!")
print("=" * 60)


TRAINING CONFIGURATION SUMMARY
[Dataset]
  Total samples:      36,880
  Train/Validation/Test:     25,901 / 5,518 / 5,461
  Sample fraction:    10.0%

[Device]
  Device:             cpu

[Training]
  Batch size:         64
  Num workers:        4
  Learning rate:      0.01
  Epochs:             50
  Patience:           5

[Peak Detection]
  Sampling rate:      100 Hz
  Height threshold:   0.5
  Min peak distance:  100 samples
Ready to start training!


In [55]:
# Load the data for machine learning

train_loader = DataLoader(train_generator,batch_size=batch_size, shuffle=True, num_workers=num_workers, worker_init_fn=worker_seeding,pin_memory=True,prefetch_factor=4,persistent_workers=True)
test_loader = DataLoader(test_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding,pin_memory=True,prefetch_factor=4,persistent_workers=True)
val_loader = DataLoader(dev_generator, batch_size=batch_size, shuffle=False, num_workers=num_workers, worker_init_fn=worker_seeding,pin_memory=True,prefetch_factor=4,persistent_workers=True)


In [56]:
# Define loss function
def loss_fn(y_pred, y_true, eps=1e-8):
    h = y_true * torch.log(y_pred + eps)
    h = h.mean(-1).sum(-1)
    h = h.mean()
    return -h

In [63]:
# Learning rate and number of epochs
learning_rate = config['training']['learning_rate']
epochs = config['training']['epochs']

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print("\n" + "=" * 60)
print("OPTIMIZER SETTINGS")
print("=" * 60)
print(f"Optimizer: {optimizer.__class__.__name__}")
print("=" * 60)


OPTIMIZER SETTINGS
Optimizer: Adam


In [64]:
# Early stopping and checkpoint setup
checkpoint_dir = Path.cwd().parent / "checkpoints"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

best_model_path = checkpoint_dir / "best_model.pth"
final_model_path = checkpoint_dir / "final_model.pth"
history_path = checkpoint_dir / "loss_history.json"

patience = config.get('training', {}).get('patience', 5)
min_delta = config.get('training', {}).get('min_delta', 0.0)

early_stopping = EarlyStopping(
    patience=patience,
    min_delta=min_delta,
    checkpoint_dir=checkpoint_dir,
    verbose=True,
)

# Loss history container
history = {
    "train_loss": [],
    "val_loss": []
}

# Helper functions for saving
def save_loss_history(history_dict, path):
    with open(path, "w") as f:
        json.dump(history_dict, f, indent=2)
    print(f"✓ Loss history saved to {path}")


def save_final_model(model, path):
    torch.save({
        "model_state_dict": model.state_dict(),
        "config": config
    }, path)
    print(f"✓ Final model saved to {path}")

print("\n" + "=" * 60)
print("EARLY STOPPING & CHECKPOINTS")
print("=" * 60)
print(f"Checkpoint dir: {checkpoint_dir}")
print(f"Best model:     {best_model_path}")
print(f"Final model:    {final_model_path}")
print(f"History file:   {history_path}")
print(f"Patience:       {patience}")
print(f"Min delta:      {min_delta}")
print("=" * 60)


EARLY STOPPING & CHECKPOINTS
Checkpoint dir: /Users/hongyuxiao/Hongyu_File/xiao_net/checkpoints
Best model:     /Users/hongyuxiao/Hongyu_File/xiao_net/checkpoints/best_model.pth
Final model:    /Users/hongyuxiao/Hongyu_File/xiao_net/checkpoints/final_model.pth
History file:   /Users/hongyuxiao/Hongyu_File/xiao_net/checkpoints/loss_history.json
Patience:       5
Min delta:      0.0


In [66]:
# Training loop with early stopping
print("\n" + "=" * 60)
print("TRAINING")
print("=" * 60)

for epoch in range(epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    dataset_size = len(train_loader.dataset)
    
    for batch_id, batch in enumerate(train_loader):
        # Forward pass
        pred = model(batch["X"].to(device))
        loss = loss_fn(pred, batch["y"].to(device))
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Progress tracking
        if batch_id % 5 == 0:
            current = batch_id * len(batch["X"])
            print(f"  loss: {loss.item():>7f}  [{current:>5d}/{dataset_size:>5d}]")
        
        train_loss += loss.item()
    
    avg_train_loss = train_loss / len(train_loader)
    history["train_loss"].append(avg_train_loss)
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    
    with torch.no_grad():
        for batch in val_loader:
            pred = model(batch["X"].to(device))
            val_loss += loss_fn(pred, batch["y"].to(device)).item()
    
    avg_val_loss = val_loss / len(val_loader)
    history["val_loss"].append(avg_val_loss)
    
    # Print epoch summary
    print(f"\nEpoch {epoch+1}/{epochs} Summary:")
    print(f"  Train Loss: {avg_train_loss:.4f}")
    print(f"  Val Loss:   {avg_val_loss:.4f}")
    
    # Check early stopping
    early_stopping(avg_val_loss, model)
    
    if early_stopping.early_stop:
        print(f"\nEarly stopping triggered at epoch {epoch+1}")
        break
    
    print("-" * 60)

# Save final model and history
save_final_model(model, final_model_path)
save_loss_history(history, history_path)

print("\n" + "=" * 60)
print("TRAINING COMPLETE")
print("=" * 60)
print(f"Best model saved to: {best_model_path}")
print(f"Final model saved to: {final_model_path}")
print(f"Loss history saved to: {history_path}")
print("=" * 60)


TRAINING


RuntimeError: DataLoader worker (pid(s) 8578, 8586, 8590, 8595) exited unexpectedly