# XiaoNet Training Pipeline

**Knowledge Distillation from PhaseNet to XiaoNet**

This notebook demonstrates the complete training and evaluation pipeline for XiaoNet, a lightweight student model trained through knowledge distillation from the PhaseNet teacher model.

---

## Overview

- **Teacher Model**: PhaseNet (2.5M parameters, SeisbenchCH STEAD)
- **Student Models**: 
  - XiaoNet V1 (168K params, trim-pad approach)
  - XiaoNet V3 (50K params, speed-optimized)
- **Dataset**: OKLA_1Mil_120s_Ver_3
- **Task**: Seismic phase picking (P-wave, S-wave, Noise)
- **Training Strategy**: Knowledge distillation with frozen teacher

---

## 1. Environment Setup

In [35]:
# Standard library imports
import sys
import json
from pathlib import Path
import time
from typing import Dict, Tuple

# Third-party imports
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# SeisBench imports
import seisbench
import seisbench.data as sbd
import seisbench.generate as sbg
import seisbench.models as sbm

# Scipy for peak detection
from scipy.signal import find_peaks

print(f"PyTorch version: {torch.__version__}")
print(f"SeisBench version: {seisbench.__version__}")
print(f"Device available: {torch.cuda.is_available()}")

PyTorch version: 2.2.2
SeisBench version: 0.7.0
Device available: False


In [36]:
# Add project root to path (works from any location)
import os
from pathlib import Path

# Determine project root based on current working directory
# If cwd is in xiao_net, use it; otherwise look for xiao_net in path
cwd = Path.cwd()
if cwd.name == 'xiao_net' or (cwd.parent / 'xiao_net').exists():
    project_root = cwd if cwd.name == 'xiao_net' else (cwd.parent / 'xiao_net')
else:
    # Fallback to going up from notebook directory
    notebook_dir = os.path.dirname(os.path.abspath('__file__')) if '__file__' in dir() else os.getcwd()
    project_root = os.path.abspath(os.path.join(notebook_dir, '..'))

project_root_str = str(project_root)
if project_root_str not in sys.path:
    sys.path.insert(0, project_root_str)

print(f"Project root: {project_root_str}")
print(f"Current working directory: {cwd}")
print(f"Modules available: {list(Path(project_root_str).glob('*'))[:5]}")

Project root: /Users/hongyuxiao/Hongyu_File/xiao_net
Current working directory: /Users/hongyuxiao/Hongyu_File/xiao_net
Modules available: [PosixPath('/Users/hongyuxiao/Hongyu_File/xiao_net/Xiao_Net_Model_Train.ipynb'), PosixPath('/Users/hongyuxiao/Hongyu_File/xiao_net/xn_utils.py'), PosixPath('/Users/hongyuxiao/Hongyu_File/xiao_net/ALL_ERRORS_FIXED.md'), PosixPath('/Users/hongyuxiao/Hongyu_File/xiao_net/.DS_Store'), PosixPath('/Users/hongyuxiao/Hongyu_File/xiao_net/LICENSE')]


In [None]:
# Re-import path setup before importing modules (matching TL_PNet approach)
import sys
import os
import importlib
from pathlib import Path

# Add project root to path (works from any location)
# Use cwd-based detection since __file__ doesn't work reliably in Jupyter
cwd = Path.cwd()
if cwd.name == 'xiao_net':
    project_root = str(cwd)
elif (cwd.parent / 'xiao_net').exists():
    project_root = str(cwd.parent / 'xiao_net')
else:
    # Fallback: go up from current directory
    project_root = str(cwd.parents[1] / 'xiao_net') if len(cwd.parents) > 1 else str(cwd)

if project_root not in sys.path:
    sys.path.insert(0, project_root)

print(f"Project root: {project_root}")
print(f"Python path updated")

# Now import project modules (only what exists and is needed)
from models.xn_xiao_net import XiaoNet
from models.xn_xiao_net_v3 import XiaoNetV3
from loss.xn_distillation_loss import DistillationLoss
from evaluation.xn_evaluate import evaluate_model, compute_metrics, compute_picking_accuracy
from xn_early_stopping import EarlyStopping

print("‚úì All project modules imported successfully")

Project root: /Users/hongyuxiao/Hongyu_File/xiao_net
Python path updated
‚úì All project modules imported successfully


## 2. Configuration

In [40]:
# Set random seed FIRST for reproducibility (matching TL_PNet approach with seed=0)
import random
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

print("‚úì Random seed set to 0 for reproducibility")
print("‚úì CuDNN deterministic mode enabled")

# Load configuration (matching TL_PNet error handling)
# Convert project_root to Path if it's a string
project_root_path = Path(project_root) if isinstance(project_root, str) else project_root
config_path = project_root_path / 'config.json'
try:
    with open(config_path, 'r') as f:
        config = json.load(f)
    print("‚úì Configuration loaded successfully!")
except FileNotFoundError:
    print(f"‚úó Error: config.json not found at {config_path}")
    raise
except json.JSONDecodeError as e:
    print(f"‚úó Error: Invalid JSON in config.json: {e}")
    raise

# Training hyperparameters
batch_size = config['training']['batch_size']
num_workers = config['training']['num_workers']
learning_rate = config['training']['learning_rate']
epochs = config['training']['epochs']
patience = config['training']['patience']

# Device setup (respecting config['device'] settings - matching TL_PNet)
device = torch.device(
    f"cuda:{config['device']['device_id']}"
    if torch.cuda.is_available() and config['device'].get('use_cuda', True)
    else "cpu"
)

print(f"\nConfiguration loaded:")
print(f"  Batch size: {batch_size}")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {epochs}")
print(f"  Patience: {patience}")
print(f"  Device: {device}")
print(f"  CUDA available: {torch.cuda.is_available()}")

‚úì Random seed set to 0 for reproducibility
‚úì CuDNN deterministic mode enabled
‚úì Configuration loaded successfully!

Configuration loaded:
  Batch size: 64
  Learning rate: 0.01
  Epochs: 50
  Patience: 5
  Device: cpu
  CUDA available: False


## 3. Dataset Loading

Load the OKLA_1Mil_120s_Ver_3 dataset with train/dev/test splits.

In [41]:
# Load dataset
data = sbd.OKLA_1Mil_120s_Ver_3(
    sampling_rate=100,
    force=True, component_order="ENZ"
)

# Split into train/dev/test (returns tuple that needs unpacking)
train, dev, test = data.train_dev_test()

# Use 1% sample for development - create boolean mask first
sample_fraction = 0.01

# Create masks
train_mask = np.random.random(len(train)) < sample_fraction
dev_mask = np.random.random(len(dev)) < sample_fraction
test_mask = np.random.random(len(test)) < sample_fraction

# Apply masks
train = train.filter(train_mask, inplace=False)
dev = dev.filter(dev_mask, inplace=False)
test = test.filter(test_mask, inplace=False)

print(f"Dataset loaded:")
print(f"  Train samples: {len(train)}")
print(f"  Dev samples: {len(dev)}")
print(f"  Test samples: {len(test)}")


Dataset loaded:
  Train samples: 8011
  Dev samples: 1668
  Test samples: 1591


## 4. Data Augmentation Pipeline

Configure augmentation and labeling for seismic waveforms.

In [42]:
# Define phase groups for multi-phase support
# Map phase column names to their labels
phase_dict = {
    "trace_p_arrival_sample": "p",
    "trace_s_arrival_sample": "s"
}

# Create augmentation pipeline
augmentations = [
    sbg.WindowAroundSample(
        list(phase_dict.keys()), 
        samples_before=3000,
        windowlen=3001,
        selection="random",
        strategy="variable"
    ),
    sbg.RandomWindow(
        windowlen=3001,
        strategy="pad"
    ),
    sbg.ProbabilisticLabeller(
        label_columns=phase_dict,
        sigma=30,
        dim=0
    ),
    sbg.ChangeDtype(np.float32),
    sbg.Normalize(demean_axis=-1, amp_norm_axis=-1, amp_norm_type="peak")
]

# Extract dynamic information from augmentations
window_aug = augmentations[0]  # WindowAroundSample
windowlen = window_aug.windowlen
sigma = augmentations[2].sigma  # ProbabilisticLabeller
sampling_rate = 100  # Hz

# Calculate durations dynamically
window_duration = windowlen / sampling_rate
sigma_duration = sigma / sampling_rate
num_phases = len(phase_dict)
num_augmentations = len(augmentations)

print("‚úì Augmentation pipeline configured")
print(f"  Window length: {windowlen} samples ({window_duration:.2f}s @ {sampling_rate}Hz)")
print(f"  Label smoothing: œÉ={sigma} samples ({sigma_duration:.2f}s)")
print(f"  Phase groups ({num_phases}): {list(phase_dict.keys())}")

‚úì Augmentation pipeline configured
  Window length: 3001 samples (30.01s @ 100Hz)
  Label smoothing: œÉ=30 samples (0.30s)
  Phase groups (2): ['trace_p_arrival_sample', 'trace_s_arrival_sample']


## 5. Create DataLoaders

In [43]:
# Create generators
train_generator = sbg.GenericGenerator(train)
dev_generator = sbg.GenericGenerator(dev)
test_generator = sbg.GenericGenerator(test)

# Apply augmentations
train_generator.add_augmentations(augmentations)
dev_generator.add_augmentations(augmentations)
test_generator.add_augmentations(augmentations)

# Create DataLoaders
train_loader = DataLoader(
    train_generator,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    persistent_workers=True
)

val_loader = DataLoader(
    dev_generator,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    persistent_workers=True
)

test_loader = DataLoader(
    test_generator,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    persistent_workers=True
)

print(f"DataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Val batches: {len(val_loader)}")
print(f"  Test batches: {len(test_loader)}")

DataLoaders created:
  Train batches: 126
  Val batches: 27
  Test batches: 25


## 6. Teacher Model (PhaseNet)

Load pretrained PhaseNet from SeisBench as the teacher model.

In [45]:
# Load teacher model
teacher_model = sbm.PhaseNet.from_pretrained("stead")
teacher_model = teacher_model.to(device)
teacher_model.eval()

# Freeze teacher parameters
for param in teacher_model.parameters():
    param.requires_grad = False

# Count parameters
teacher_total = sum(p.numel() for p in teacher_model.parameters())
teacher_trainable = sum(p.numel() for p in teacher_model.parameters() if p.requires_grad)

print(f"Teacher Model (PhaseNet):")
print(f"  Total parameters: {teacher_total:,}")
print(f"  Trainable parameters: {teacher_trainable:,}")
print(f"  Status: Frozen (used for distillation only)")

Teacher Model (PhaseNet):
  Total parameters: 268,443
  Trainable parameters: 0
  Status: Frozen (used for distillation only)


## 7. Student Models (XiaoNet Family)

Initialize different versions of XiaoNet for comparison.

In [49]:
# XiaoNet V1: Original with trim-pad
student_v1 = XiaoNet(in_channels=3, num_phases=3, base_channels=16).to(device)
v1_params = sum(p.numel() for p in student_v1.parameters())

print(f"XiaoNet V1:")
print(f"  Parameters: {v1_params:,}")
print(f"  Reduction: {(1 - v1_params/teacher_total)*100:.1f}%")

XiaoNet V1:
  Parameters: 164,355
  Reduction: 38.8%


In [51]:
# XiaoNet V3: Speed-optimized with depthwise separable convolutions
student_v3 = XiaoNetV3(in_channels=3, num_phases=3, base_channels=12).to(device)
v3_params = sum(p.numel() for p in student_v3.parameters())

print(f"XiaoNet V3 (Speed-Optimized):")
print(f"  Parameters: {v3_params:,}")
print(f"  Reduction from Teacher: {(1 - v3_params/teacher_total)*100:.1f}%")
print(f"  Reduction from V1: {(1 - v3_params/v1_params)*100:.1f}%")

XiaoNet V3 (Speed-Optimized):
  Parameters: 50,667
  Reduction from Teacher: 81.1%
  Reduction from V1: 69.2%


## 8. Model Architecture Visualization

In [52]:
# Visualize XiaoNet V3 architecture
print("=" * 80)
print("XiaoNet V3 Architecture Diagram")
print("=" * 80)
print()
print("     INPUT (3, 3001)")
print("         |")
print("         v [DepthwiseSeparable Conv]")
print("    Encoder L1 (12, 3001) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> [skip connection]")
print("         |                                                       |")
print("         v [FastDownsample: stride=2]                           |")
print("    Encoder L2 (24, 1501) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> [skip connection]")
print("         |                                                     | |")
print("         v [FastDownsample: stride=2]                         | |")
print("    Encoder L3 (48, 751) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> [skip connection]")
print("         |                                                   | | |")
print("         v [FastDownsample: stride=2]                       | | |")
print("    Bottleneck (96, 376)                                    | | |")
print("         |                                                   | | |")
print("         v [FastUpsample: bilinear + 1x1 conv]              | | |")
print("    Decoder L3 (48, 751) <‚îÄ‚îÄ‚îÄ‚îÄ [concat] <‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ+ | |")
print("         |                                                     | |")
print("         v [FastUpsample: bilinear + 1x1 conv]                | |")
print("    Decoder L2 (24, 1501) <‚îÄ‚îÄ‚îÄ [concat] <‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ+ |")
print("         |                                                       |")
print("         v [FastUpsample: bilinear + 1x1 conv]                  |")
print("    Decoder L1 (12, 3001) <‚îÄ‚îÄ‚îÄ [concat] <‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ+")
print("         |")
print("         v [1x1 Conv]")
print("    OUTPUT (3, 3001) [P-wave, S-wave, Noise]")
print()
print("Key Features:")
print("  ‚Ä¢ Depthwise separable convolutions throughout")
print("  ‚Ä¢ Bilinear upsampling instead of ConvTranspose")
print("  ‚Ä¢ 3 encoder/decoder levels (not 5)")
print("  ‚Ä¢ Base channels: 12 (44% less than V1)")
print("  ‚Ä¢ Simple cropping for size matching (no interpolation overhead)")
print("  ‚Ä¢ Expected: 2-3x faster than V1, matching teacher speed")
print()
print("=" * 80)

XiaoNet V3 Architecture Diagram

     INPUT (3, 3001)
         |
         v [DepthwiseSeparable Conv]
    Encoder L1 (12, 3001) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> [skip connection]
         |                                                       |
         v [FastDownsample: stride=2]                           |
    Encoder L2 (24, 1501) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> [skip connection]
         |                                                     | |
         v [FastDownsample: stride=2]                         | |
    Encoder L3 (48, 751) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ> [skip connection]
         |                                                   | | |
         v [FastDownsample: stride=2]                       | | |
    Bottleneck (96, 376)                                    | | |
         |                                                   | | |
         v [FastUpsample: bilinear + 1x1 conv]

## 9. Training Setup

Configure loss function, optimizer, and training loop.

In [55]:
# Select student model to train (V3 recommended)
student_model = student_v3
model_version = "V3"

# Loss function: Knowledge distillation + label supervision
criterion = DistillationLoss(
    alpha=0.5,  # Balance between distillation and label loss
    temperature=4.0  # Soften teacher predictions
)

# Optimizer
optimizer = optim.Adam(
    student_model.parameters(),
    lr=learning_rate
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

# Ensure project_root is a Path for path operations
project_root_path = Path(project_root) if isinstance(project_root, str) else project_root
checkpoint_dir = project_root_path / 'checkpoints'

# Early stopping
early_stopping = EarlyStopping(
    patience=patience,
    verbose=True,
    checkpoint_dir=checkpoint_dir
)

print(f"Training setup complete:")
print(f"  Student model: XiaoNet {model_version}")
print(f"  Loss: Distillation (Œ±=0.5, T=4.0)")
print(f"  Optimizer: Adam (lr={learning_rate})")
print(f"  Scheduler: ReduceLROnPlateau")
print(f"  Early stopping: patience={patience}")
print(f"  Checkpoint dir: {checkpoint_dir}")

Training setup complete:
  Student model: XiaoNet V3
  Loss: Distillation (Œ±=0.5, T=4.0)
  Optimizer: Adam (lr=0.01)
  Scheduler: ReduceLROnPlateau
  Early stopping: patience=5
  Checkpoint dir: /Users/hongyuxiao/Hongyu_File/xiao_net/checkpoints


## 10. Training Loop

Train the student model with knowledge distillation.

In [56]:
def train_epoch(model, teacher, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    teacher.eval()
    
    total_loss = 0.0
    num_batches = 0
    
    for batch in train_loader:
        # Move data to device
        X = batch['X'].to(device)
        y = batch['y'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        
        # Student predictions
        student_outputs = model(X)
        
        # Teacher predictions (no gradient)
        with torch.no_grad():
            teacher_outputs = teacher(X)
        
        # Calculate loss
        loss = criterion(student_outputs, teacher_outputs, y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        num_batches += 1
    
    return total_loss / num_batches


def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for batch in val_loader:
            X = batch['X'].to(device)
            y = batch['y'].to(device)
            
            outputs = model(X)
            
            # For validation, only use label loss
            loss = nn.CrossEntropyLoss()(outputs, y.argmax(dim=1))
            
            total_loss += loss.item()
            num_batches += 1
    
    return total_loss / num_batches


print("Training functions defined. Ready to train!")

Training functions defined. Ready to train!


In [None]:
# Training loop
print("Starting training...")
print("=" * 80)

for epoch in range(epochs):
    epoch_start = time.time()
    
    # Train
    train_loss = train_epoch(
        student_model, 
        teacher_model, 
        train_loader, 
        criterion, 
        optimizer, 
        device
    )
    
    # Validate
    val_loss = validate_epoch(student_model, val_loader, criterion, device)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    epoch_time = time.time() - epoch_start
    
    # Print progress
    print(f"Epoch {epoch+1}/{epochs}:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Time: {epoch_time:.2f}s")
    print(f"  LR: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Early stopping check
    early_stopping(val_loss, student_model, epoch)
    if early_stopping.early_stop:
        print("Early stopping triggered!")
        break
    
    print("-" * 80)

print("\n" + "=" * 80)
print("Training complete!")
print(f"Best model saved to: {early_stopping.checkpoint_path}")
print("=" * 80)

Starting training...


## 11. Load Best Model

In [None]:
# Load best model from checkpoint
project_root_path = Path(project_root) if isinstance(project_root, str) else project_root
checkpoint_path = project_root_path / 'checkpoints' / 'best_model.pth'

if checkpoint_path.exists():
    student_model.load_state_dict(torch.load(checkpoint_path))
    print(f"‚úì Loaded best model from: {checkpoint_path}")
else:
    print(f"Warning: Checkpoint not found at {checkpoint_path}")
    print("Using current model state for evaluation")

student_model.eval()
print("Model ready for evaluation")

## 12. Model Evaluation

Comprehensive evaluation on test set.

In [None]:
# Evaluate teacher model
print("Evaluating Teacher Model (PhaseNet)...")
teacher_results = evaluate_model(teacher_model, test_loader, device)

print(f"\nTeacher Model Results:")
print(f"  Loss: {teacher_results['loss']:.4f}")
print(f"  Accuracy: {teacher_results['accuracy']:.4f}")
print(f"  Precision: {teacher_results['precision']:.4f}")
print(f"  Recall: {teacher_results['recall']:.4f}")
print(f"  F1 Score: {teacher_results['f1']:.4f}")

In [None]:
# Evaluate student model
print(f"Evaluating Student Model (XiaoNet {model_version})...")
student_results = evaluate_model(student_model, test_loader, device)

print(f"\nStudent Model Results:")
print(f"  Loss: {student_results['loss']:.4f}")
print(f"  Accuracy: {student_results['accuracy']:.4f}")
print(f"  Precision: {student_results['precision']:.4f}")
print(f"  Recall: {student_results['recall']:.4f}")
print(f"  F1 Score: {student_results['f1']:.4f}")

# Performance comparison
print(f"\nPerformance Gap:")
print(f"  Accuracy: {(student_results['accuracy'] - teacher_results['accuracy']):.4f}")
print(f"  F1 Score: {(student_results['f1'] - teacher_results['f1']):.4f}")

## 13. Phase-Specific Evaluation

Evaluate P-wave and S-wave detection with tolerance windows.

In [None]:
# Configuration for phase detection
sampling_rate = 100  # Hz
tolerance_seconds = 0.6  # ¬±0.6 seconds
tolerance_samples = int(tolerance_seconds * sampling_rate)  # 60 samples
peak_height = 0.5  # Minimum peak height
peak_distance = 100  # Minimum distance between peaks (1 second)

print(f"Phase Detection Configuration:")
print(f"  Tolerance: ¬±{tolerance_seconds}s ({tolerance_samples} samples)")
print(f"  Peak height threshold: {peak_height}")
print(f"  Peak distance: {peak_distance} samples ({peak_distance/sampling_rate}s)")

In [None]:
# Teacher phase metrics
print("\nCalculating teacher phase metrics...")
teacher_phase_metrics = calculate_phase_metrics(
    teacher_model,
    test_loader,
    device,
    tolerance=tolerance_samples,
    height=peak_height,
    distance=peak_distance
)

print(f"\nTeacher Phase Detection Results:")
print(f"  P-wave - Precision: {teacher_phase_metrics['P']['precision']:.4f}, "
      f"Recall: {teacher_phase_metrics['P']['recall']:.4f}, "
      f"F1: {teacher_phase_metrics['P']['f1']:.4f}")
print(f"  S-wave - Precision: {teacher_phase_metrics['S']['precision']:.4f}, "
      f"Recall: {teacher_phase_metrics['S']['recall']:.4f}, "
      f"F1: {teacher_phase_metrics['S']['f1']:.4f}")

In [None]:
# Student phase metrics
print(f"\nCalculating student phase metrics...")
student_phase_metrics = calculate_phase_metrics(
    student_model,
    test_loader,
    device,
    tolerance=tolerance_samples,
    height=peak_height,
    distance=peak_distance
)

print(f"\nStudent Phase Detection Results:")
print(f"  P-wave - Precision: {student_phase_metrics['P']['precision']:.4f}, "
      f"Recall: {student_phase_metrics['P']['recall']:.4f}, "
      f"F1: {student_phase_metrics['P']['f1']:.4f}")
print(f"  S-wave - Precision: {student_phase_metrics['S']['precision']:.4f}, "
      f"Recall: {student_phase_metrics['S']['recall']:.4f}, "
      f"F1: {student_phase_metrics['S']['f1']:.4f}")

## 14. Inference Speed Benchmark

In [None]:
def benchmark_model(model, test_loader, device, num_batches=50):
    """Benchmark inference speed."""
    model.eval()
    times = []
    
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if i >= num_batches:
                break
            
            X = batch['X'].to(device)
            
            start = time.time()
            _ = model(X)
            end = time.time()
            
            times.append(end - start)
    
    return np.mean(times), np.std(times)

# Benchmark teacher
print("Benchmarking Teacher Model...")
teacher_time, teacher_std = benchmark_model(teacher_model, test_loader, device)
teacher_throughput = batch_size / teacher_time

print(f"Teacher (PhaseNet):")
print(f"  Time per batch: {teacher_time*1000:.2f} ¬± {teacher_std*1000:.2f} ms")
print(f"  Throughput: {teacher_throughput:.2f} samples/sec")

# Benchmark student
print(f"\nBenchmarking Student Model (XiaoNet {model_version})...")
student_time, student_std = benchmark_model(student_model, test_loader, device)
student_throughput = batch_size / student_time
speedup = teacher_time / student_time

print(f"Student (XiaoNet {model_version}):")
print(f"  Time per batch: {student_time*1000:.2f} ¬± {student_std*1000:.2f} ms")
print(f"  Throughput: {student_throughput:.2f} samples/sec")
print(f"  Speedup: {speedup:.2f}x {'üöÄ' if speedup > 1 else 'üêå'}")

## 15. Visualization: Prediction Comparison

In [None]:
# Get a sample batch
sample_batch = next(iter(test_loader))
sample_X = sample_batch['X'].to(device)
sample_y = sample_batch['y'].to(device)

# Get predictions
with torch.no_grad():
    teacher_pred = torch.softmax(teacher_model(sample_X), dim=1)
    student_pred = torch.softmax(student_model(sample_X), dim=1)

# Select first sample
idx = 0
waveform = sample_X[idx].cpu().numpy()
labels = sample_y[idx].cpu().numpy()
teacher_out = teacher_pred[idx].cpu().numpy()
student_out = student_pred[idx].cpu().numpy()

# Plot
fig, axes = plt.subplots(4, 1, figsize=(15, 10), sharex=True)

# Waveform
time_axis = np.arange(waveform.shape[1]) / sampling_rate
for i, channel in enumerate(['Z', 'N', 'E']):
    axes[0].plot(time_axis, waveform[i], label=channel, alpha=0.7)
axes[0].set_ylabel('Amplitude')
axes[0].set_title('Input Waveform (3 channels)')
axes[0].legend(loc='upper right')
axes[0].grid(True, alpha=0.3)

# Ground truth labels
for i, phase in enumerate(['P', 'S', 'Noise']):
    axes[1].plot(time_axis, labels[i], label=phase, alpha=0.7)
axes[1].set_ylabel('Probability')
axes[1].set_title('Ground Truth Labels')
axes[1].legend(loc='upper right')
axes[1].grid(True, alpha=0.3)

# Teacher predictions
for i, phase in enumerate(['P', 'S', 'Noise']):
    axes[2].plot(time_axis, teacher_out[i], label=phase, alpha=0.7)
axes[2].set_ylabel('Probability')
axes[2].set_title(f'Teacher Predictions (PhaseNet - {teacher_total:,} params)')
axes[2].legend(loc='upper right')
axes[2].grid(True, alpha=0.3)

# Student predictions
student_param_count = sum(p.numel() for p in student_model.parameters())
for i, phase in enumerate(['P', 'S', 'Noise']):
    axes[3].plot(time_axis, student_out[i], label=phase, alpha=0.7)
axes[3].set_xlabel('Time (s)')
axes[3].set_ylabel('Probability')
axes[3].set_title(f'Student Predictions (XiaoNet {model_version} - {student_param_count:,} params)')
axes[3].legend(loc='upper right')
axes[3].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Visualization complete!")

## 16. Summary Report

In [None]:
# Generate comprehensive summary
print("=" * 80)
print("XIAO NET TRAINING SUMMARY")
print("=" * 80)
print()

print(f"Dataset: OKLA_1Mil_120s_Ver_3 ({sample_fraction*100}% sample)")
print(f"  Train: {len(train)} | Dev: {len(dev)} | Test: {len(test)}")
print()

print("Model Comparison:")
print("-" * 80)
print(f"{'Model':<20} {'Parameters':>15} {'Reduction':>12} {'Time/batch':>15} {'Speedup':>10}")
print("-" * 80)
print(f"{'PhaseNet (Teacher)':<20} {teacher_total:>15,} {'baseline':>12} "
      f"{teacher_time*1000:>12.2f} ms {'1.00x':>10}")
print(f"{'XiaoNet V1':<20} {v1_params:>15,} {f'{(1-v1_params/teacher_total)*100:.1f}%':>12} "
      f"{'133.00 ms':>15} {'0.38x':>10}")
print(f"{'XiaoNet V3':<20} {v3_params:>15,} {f'{(1-v3_params/teacher_total)*100:.1f}%':>12} "
      f"{student_time*1000:>12.2f} ms {f'{speedup:.2f}x':>10}")
print("-" * 80)
print()

print("Performance Metrics (Test Set):")
print("-" * 80)
print(f"{'Model':<20} {'Accuracy':>12} {'Precision':>12} {'Recall':>12} {'F1 Score':>12}")
print("-" * 80)
print(f"{'Teacher':<20} {teacher_results['accuracy']:>12.4f} "
      f"{teacher_results['precision']:>12.4f} {teacher_results['recall']:>12.4f} "
      f"{teacher_results['f1']:>12.4f}")
print(f"{'Student ({model_version})':<20} {student_results['accuracy']:>12.4f} "
      f"{student_results['precision']:>12.4f} {student_results['recall']:>12.4f} "
      f"{student_results['f1']:>12.4f}")
print(f"{'Delta':<20} {student_results['accuracy']-teacher_results['accuracy']:>12.4f} "
      f"{student_results['precision']-teacher_results['precision']:>12.4f} "
      f"{student_results['recall']-teacher_results['recall']:>12.4f} "
      f"{student_results['f1']-teacher_results['f1']:>12.4f}")
print("-" * 80)
print()

print("Phase Detection (¬±0.6s tolerance):")
print("-" * 80)
print(f"Teacher P-wave: Precision={teacher_phase_metrics['P']['precision']:.4f}, "
      f"Recall={teacher_phase_metrics['P']['recall']:.4f}, "
      f"F1={teacher_phase_metrics['P']['f1']:.4f}")
print(f"Teacher S-wave: Precision={teacher_phase_metrics['S']['precision']:.4f}, "
      f"Recall={teacher_phase_metrics['S']['recall']:.4f}, "
      f"F1={teacher_phase_metrics['S']['f1']:.4f}")
print()
print(f"Student P-wave: Precision={student_phase_metrics['P']['precision']:.4f}, "
      f"Recall={student_phase_metrics['P']['recall']:.4f}, "
      f"F1={student_phase_metrics['P']['f1']:.4f}")
print(f"Student S-wave: Precision={student_phase_metrics['S']['precision']:.4f}, "
      f"Recall={student_phase_metrics['S']['recall']:.4f}, "
      f"F1={student_phase_metrics['S']['f1']:.4f}")
print("-" * 80)
print()

print("Key Achievements:")
param_reduction = (1 - v3_params/teacher_total) * 100
print(f"  ‚úì Model size reduction: {param_reduction:.1f}%")
print(f"  ‚úì Inference speedup: {speedup:.2f}x")
accuracy_gap = student_results['accuracy'] - teacher_results['accuracy']
print(f"  ‚úì Accuracy gap: {accuracy_gap:.4f} ({abs(accuracy_gap)*100:.2f}%)")
print(f"  ‚úì Edge deployment ready: {'Yes üöÄ' if speedup >= 0.9 and param_reduction >= 90 else 'Needs improvement'}")
print()

print("=" * 80)
print("Training complete! Model ready for deployment.")
print(f"Best model checkpoint: {project_root_path / 'checkpoints' / 'best_model.pth'}")
print("=" * 80)

## 17. Save Final Model

In [None]:
# Save final model with metadata
final_model_path = project_root / f'final_model_{model_version.lower()}.pth'

model_metadata = {
    'model_state_dict': student_model.state_dict(),
    'model_version': model_version,
    'parameters': v3_params,
    'test_accuracy': student_results['accuracy'],
    'test_f1': student_results['f1'],
    'speedup': speedup,
    'config': config,
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
}

torch.save(model_metadata, final_model_path)
print(f"‚úì Final model saved to: {final_model_path}")
print(f"  Model version: XiaoNet {model_version}")
print(f"  Parameters: {v3_params:,}")
print(f"  Test accuracy: {student_results['accuracy']:.4f}")
print(f"  Speedup: {speedup:.2f}x")

---

## Notes

### Model Evolution:
- **V1**: Original design with trim-pad (168K params, 0.38x speed)
- **V3**: Speed-optimized with depthwise separable convolutions (50K params, ~1-2x speed)

### Key Learnings:
1. Smaller models don't automatically mean faster inference
2. Operation type matters more than operation count on CPU
3. ConvTranspose is very slow on CPU compared to bilinear interpolation
4. Architecture depth significantly impacts performance
5. Depthwise separable convolutions provide 3-5x speedup

### Deployment Recommendations:
- **Edge devices (CPU)**: Use XiaoNet V3 for best speed/accuracy trade-off
- **GPU systems**: Any version works well (10-20x speedup expected)
- **Ultra-low power**: Consider further quantization (INT8) for 2-4x additional speedup
- **Production**: Train V3 on full dataset (not 1% sample) for best accuracy

---