In [34]:
#!/usr/bin/env python
# coding: utf-8

"""
Knowledge Distillation Training for XiaoNet
Teacher: PhaseNet (from STEAD)
Student: XiaoNet (v2, v3, v4, or v5)
Dataset: OKLA regional seismic data
"""

# Standard library
import os
import sys
import json
import random
from pathlib import Path

# Scientific computing
import numpy as np
import pandas as pd
from scipy import signal

# Deep learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau, CosineAnnealingLR

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Seismology & SeisBench
import obspy
import seisbench.data as sbd
import seisbench.generate as sbg
import seisbench.models as sbm
from seisbench.util import worker_seeding

# Progress bars
from tqdm.notebook import tqdm

# Add parent directory to path for importing local modules
sys.path.append(str(Path.cwd().parent))

# XiaoNet modules
from models.xn_xiao_net_v2 import XiaoNet as XiaoNetV2
from models.xn_xiao_net_v3 import XiaoNet as XiaoNetV3
from models.xn_xiao_net_v4 import XiaoNetFast as XiaoNetV4
from models.xn_xiao_net_v5 import XiaoNetEdge as XiaoNetV5
from loss.xn_distillation_loss import DistillationLoss
from xn_utils import set_seed, setup_device
from xn_early_stopping import EarlyStopping

print("✓ All packages loaded successfully!")

✓ All packages loaded successfully!


In [35]:
# Set random seed for reproducibility
SEED = 0
set_seed(SEED)

# Set device
device = setup_device('cuda')
print(f"Using device: {device}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Using device: cpu
CUDA available: False


In [36]:
# Load configuration from config.json
config_path = Path.cwd().parent / "config.json"

if not config_path.exists():
    raise FileNotFoundError(f"Config file not found: {config_path}")

with open(config_path, "r") as f:
    config = json.load(f)

print(f"Loaded configuration from: {config_path}")
print(json.dumps(config, indent=2))

Loaded configuration from: /Users/hongyuxiao/Hongyu_File/xiao_net/config.json
{
  "peak_detection": {
    "sampling_rate": 100,
    "height": 0.5,
    "distance": 100
  },
  "data": {
    "dataset_name": "OKLA_1Mil_120s_Ver_3",
    "sampling_rate": 100,
    "window_len": 3001,
    "samples_before": 3000,
    "windowlen_large": 6000,
    "sample_fraction": 0.1
  },
  "data_filter": {
    "min_magnitude": 1.0,
    "max_magnitude": 2.0
  },
  "training": {
    "batch_size": 64,
    "num_workers": 4,
    "learning_rate": 0.01,
    "epochs": 50,
    "patience": 5,
    "loss_weights": [
      0.01,
      0.4,
      0.59
    ],
    "optimization": {
      "mixed_precision": true,
      "gradient_accumulation_steps": 1,
      "pin_memory": true,
      "prefetch_factor": 2,
      "persistent_workers": true
    }
  },
  "device": {
    "use_cuda": true,
    "device_id": 0
  }
}


In [37]:
# Load PhaseNet teacher model (pretrained on STEAD)
print("Available PhaseNet pretrained models:")
sbm.PhaseNet.list_pretrained()

Available PhaseNet pretrained models:


['diting',
 'ethz',
 'geofon',
 'instance',
 'iquique',
 'jma',
 'jma_wc',
 'lendb',
 'neic',
 'obs',
 'original',
 'phasenet_sn',
 'pisdl',
 'scedc',
 'stead',
 'volpick']

In [38]:
print("\nLoading PhaseNet teacher model...")
model = sbm.PhaseNet.from_pretrained("stead")
model.to(device)
model.eval()  # Set to evaluation mode for teacher

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n✓ PhaseNet teacher loaded successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Model on device: {next(model.parameters()).device}")


Loading PhaseNet teacher model...

✓ PhaseNet teacher loaded successfully!
Total parameters: 268,443
Trainable parameters: 268,443
Model on device: cpu


In [39]:
# Load OKLA dataset
print("Loading OKLA regional seismic dataset...")
data = sbd.OKLA_1Mil_120s_Ver_3(sampling_rate=100, force=True, component_order="ENZ")

# Optional: Use subset for faster experimentation
sample_fraction = config.get('data', {}).get('sample_fraction', 0.1)
if sample_fraction < 1.0:
    print(f"Sampling {sample_fraction*100}% of data for faster training...")
    # Create a random mask for sampling
    mask = np.random.random(len(data)) < sample_fraction
    data.filter(mask, inplace=True)
    print(f"Sampled dataset size: {len(data):,}")

# Split into train/dev/test
train, dev, test = data.train_dev_test()

print(f"\n✓ Dataset loaded successfully!")
print(f"Training samples: {len(train):,}")
print(f"Validation samples: {len(dev):,}")
print(f"Test samples: {len(test):,}")
print(f"Total samples: {len(data):,}")

Loading OKLA regional seismic dataset...
Sampling 10.0% of data for faster training...
Sampled dataset size: 113,884

✓ Dataset loaded successfully!
Training samples: 79,670
Validation samples: 17,090
Test samples: 17,124
Total samples: 113,884


In [40]:
# Magnitude filtering (with defaults)
min_magnitude = config.get('data_filter', {}).get('min_magnitude', 1.0)
max_magnitude = config.get('data_filter', {}).get('max_magnitude', 2.0)

print(f"Applying magnitude filters: {min_magnitude} < M < {max_magnitude}")

try:
    # Filter events with magnitude above the minimum
    print(f"✓ [Data Filter]: Start - magnitude > {min_magnitude}")
    mask = data.metadata["source_magnitude"] > min_magnitude
    data.filter(mask, inplace=True)
    print(f"✓ [Data Filter]: Applied - magnitude > {min_magnitude}, remaining samples: {len(data):,}")
except Exception as exc:
    print("✗ [Data Filter]: Error - Failed to apply minimum magnitude filter.")
    print(f"  Details: {exc}")
    raise

try:
    # Filter events with magnitude below the maximum
    print(f"✓ [Data Filter]: Start - magnitude < {max_magnitude}")
    mask = data.metadata["source_magnitude"] < max_magnitude
    data.filter(mask, inplace=True)
    print(f"✓ [Data Filter]: Applied - magnitude < {max_magnitude}, remaining samples: {len(data):,}")
except Exception as exc:
    print("✗ [Data Filter]: Error - Failed to apply maximum magnitude filter.")
    print(f"  Details: {exc}")
    raise

print(f"\n✓ Magnitude filtering complete: {len(data):,} traces in range [{min_magnitude}, {max_magnitude}]")

Applying magnitude filters: 1.0 < M < 2.0
✓ [Data Filter]: Start - magnitude > 1.0
✓ [Data Filter]: Applied - magnitude > 1.0, remaining samples: 108,944
✓ [Data Filter]: Start - magnitude < 2.0
✓ [Data Filter]: Applied - magnitude < 2.0, remaining samples: 36,880

✓ Magnitude filtering complete: 36,880 traces in range [1.0, 2.0]


In [41]:
# Dataset summary for training
print("\n" + "=" * 60)
print("DATASET SUMMARY")
print("=" * 60)

# Core sizes
print(f"Total dataset size: {len(data):,}")
print(f"Train size: {len(train):,}")
print(f"Validation size: {len(dev):,}")
print(f"Test size: {len(test):,}")

# Sampling configuration
sampling_rate = config.get('data', {}).get('sampling_rate', 'unknown')
window_len = config.get('data', {}).get('window_len', 'unknown')
print(f"Sampling rate: {sampling_rate} Hz")
print(f"Window length: {window_len} samples")

# Metadata summary (if available)
if hasattr(data, 'metadata') and data.metadata is not None:
    if 'source_magnitude' in data.metadata:
        mags = data.metadata['source_magnitude']
        print(f"Magnitude stats: min={mags.min():.2f}, max={mags.max():.2f}, mean={mags.mean():.2f}")
    print(f"Metadata columns: {list(data.metadata.columns)}")

print("=" * 60)


DATASET SUMMARY
Total dataset size: 36,880
Train size: 79,670
Validation size: 17,090
Test size: 17,124
Sampling rate: 100 Hz
Window length: 3001 samples
Magnitude stats: min=1.00, max=2.00, mean=1.53
Metadata columns: ['index', 'station_network_code', 'station_code', 'trace_channel', 'station_latitude_deg', 'station_longitude_deg', 'station_elevation_m', 'trace_p_arrival_sample', 'trace_p_status', 'trace_p_weight', 'path_p_travel_sec', 'trace_s_arrival_sample', 'trace_s_status', 'trace_s_weight', 'source_id', 'source_origin_time', 'source_origin_uncertainty_sec', 'source_latitude_deg', 'source_longitude_deg', 'source_error_sec', 'source_gap_deg', 'source_horizontal_uncertainty_km', 'source_depth_km', 'source_depth_uncertainty_km', 'source_magnitude', 'source_magnitude_type', 'source_magnitude_author', 'source_mechanism_strike_dip_rake', 'source_distance_deg', 'source_distance_km', 'path_back_azimuth_deg', 'trace_snr_db', 'trace_coda_end_sample', 'trace_start_time', 'trace_category', 