In [2]:
"""
CNN Training Script for Earthquake Detection Using PSD Features
================================================================

This notebook trains a convolutional neural network (CNN) to classify infrasound events 
(earthquake vs. background noise) using power spectral density (PSD) data.

Key Components:
---------------
1. **Data Loading & Preprocessing**
   - Loads PSD data from pickle files (earthquake and background).
   - Extracts PSD arrays, flattens them, and applies log normalization.
   - Saves normalization statistics (`mean.npy`, `std.npy`).

2. **Dataset & Training Utilities**
   - Custom `Dataset` class for PyTorch DataLoader.
   - Implements early stopping to prevent overfitting.

3. **Cross-Validation Training Loop**
   - 5-fold stratified cross-validation.
   - Applies class balancing with weighted loss.
   - Uses RAdam optimizer and monitors validation performance.
   - Saves model and data for each fold.

4. **Evaluation**
   - Reports standard classification metrics per fold.
   - Computes and prints aggregate cross-validation metrics:
     Accuracy, Precision, Recall, F1 Score, ROC AUC, and Confusion Matrix.

Outputs:
--------
- Trained model checkpoints: `fold_v2_outputs/fold_*/CNNmodel.pth`
- Preprocessed data: `fold_v2_outputs/fold_*/data.npz`
- Normalization stats: `mean.npy`, `std.npy`

Dependencies:
-------------
- torch, numpy, sklearn, torch_optimizer, tqdm
- Custom modules: `psd_pickle_utils`, `cnn_model`

Ethan Gelfand, 08/06/2025
"""


import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch_optimizer as optim
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, classification_report
)

from psd_pickle_utils import load_pickle_data, extract_psd_array
from cnn_model import EarthquakeCNN

# -----------------------------
# Utility Classes
# -----------------------------

class PSD_Dataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None or val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# -----------------------------
# Data Loading and Preprocessing
# -----------------------------

patheq = "../DataCollection_Preprocessing/Exported_Paros_Data/PSD_Windows_Earthquake_100Hz.pkl"
pathbg = "../DataCollection_Preprocessing/Exported_Paros_Data/PSD_Windows_Background_100Hz.pkl"

EarthquakeData = load_pickle_data(patheq)
BackgroundData = load_pickle_data(pathbg)

eq_array = extract_psd_array(EarthquakeData)
bg_array = extract_psd_array(BackgroundData)

eq_labels = np.ones(len(eq_array), dtype=int)
bg_labels = np.zeros(len(bg_array), dtype=int)

X = np.concatenate([eq_array, bg_array], axis=0)
y = np.concatenate([eq_labels, bg_labels], axis=0)

X = X.reshape(X.shape[0], -1)

# Shuffle
indices = np.arange(len(X))
np.random.shuffle(indices)
X, y = X[indices], y[indices]

# Normalize
X_log = np.log10(X + 1e-10)
mean, std = X_log.mean(axis=0), X_log.std(axis=0) + 1e-6
X = (X_log - mean) / std

np.save("../DataCollection_Preprocessing/Exported_Paros_Data/mean.npy", mean)
np.save("../DataCollection_Preprocessing/Exported_Paros_Data/std.npy", std)

# -----------------------------
# Cross-Validation Training Loop
# -----------------------------

K = 5
skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=42)

fold_train_accs, fold_val_accs = [], []
all_val_labels, all_val_preds, all_val_probs = [], [], []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for fold, (train_index, val_index) in enumerate(skf.split(X, y)):
    print(f"Starting Fold {fold+1}/{K}")
    fold_dir = f"fold_outputs/fold_{fold+1}"
    os.makedirs(fold_dir, exist_ok=True)

    X_train, X_val = X[train_index], X[val_index]
    y_train, y_val = y[train_index], y[val_index]

    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    class_weights[1] *= 3.0  # Emphasize Earthquake class

    train_dataset = PSD_Dataset(X_train, y_train)
    val_dataset = PSD_Dataset(X_val, y_val)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

    model = EarthquakeCNN(X.shape[1]).to(device)
    criterion = nn.CrossEntropyLoss(weight=torch.tensor(class_weights, dtype=torch.float32).to(device))
    optimizer = optim.RAdam(model.parameters(), lr=1e-4)
    early_stopping = EarlyStopping(patience=5, min_delta=1e-4)

    for epoch in range(70):
        model.train()
        train_loss, correct, total = 0, 0, 0

        for data, targets in train_loader:
            data, targets = data.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        train_acc = correct / total

        model.eval()
        val_loss, val_correct, val_total = 0, 0, 0
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(device), targets.to(device)
                outputs = model(data)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += targets.size(0)
                val_correct += (predicted == targets).sum().item()

        val_acc = val_correct / val_total

        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Fold {fold+1}, Epoch {epoch+1}: Train Acc: {train_acc:.4f}, Val Acc: {val_acc:.4f}")

        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

    fold_train_accs.append(train_acc)
    fold_val_accs.append(val_acc)

    # Final evaluation
    model.eval()
    val_preds, val_labels, val_probs = [], [], []
    with torch.no_grad():
        for data, targets in val_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = model(data)
            probs = torch.softmax(outputs, dim=1)[:, 1]
            _, predicted = torch.max(outputs, 1)
            val_preds.extend(predicted.cpu().numpy())
            val_labels.extend(targets.cpu().numpy())
            val_probs.extend(probs.cpu().numpy())

    all_val_labels.extend(val_labels)
    all_val_preds.extend(val_preds)
    all_val_probs.extend(val_probs)

    np.savez(os.path.join(fold_dir, "data.npz"), X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val)
    torch.save(model.state_dict(), os.path.join(fold_dir, "CNNmodel.pth"))

    print(f"\nClassification report for fold {fold+1}:")
    print(classification_report(val_labels, val_preds, digits=4))

# -----------------------------
# Final Metrics
# -----------------------------

print("\n===== Final Cross-Validation Metrics =====")
print(f"Accuracy:  {accuracy_score(all_val_labels, all_val_preds):.4f}")
print(f"Precision: {precision_score(all_val_labels, all_val_preds):.4f}")
print(f"Recall:    {recall_score(all_val_labels, all_val_preds):.4f}")
print(f"F1 Score:  {f1_score(all_val_labels, all_val_preds):.4f}")
print(f"ROC AUC:   {roc_auc_score(all_val_labels, all_val_probs):.4f}")
print("Confusion Matrix:")
print(confusion_matrix(all_val_labels, all_val_preds))

print(f"\nAverage Train Accuracy: {np.mean(fold_train_accs):.4f} +/- {np.std(fold_train_accs):.4f}")
print(f"Average Val Accuracy:   {np.mean(fold_val_accs):.4f} +/- {np.std(fold_val_accs):.4f}")


Starting Fold 1/5
Fold 1, Epoch 1: Train Acc: 0.4097, Val Acc: 0.3767
Fold 1, Epoch 10: Train Acc: 0.9091, Val Acc: 0.9283
Fold 1, Epoch 20: Train Acc: 0.9304, Val Acc: 0.9417
Fold 1, Epoch 30: Train Acc: 0.9495, Val Acc: 0.9417
Early stopping triggered.

Classification report for fold 1:
              precision    recall  f1-score   support

           0     0.9437    0.9640    0.9537       139
           1     0.9383    0.9048    0.9212        84

    accuracy                         0.9417       223
   macro avg     0.9410    0.9344    0.9375       223
weighted avg     0.9416    0.9417    0.9415       223

Starting Fold 2/5
Fold 2, Epoch 1: Train Acc: 0.3805, Val Acc: 0.3767
Fold 2, Epoch 10: Train Acc: 0.9282, Val Acc: 0.9372
Fold 2, Epoch 20: Train Acc: 0.9439, Val Acc: 0.9417
Early stopping triggered.

Classification report for fold 2:
              precision    recall  f1-score   support

           0     0.9200    0.9928    0.9550       139
           1     0.9863    0.8571    