<div style="display: flex; justify-content: center; margin-bottom: 20px;">
  <img src="../docs/_static/seispolarity_logo_title.svg">
</div>
7
---

## Training API Usage Example


Import necessary libraries

In [14]:
# Import necessary libraries for training
import sys
from pathlib import Path

import torch
import numpy as np

sys.path.append(str(Path.cwd().parent))

from seispolarity.data.base import WaveformDataset
from seispolarity.models import SCSN
from seispolarity.training import Trainer, TrainingConfig


## Device Configuration

In [15]:
# Select device based on availability
DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

print(f"Using device: {DEVICE}")

# Additional device info
if DEVICE == "cuda":
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
    print(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
elif DEVICE == "mps":
    print(f"MPS device available")
else:
    print(f"CPU device")

Using device: cuda
CUDA device: NVIDIA GeForce RTX 4060 Ti
CUDA memory: 8.59 GB


## Dataset Configuration

Configure the dataset path and parameters for loading seismic waveform data.

In [16]:
# Dataset path configuration
# Option 1: Use local file
# DATA_PATH = "/path/to/your/dataset.hdf5"

# Option 2: Use automatic download (uncomment to use)
# from seispolarity.data import SCSNData
# scsn_processor = SCSNData(output_dir="./datasets")
# DATA_PATH = scsn_processor.download()
# print(DATA_PATH)

# For this example, we'll use a placeholder path
# Replace this with your actual dataset path
DATA_PATH = r"/mnt/f/AI_Seismic_Data/scsn/scsn_p_2000_2017_6sec_0.5r_fm_train.hdf5"

# Output directory for checkpoints
OUT_DIR = "./checkpoints_ross_scsn"

print(f"Dataset path: {DATA_PATH}")
print(f"Output directory: {OUT_DIR}")

Dataset path: /mnt/f/AI_Seismic_Data/scsn/scsn_p_2000_2017_6sec_0.5r_fm_train.hdf5
Output directory: ./checkpoints_ross_scsn


## Data Parameters

In [17]:
# Data loading parameters
PRELOAD = True
ALLOWED_LABELS = [0, 1, 2]
CROP_LEFT = 200
CROP_RIGHT = 200

# For SCSN dataset:
# - p_pick_position: Fixed P-wave position in the waveform (sample 300)
# - crop_left: Number of samples to keep before P-wave
# - crop_right: Number of samples to keep after P-wave
# Result: 400-point window centered on P-wave (100-500)

print(f"Data parameters:")
print(f"  Preload data: {PRELOAD}")
print(f"  Allowed labels: {ALLOWED_LABELS}")
print(f"  P-wave position: 300")
print(f"  Crop left: {CROP_LEFT} samples")
print(f"  Crop right: {CROP_RIGHT} samples")
print(f"  Output window: 400 samples (100-500)")

Data parameters:
  Preload data: True
  Allowed labels: [0, 1, 2]
  P-wave position: 300
  Crop left: 200 samples
  Crop right: 200 samples
  Output window: 400 samples (100-500)


## Create Dataset

In [9]:
# Create waveform dataset
dataset = WaveformDataset(
    path=DATA_PATH,
    name="SCSN_Train",
    preload=PRELOAD,
    allowed_labels=ALLOWED_LABELS,
    data_key="X",
    label_key="Y",
    clarity_key=None,
    pick_key=None,
    metadata_keys=[],
    p_pick_position=300,
    crop_left=CROP_LEFT,
    crop_right=CROP_RIGHT
)

print(f"Dataset created: {dataset._name}")
print(f"Dataset length: {len(dataset)}")
print(f"Allowed labels: {dataset.allowed_labels}")

2026-02-07 10:27:37,587 - seispolarity.data - INFO - Initialized Flat Dataset from scsn_p_2000_2017_6sec_0.5r_fm_train.hdf5 with 2494194 samples.
2026-02-07 10:27:37,587 | seispolarity.data | INFO | Initialized Flat Dataset from scsn_p_2000_2017_6sec_0.5r_fm_train.hdf5 with 2494194 samples.


2026-02-07 10:27:37,607 - seispolarity.data - INFO - Built index with 2494194 samples (filtered from 2494194)
2026-02-07 10:27:37,607 | seispolarity.data | INFO | Built index with 2494194 samples (filtered from 2494194)
2026-02-07 10:27:37,611 - seispolarity.data - INFO - Loading 2494194 samples into RAM...
2026-02-07 10:27:37,611 | seispolarity.data | INFO | Loading 2494194 samples into RAM...
2026-02-07 10:27:37,614 - seispolarity.data - INFO - Loading Metadata...
2026-02-07 10:27:37,614 | seispolarity.data | INFO | Loading Metadata...
Loading RAM: 100%|██████████| 2494194/2494194 [00:16<00:00, 153521.92samples/s]
2026-02-07 10:27:53,872 - seispolarity.data - INFO - RAM Load Complete.
2026-02-07 10:27:53,872 | seispolarity.data | INFO | RAM Load Complete.


Dataset created: SCSN_Train
Dataset length: 2494194
Allowed labels: [0, 1, 2]


## Training Configuration

In [None]:
# Training parameters
EPOCHS = 50
BATCH_SIZE = 256
LR = 1e-3
NUM_WORKERS = 4

# Create training configuration
config = TrainingConfig(
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    learning_rate=LR,
    num_workers=NUM_WORKERS,
    device=DEVICE,
    checkpoint_dir=OUT_DIR,
    label_key="label",
    train_val_split=0.8,
    val_split=0.1,
    test_split=0.1,
    patience=5,
    random_seed=36
)

print(f"Training configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Learning rate: {LR}")

print(f"  Num workers: {NUM_WORKERS}")
print(f"  Device: {DEVICE}")
print(f"  Train/Val/Test split: {config.train_val_split}/{config.val_split}/{config.test_split}")
print(f"  Early stopping patience: {config.patience}")
print(f"  Random seed: {config.random_seed}")

Training configuration:
  Epochs: 50
  Batch size: 256
  Learning rate: 0.001
  Num workers: 4
  Device: cuda
  Train/Val/Test split: 0.8/0.1/0.1
  Early stopping patience: 5
  Random seed: 36


## Create Model

In [19]:
# Create Ross (SCSN) model
# num_fm_classes=3: number of polarity classes (Up, Down, Unknown)
model = SCSN(
    num_fm_classes=3,
    sample_rate=100.0
)

# Move model to device
model.to(DEVICE)

print(f"Model created: {model.name}")
print(f"Model device: {model.device}")
print(f"\nModel architecture:")
print(model)

# Count model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Create output directory for checkpoints
Path(OUT_DIR).mkdir(parents=True, exist_ok=True)

print(f"Output directory created: {OUT_DIR}")

Model created: SCSN
Model device: cuda

Model architecture:
SCSN(
  (shared_backbone): SharedBackbone(
    (sequential): Sequential(
      (0): Conv1d(1, 32, kernel_size=(21,), stride=(1,), padding=same)
      (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU()
      (3): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (4): Conv1d(32, 64, kernel_size=(15,), stride=(1,), padding=same)
      (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU()
      (7): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (8): Conv1d(64, 128, kernel_size=(11,), stride=(1,), padding=same)
      (9): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): ReLU()
      (11): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (12): Flatten(start_dim=1, end_dim=-1)
      (13): Linear

## Initialize Trainer

In [20]:
# Create trainer
trainer = Trainer(
    model=model,
    dataset=dataset,
    val_dataset=None,
    test_dataset=None,
    config=config
)

print(f"Trainer initialized")
print(f"Model: {trainer.model.name}")
print(f"Dataset: {trainer.dataset._name}")
print(f"Device: {trainer.config.device}")

Trainer initialized
Model: SCSN
Dataset: SCSN_Train
Device: cuda


## Start Training

The trainer will:
1. Split the dataset into train/val/test sets
2. Train the model for the specified number of epochs
3. Save checkpoints periodically
4. Apply early stopping if validation accuracy doesn't improve
5. Return the best validation accuracy and final test accuracy

In [None]:
# Start training
print("Starting training...")
print("="*50)

best_val_acc, final_test_acc = trainer.train()

print("="*50)
print(f"\nTraining completed!")
print(f"Best validation accuracy: {best_val_acc:.4f}")
print(f"Final test accuracy: {final_test_acc:.4f}")

## Save Final Model

In [23]:
# Save final model
final_model_path = Path(OUT_DIR) / "final_model.pth"
torch.save(model.state_dict(), final_model_path)

print(f"Final model saved to: {final_model_path}")

Final model saved to: checkpoints_ross_scsn/final_model.pth


## Cleanup

In [22]:
# Clean up resources
if not PRELOAD and hasattr(dataset, 'dataset') and hasattr(dataset.dataset, 'close'):
    dataset.dataset.close()
    print("Dataset closed")
else:
    print("Cleanup completed")

Cleanup completed
