# üåã Foreshock-Aftershock Classification with xLSTM-UNet

This notebook evaluates a fine-tuned xLSTM-UNet model on the Norcia earthquake foreshock/aftershock classification task.

**Task**: Classify seismic events into 9 temporal classes:
- 4 foreshock classes (FEQ1-FEQ4)
- 1 Visso event class
- 4 aftershock classes (AEQ1-AEQ4)

**Model**: xLSTM-UNet fine-tuned from contrastive pretraining

**Methodology**: Matches SeisLM's approach (temporal splitting, frozen encoder, etc.)

## üì¶ Setup

In [None]:
import os
import sys
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, accuracy_score, ConfusionMatrixDisplay
from omegaconf import OmegaConf

# Add project to path
# sys.path.insert(0, '/path/to/this/repo')

from dataloaders.foreshock_aftershock_lit import ForeshockAftershockLitDataset
from simple_train import SimpleSeqModel

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## üîß Configuration

In [None]:
# === CHECKPOINT PATH ===
CKPT_PATH = '/path/to/your/checkpoint_or_data'

# === DATASET CONFIG ===
DATA_DIR = '/path/to/your/checkpoint_or_data'
NUM_CLASSES = 9
BATCH_SIZE = 32

# === DISPLAY LABELS ===
DISPLAY_LABELS = [
    "FEQ1",  # Foreshock class 1 (earliest)
    "FEQ2",  # Foreshock class 2
    "FEQ3",  # Foreshock class 3
    "FEQ4",  # Foreshock class 4 (latest before main)
    "Visso", # Visso event
    "AEQ1",  # Aftershock class 1 (earliest after main)
    "AEQ2",  # Aftershock class 2
    "AEQ3",  # Aftershock class 3
    "AEQ4",  # Aftershock class 4 (latest)
]

print(f"‚úÖ Checkpoint: {CKPT_PATH}")
print(f"‚úÖ Data directory: {DATA_DIR}")
print(f"‚úÖ Number of classes: {NUM_CLASSES}")
print(f"‚úÖ Batch size: {BATCH_SIZE}")

## üìä Load Dataset

Using the same configuration as training:
- **Temporal splitting**: Events split temporally (prevents data leakage)
- **Event-level split**: Different earthquakes in train/val/test
- **Normalization**: std-norm per channel (seisLM-style)
- **Component order**: ZNE
- **Dimension order**: NWC (Batch, Width, Channels)

In [None]:
# Create dataset with same config as training
dataset = ForeshockAftershockLitDataset(
    data_dir=DATA_DIR,
    num_classes=NUM_CLASSES,
    batch_size=BATCH_SIZE,
    event_split_method='temporal',  # Match seisLM
    component_order='ZNE',          # Match seisLM
    seed=42,
    remove_class_overlapping_dates=False,
    train_frac=0.7,
    val_frac=0.10,
    test_frac=0.20,
    dimension_order='NWC',          # Match seisLM
    demean_axis=1,                  # Match seisLM (per channel)
    amp_norm_axis=1,                # Match seisLM (per channel)
    amp_norm_type='std',            # Match seisLM
    num_workers=0,
    collator=None,
)

test_loader = dataset.test_loader
print(f"‚úÖ Test set loaded: {len(test_loader.dataset)} samples")
print(f"‚úÖ Number of batches: {len(test_loader)}")

## üß† Load Model

Load the fine-tuned xLSTM-UNet model from checkpoint

In [None]:
# Load model config directly from checkpoint (like SeisLM!)
# This is the key: SeisLM loads the config FROM the checkpoint, not from experiment files
print("Loading checkpoint...")
state = torch.load(CKPT_PATH, map_location=device, weights_only=False)

# Extract hyperparameters that were used during training
cfg = state['hyper_parameters']
print(f"Checkpoint model d_model: {cfg.model.d_model}")

# Disable struct mode to modify config
OmegaConf.set_struct(cfg, False)

# Disable pretrained loading (we already have the trained weights)
if 'pretrained' in cfg.model:
    cfg.model.pretrained = None

# Add full encoder config (from pretrained checkpoint)
# The checkpoint only has encoder.pretrained=true, but we need the full config
if 'encoder' not in cfg or '_name_' not in cfg.encoder:
    print("Adding encoder config from pretrained checkpoint...")
    cfg.encoder = OmegaConf.create({
        '_name_': 'conv-down-encoder-contrastive',
        'kernel_size': 3,
        'n_layers': 2,
        'dim': 256,
        'stride': 2,
        'pretrained': False,  # Don't reload, we'll load from checkpoint
    })

OmegaConf.set_struct(cfg, True)

# Instantiate model with the SAME config as training
print("Instantiating model...")
model = SimpleSeqModel(cfg, d_data=3).to(device)

# Load the fine-tuned weights
print("Loading trained weights...")
model.load_state_dict(state['state_dict'], strict=False)

# Ensure classification mode (not pretraining)
try:
    model.model.pretraining = False
    model.encoder.pretraining = False
except Exception:
    pass

model.eval()
print("‚úÖ Model loaded successfully")
print(f"‚úÖ Model in eval mode: {not model.training}")

## üî¨ Evaluate Model

Run inference on the test set and collect predictions

In [None]:
all_preds = []
all_targets = []
total_loss = 0.0
n_samples = 0

print("Running inference...")
with torch.no_grad():
    for batch_idx, (x, y) in enumerate(test_loader):
        x = x.to(device)
        y = y.to(device)
        
        # Forward pass
        logits, targets = model.forward((x, y), batch_idx)
        
        # Compute loss
        loss = F.cross_entropy(logits, targets)
        total_loss += loss.item() * targets.shape[0]
        n_samples += targets.shape[0]
        
        # Get predictions
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.detach().cpu().numpy())
        all_targets.append(targets.detach().cpu().numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f"  Processed {batch_idx + 1}/{len(test_loader)} batches")

# Concatenate all predictions
all_preds = np.concatenate(all_preds, axis=0)
all_targets = np.concatenate(all_targets, axis=0)

# Compute metrics
test_loss = total_loss / max(1, n_samples)
test_acc = accuracy_score(all_targets, all_preds)

print(f"\n{'='*50}")
print(f"üìä EVALUATION RESULTS")
print(f"{'='*50}")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Total Samples: {n_samples}")
print(f"{'='*50}\n")

## üìà Confusion Matrix

Visualize the confusion matrix to see per-class performance

In [None]:
# Compute confusion matrix
cm = confusion_matrix(all_targets, all_preds, labels=list(range(NUM_CLASSES)))

# Convert to percentages (per-row normalization)
cm_percentage = 100 * cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Round to integer percentages
cm_display = np.rint(cm_percentage).astype(int)

print("Confusion Matrix (Percentages):")
print(cm_display)
print(f"\nDiagonal (Per-class accuracy): {cm_display.diagonal()}")
print(f"Mean diagonal accuracy: {cm_display.diagonal().mean():.2f}%")

In [None]:
# Plot confusion matrix (SeisLM style)
fig, ax = plt.subplots(figsize=(12, 10))

disp = ConfusionMatrixDisplay(
    confusion_matrix=cm_display,
    display_labels=DISPLAY_LABELS,
)

disp.plot(ax=ax, xticks_rotation=45, colorbar=False, cmap="Reds")
ax.set_title(
    f"Confusion Matrix (xLSTM-UNet) | Accuracy: {test_acc*100:.2f}%",
    fontsize=16,
    fontweight='bold'
)
ax.set_xlabel('Predicted Class', fontsize=12)
ax.set_ylabel('True Class', fontsize=12)

plt.tight_layout()
plt.show()

# Save figure
save_path = '/path/to/your/checkpoint_or_data'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
fig.savefig(save_path, dpi=300, bbox_inches='tight')
print(f"\n‚úÖ Confusion matrix saved to: {save_path}")

## üìä Per-Class Metrics

In [None]:
# Per-class accuracy
per_class_acc = cm_display.diagonal()

print("\n" + "="*60)
print("PER-CLASS ACCURACY")
print("="*60)
for i, (label, acc) in enumerate(zip(DISPLAY_LABELS, per_class_acc)):
    bar = '‚ñà' * int(acc / 5)  # Visual bar
    print(f"{label:8s} | {acc:3d}% {bar}")
print("="*60)
print(f"MEAN     | {per_class_acc.mean():.2f}%")
print("="*60)

## üìã Summary Statistics

In [None]:
# Count samples per class
unique, counts = np.unique(all_targets, return_counts=True)
class_distribution = dict(zip(unique, counts))

print("\n" + "="*60)
print("TEST SET CLASS DISTRIBUTION")
print("="*60)
for i, label in enumerate(DISPLAY_LABELS):
    count = class_distribution.get(i, 0)
    print(f"{label:8s} | {count:4d} samples")
print("="*60)
print(f"TOTAL    | {n_samples:4d} samples")
print("="*60)

## üîç Comparison with SeisLM Methodology

### ‚úÖ Implementation Checklist

Your xLSTM-UNet implementation follows SeisLM's methodology:

| **Aspect** | **SeisLM** | **Your xLSTM-UNet** | **Match?** |
|------------|------------|---------------------|------------|
| Dataset | Norcia foreshock/aftershock | Norcia foreshock/aftershock | ‚úÖ |
| Num Classes | 9 | 9 | ‚úÖ |
| Split Method | Temporal | Temporal | ‚úÖ |
| Component Order | ZNE | ZNE | ‚úÖ |
| Dimension Order | NWC | NWC | ‚úÖ |
| Normalization | std per channel | std per channel | ‚úÖ |
| Train/Val/Test | 70/10/20 | 70/10/20 | ‚úÖ |
| Frozen Encoder | Yes | Yes | ‚úÖ |
| Head Type | DoubleConv | DoubleConv | ‚úÖ |
| Optimizer | AdamW (4e-4, wd=0.1) | AdamW (4e-4, wd=0.1) | ‚úÖ |
| Max Epochs | 15 | 15 | ‚úÖ |
| Pretrain Method | Contrastive | Contrastive | ‚úÖ |

**Conclusion**: Your implementation correctly follows SeisLM's foreshock fine-tuning methodology! üéâ

## üíæ Save Results

In [None]:
# Save results to file
results = {
    'checkpoint': CKPT_PATH,
    'test_loss': float(test_loss),
    'test_accuracy': float(test_acc),
    'num_samples': int(n_samples),
    'num_classes': NUM_CLASSES,
    'per_class_accuracy': per_class_acc.tolist(),
    'class_labels': DISPLAY_LABELS,
    'confusion_matrix': cm.tolist(),
    'confusion_matrix_percentage': cm_display.tolist(),
}

import json
results_path = '/path/to/your/checkpoint_or_data'
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n‚úÖ Results saved to: {results_path}")

## üéØ Done!

**Evaluation Complete** ‚úÖ

Your xLSTM-UNet model has been evaluated on the foreshock-aftershock classification task using the exact same methodology as SeisLM.

### Next Steps:
1. Compare your accuracy with SeisLM's baseline
2. Try evaluating other checkpoints (e.g., epoch=14)
3. Analyze which classes are harder to classify
4. Try with different num_classes (2, 4, 8) for comparison