In [None]:
"""
Earthquake vs Background PSD Classification Training Script Using 2D CNN and Stratified K-Fold Cross-Validation
-------------------------------------------------------------------------------------------------------------

This script performs supervised training of a 2D convolutional neural network (CNN) 
to classify Power Spectral Density (PSD) feature windows as earthquake events or background noise.

Main Features:
--------------
- Loads preprocessed PSD datasets (earthquake and background) saved as pickle files.
- Converts PSD data into log-scale and applies per-feature normalization (mean/std).
- Defines a PyTorch Dataset and DataLoader for efficient batch processing.
- Implements class weighting with bias factor to handle class imbalance.
- Trains an EarthquakeCNN2d model using Stratified K-Fold cross-validation for robust evaluation.
- Includes early stopping based on validation loss to prevent overfitting.
- Logs training and validation loss and accuracy per epoch.
- Saves model checkpoints and normalized data splits per fold.
- Computes and prints detailed classification metrics and confusion matrix after cross-validation.

Dependencies:
-------------
- PyTorch and torch_optimizer (RAdam optimizer)
- NumPy
- scikit-learn (for stratified splitting, metrics, and class weights)
- Custom utilities: psd_pickle_utils (data loading), cnn_model (model definition)

Inputs:
-------
- PSD pickle files containing labeled earthquake and background PSD windows.
- Normalization parameters are computed internally.
- Model architecture expects input shape: (windows=11, freq_bins=52).

Outputs:
--------
- Saved normalization statistics (mean.npy, std.npy) for inference normalization.
- Model checkpoints saved per fold.
- Classification reports printed after each fold and overall metrics after all folds.

Author: Ethan Gelfand
Date: 08/07/2025
"""


import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch_optimizer as optim
from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score,
    confusion_matrix, classification_report
)

from psd_pickle_utils import load_pickle_data, extract_psd_array
from cnn_model import EarthquakeCNN2d

# --- PyTorch Dataset ---
class PSD_Dataset(Dataset):
    def __init__(self, X, y):
        # X shape: (events, windows, freq_bins)
        # Add channel dim for CNN2d: (events, 1, windows, freq_bins)
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]



# --- Early stopping ---
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
            return
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


# --- Load the dataset ---
patheq = "../DataCollection_Preprocessing/Exported_Paros_Data/PSD_Windows_Earthquake_100Hz.pkl"
pathbg = "../DataCollection_Preprocessing/Exported_Paros_Data/PSD_Windows_Background_100Hz.pkl"

EarthquakeData = load_pickle_data(patheq)
BackgroundData = load_pickle_data(pathbg)

# convert to 3D numpy arrays: (events, windows, freq_bins)
eq_array = extract_psd_array(EarthquakeData)  # (417, 11, 52)
bg_array = extract_psd_array(BackgroundData)  # (684, 11, 52)

print("EQ array shape:", eq_array.shape)
print("BG array shape:", bg_array.shape)

# Assign labels
eq_labels = np.ones(len(eq_array), dtype=int)
bg_labels = np.zeros(len(bg_array), dtype=int)
X = np.concatenate([eq_array, bg_array], axis=0)  # shape (events, windows, freq_bins)
y = np.concatenate([eq_labels, bg_labels], axis=0)

# Shuffle dataset
indices = np.arange(len(X))
np.random.shuffle(indices)
X = X[indices]
y = y[indices]

# --- Normalization ---
X_log = np.log10(X + 1e-10)  # avoid log(0)
mean = X_log.mean(axis=0)  # mean per (window, freq_bin)
std = X_log.std(axis=0) + 1e-6  # prevent div by zero

X = (X_log - mean) / std

# Save normalization stats for inference
np.save("../DataCollection_Preprocessing/Exported_Paros_Data/mean.npy", mean)
np.save("../DataCollection_Preprocessing/Exported_Paros_Data/std.npy", std)
print("Saved mean and std to mean.npy and std.npy")


# --- Training with stratified k-fold ---
K = 5
skf = StratifiedKFold(n_splits=K, shuffle=True, random_state=42)

fold_train_accs = []
fold_val_accs = []
all_val_labels = []
all_val_preds = []
all_val_probs = []

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

for fold, (train_idx, val_idx) in enumerate(skf.split(X, y)):
    print(f"Starting Fold {fold+1}/{K}")
    fold_dir = f"fold_outputs/fold_{fold+1}"
    os.makedirs(fold_dir, exist_ok=True)

    X_train, X_val = X[train_idx], X[val_idx]
    y_train, y_val = y[train_idx], y[val_idx]

    # Compute class weights dynamically
    class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
    bias_factor = 3.0
    class_weights[1] *= bias_factor

    train_dataset = PSD_Dataset(X_train, y_train)
    val_dataset = PSD_Dataset(X_val, y_val)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

    model = EarthquakeCNN2d(input_shape=X.shape[1:]).to(device)

    weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
    criterion = nn.CrossEntropyLoss(weight=weights)
    optimizer = optim.RAdam(model.parameters(), lr=1e-4)

    early_stopping = EarlyStopping(patience=5, min_delta=1e-4)

    num_epochs = 70
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        correct = 0
        total = 0

        for data, targets in train_loader:
            data = data.to(device)
            targets = targets.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        train_loss /= len(train_loader)
        train_acc = correct / total

        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for data, targets in val_loader:
                data = data.to(device)
                targets = targets.to(device)
                outputs = model(data)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                val_total += targets.size(0)
                val_correct += (predicted == targets).sum().item()

        val_loss /= len(val_loader)
        val_acc = val_correct / val_total

        if (epoch + 1) % 10 == 0 or epoch == 0:
            print(f"Fold {fold+1}, Epoch [{epoch+1}/{num_epochs}] "
                  f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} "
                  f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered.")
            break

    fold_train_accs.append(train_acc)
    fold_val_accs.append(val_acc)

    model.eval()
    val_preds_fold = []
    val_labels_fold = []
    val_probs_fold = []

    with torch.no_grad():
        for data, targets in val_loader:
            data = data.to(device)
            targets = targets.to(device)
            outputs = model(data)  # logits
            probs = torch.softmax(outputs, dim=1)[:, 1]  # earthquake class probability
            _, predicted = torch.max(outputs, 1)
            val_preds_fold.extend(predicted.cpu().numpy())
            val_labels_fold.extend(targets.cpu().numpy())
            val_probs_fold.extend(probs.cpu().numpy())

    all_val_labels.extend(val_labels_fold)
    all_val_preds.extend(val_preds_fold)
    all_val_probs.extend(val_probs_fold)

    np.savez(os.path.join(fold_dir, "data.npz"),
             X_train=X_train, y_train=y_train,
             X_val=X_val, y_val=y_val)

    torch.save(model.state_dict(), os.path.join(fold_dir, "CNNmodel.pth"))

    print(f"\nClassification report for fold {fold+1}:")
    print(classification_report(val_labels_fold, val_preds_fold, digits=4))

print("\n===== Final Cross-Validation Metrics =====")
accuracy = accuracy_score(all_val_labels, all_val_preds)
precision = precision_score(all_val_labels, all_val_preds)
recall = recall_score(all_val_labels, all_val_preds)
f1 = f1_score(all_val_labels, all_val_preds)
roc_auc = roc_auc_score(all_val_labels, all_val_probs)
cm = confusion_matrix(all_val_labels, all_val_preds)

print(f"Accuracy:  {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall:    {recall:.4f}")
print(f"F1 Score:  {f1:.4f}")
print(f"ROC AUC:   {roc_auc:.4f}")
print("Confusion Matrix:")
print(cm)

print(f"\nCross-validation results over {K} folds:")
print(f"Average Train Accuracy: {np.mean(fold_train_accs):.4f} +/- {np.std(fold_train_accs):.4f}")
print(f"Average Validation Accuracy: {np.mean(fold_val_accs):.4f} +/- {np.std(fold_val_accs):.4f}")


EQ array shape: (420, 11, 52)
BG array shape: (694, 11, 52)
Saved mean and std to mean.npy and std.npy
Starting Fold 1/5
Fold 1, Epoch [1/70] Train Loss: 0.6462, Train Acc: 0.3692 Val Loss: 0.6421, Val Acc: 0.3767
Fold 1, Epoch [10/70] Train Loss: 0.2561, Train Acc: 0.9057 Val Loss: 0.3038, Val Acc: 0.9148
Fold 1, Epoch [20/70] Train Loss: 0.1811, Train Acc: 0.9259 Val Loss: 0.1858, Val Acc: 0.9417
Fold 1, Epoch [30/70] Train Loss: 0.1500, Train Acc: 0.9484 Val Loss: 0.1619, Val Acc: 0.9283
Fold 1, Epoch [40/70] Train Loss: 0.1468, Train Acc: 0.9428 Val Loss: 0.1632, Val Acc: 0.9552
Fold 1, Epoch [50/70] Train Loss: 0.1396, Train Acc: 0.9439 Val Loss: 0.1425, Val Acc: 0.9507
Early stopping triggered.

Classification report for fold 1:
              precision    recall  f1-score   support

           0     0.9697    0.9209    0.9446       139
           1     0.8791    0.9524    0.9143        84

    accuracy                         0.9327       223
   macro avg     0.9244    0.9366    