In [None]:
# -*- coding: utf-8 -*-
"""
Title: Original MagNet Implementation for Single-Station Magnitude Estimation

This notebook implements the original MagNet architecture (Mousavi & Beroza, 2020) 
for earthquake magnitude estimation, including:
1. Model architecture implementation
2. Custom loss function with uncertainty quantification 
3. Training and evaluation pipeline
4. Performance analysis and visualization
"""


# Part 1: Setup and Imports 

In [None]:
#------------------------------------------------------------------------------
# Part 1: Setup and Imports 
#------------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from torchinfo import summary
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# Configure environment
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Part 2: Dataset Class

In [None]:
#------------------------------------------------------------------------------
# Part 2: Dataset Class
#------------------------------------------------------------------------------

class EarthquakeDataset(Dataset):
    """Dataset class for earthquake waveforms and magnitude labels."""
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Part 3: Model Architecture

In [None]:
#------------------------------------------------------------------------------
# Part 3: Model Architecture
#------------------------------------------------------------------------------

class EarthquakeModel(nn.Module):
    """Original MagNet architecture."""
    def __init__(self):
        super(EarthquakeModel, self).__init__()
        self.conv1 = nn.Conv1d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 32, kernel_size=3, padding=1)
        self.maxpool = nn.MaxPool1d(4, padding=1)
        self.dropout = nn.Dropout(0.2)
        self.lstm = nn.LSTM(32, 100, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(200, 2)  # Magnitude and log variance

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        x = x.transpose(1, 2)
        x, _ = self.lstm(x)
        x = x[:, -1, :]
        x = self.fc(x)
        return x


# Part 4: Training Components

In [None]:
#------------------------------------------------------------------------------
# Part 4: Training Components
#------------------------------------------------------------------------------

class EarlyStopping:
    """Early stopping to prevent overfitting."""
    def __init__(self, patience=7, verbose=False, delta=0, run_id=None, 
                 test_seed=None, model_seed=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.run_id = run_id
        self.test_seed = test_seed
        self.model_seed = model_seed

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f})')
        model_filename = f'best_model_Run_{self.run_id}_test_seed_{self.test_seed}_model_seed_{self.model_seed}.pth'
        torch.save(model.state_dict(), model_filename)
        self.val_loss_min = val_loss

def custom_loss(y_pred, y_true):
    """Custom loss function combining prediction error and uncertainty."""
    y_hat = y_pred[:, 0]
    s = y_pred[:, 1]
    return torch.mean(0.5 * torch.exp(-s) * (y_true - y_hat)**2 + 0.5 * s)

# Part 5: Training Function

In [None]:
#------------------------------------------------------------------------------
# Part 5: Training Function
#------------------------------------------------------------------------------

def train_model(model, train_loader, val_loader, num_epochs=300, patience=5,
                run_id=None, test_seed=None, model_seed=None):
    """Main training function with early stopping and LR scheduling."""
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=np.sqrt(0.1),
        cooldown=0, patience=4, verbose=True, min_lr=0.5e-6
    )

    early_stopping = EarlyStopping(
        patience=patience, verbose=True,
        run_id=run_id, test_seed=test_seed, model_seed=model_seed
    )

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        for data, target in train_loader:
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            outputs = model(data)
            loss = custom_loss(outputs, target)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item()

        # Validation phase
        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for data, target in val_loader:
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                loss = custom_loss(outputs, target)
                val_loss += loss.item()

        # Calculate average losses
        val_loss /= len(val_loader)
        running_loss /= len(train_loader)

        # Update LR and check early stopping
        scheduler.step(val_loss)
        early_stopping(val_loss, model)

        print(f'Epoch {epoch+1}:')
        print(f'Training Loss: {running_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        train_losses.append(running_loss)
        val_losses.append(val_loss)

        if early_stopping.early_stop:
            print(f'Early stopping triggered at epoch {epoch+1}')
            break

    return train_losses, val_losses


# Part 6: Uncertainty Estimation

In [None]:
#------------------------------------------------------------------------------
# Part 6: Uncertainty Estimation
#------------------------------------------------------------------------------

def estimate_uncertainty(model, data_loader, num_samples=50):
    """Estimate aleatoric and epistemic uncertainties using MC dropout."""
    model.eval()
    
    # Enable dropout at test time
    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.train()
    
    predictions = []
    log_variances = []
    
    with torch.no_grad():
        for _ in range(num_samples):
            batch_predictions = []
            batch_log_variances = []
            for data, _ in data_loader:
                data = data.to(device)
                output = model(data)
                batch_predictions.append(output[:, 0].cpu().numpy())
                batch_log_variances.append(output[:, 1].cpu().numpy())
            predictions.append(np.concatenate(batch_predictions))
            log_variances.append(np.concatenate(batch_log_variances))
    
    predictions = np.array(predictions)
    log_variances = np.array(log_variances)
    
    # Calculate different types of uncertainty
    mean_prediction = np.mean(predictions, axis=0)
    yhat_squared_mean = np.mean(np.square(predictions), axis=0)
    
    sigma_squared = np.power(10, log_variances)
    aleatoric_uncertainty = np.mean(sigma_squared, axis=0)
    
    epistemic_uncertainty = np.std(predictions, axis=0)
    combined_uncertainty = yhat_squared_mean - np.square(mean_prediction) + aleatoric_uncertainty
    
    return mean_prediction, epistemic_uncertainty, aleatoric_uncertainty, combined_uncertainty


# Part 7: Evaluation Function

In [None]:
#------------------------------------------------------------------------------
# Part 7: Evaluation Function
#------------------------------------------------------------------------------

def evaluate_model(model, test_loader, run_id, test_seed, model_seed):
    """Evaluate model performance with uncertainty estimation."""
    model_filename = f'best_model_Run_{run_id}_test_seed_{test_seed}_model_seed_{model_seed}.pth'
    model.load_state_dict(torch.load(model_filename))
    
    mean_pred, epistemic_unc, aleatoric_unc, combined_unc = estimate_uncertainty(model, test_loader)
    
    true_values = []
    for _, target in test_loader:
        true_values.append(target.numpy())
    true_values = np.concatenate(true_values)
    
    mae = np.mean(np.abs(mean_pred - true_values))
    
    return mae, mean_pred, true_values, epistemic_unc, aleatoric_unc, combined_unc



# Part 8: Visualization Functions

In [None]:
#------------------------------------------------------------------------------
# Part 8: Visualization Functions
#------------------------------------------------------------------------------

def plot_training_curves(train_losses, val_losses):
    """Plot training and validation loss curves."""
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs', fontsize=14, fontweight='bold')
    plt.ylabel('Loss', fontsize=14, fontweight='bold')
    plt.title('Training and Validation Loss Curves', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_prediction_scatter(true_values, predictions, mae):
    """Plot scatter plot of predicted vs true magnitudes."""
    plt.figure(figsize=(10, 6))
    plt.scatter(true_values, predictions, alpha=0.5)
    plt.plot([min(true_values), max(true_values)], 
             [min(true_values), max(true_values)], 'r--')
    plt.xlabel('True Magnitude', fontsize=14, fontweight='bold')
    plt.ylabel('Predicted Magnitude', fontsize=14, fontweight='bold')
    plt.title(f'Prediction vs Ground Truth (MAE: {mae:.3f})', 
             fontsize=16, fontweight='bold')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

def plot_uncertainty_analysis(predictions, aleatoric_unc, epistemic_unc, combined_unc):
    """Plot uncertainty analysis visualizations."""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    # Aleatoric uncertainty
    axes[0].scatter(predictions, aleatoric_unc, alpha=0.5)
    axes[0].set_xlabel('Predicted Magnitude', fontsize=12)
    axes[0].set_ylabel('Aleatoric Uncertainty', fontsize=12)
    axes[0].set_title('Aleatoric Uncertainty vs Predictions', fontsize=14)
    axes[0].grid(True)
    
    # Epistemic uncertainty
    axes[1].scatter(predictions, epistemic_unc, alpha=0.5)
    axes[1].set_xlabel('Predicted Magnitude', fontsize=12)
    axes[1].set_ylabel('Epistemic Uncertainty', fontsize=12)
    axes[1].set_title('Epistemic Uncertainty vs Predictions', fontsize=14)
    axes[1].grid(True)
    
    # Combined uncertainty
    axes[2].scatter(predictions, combined_unc, alpha=0.5)
    axes[2].set_xlabel('Predicted Magnitude', fontsize=12)
    axes[2].set_ylabel('Combined Uncertainty', fontsize=12)
    axes[2].set_title('Combined Uncertainty vs Predictions', fontsize=14)
    axes[2].grid(True)
    
    plt.tight_layout()
    plt.show()



# Part 9: Main Execution

In [None]:
#------------------------------------------------------------------------------
# Part 9: Main Execution
#------------------------------------------------------------------------------

if __name__ == "__main__":
    # Start timing
    start_time = time.time()
    
    # Load preprocessed data
    data_file = "pre_processed_data.npy" # Change to your file path
    labels_file = "pre_processed_labels.npy" # Change to your file path
    
    # Verify files exist
    assert os.path.isfile(data_file), f"Data file not found at {data_file}"
    assert os.path.isfile(labels_file), f"Labels file not found at {labels_file}"
    
    # Load data
    data = torch.tensor(np.load(data_file), dtype=torch.float32)
    labels = torch.tensor(np.load(labels_file), dtype=torch.float32)
    print(f"Data shape: {data.shape}, Labels shape: {labels.shape}")
    
    # Model seeds for multiple runs
    model_seeds = [42, 123, 256, 789, 1024]
    results = []
    
    # Run experiments
    for run_id in tqdm(range(1, 11)):  # 10 different test sets
        test_results = []
        for model_seed in model_seeds:
            # Split data
            train_val_data, test_data, train_val_labels, test_labels = train_test_split(
                data, labels, test_size=0.2, shuffle=True)
            train_data, val_data, train_labels, val_labels = train_test_split(
                train_val_data, train_val_labels, 
                test_size=0.125, shuffle=True)
            
            # Create datasets
            train_dataset = EarthquakeDataset(train_data, train_labels)
            val_dataset = EarthquakeDataset(val_data, val_labels)
            test_dataset = EarthquakeDataset(test_data, test_labels)
            
            # Create dataloaders
            train_loader = DataLoader(train_dataset, batch_size=256, 
                                    shuffle=True, num_workers=2)
            val_loader = DataLoader(val_dataset, batch_size=256, 
                                  shuffle=False, num_workers=2)
            test_loader = DataLoader(test_dataset, batch_size=256, 
                                   shuffle=False, num_workers=2)
            
            # Initialize model
            model = EarthquakeModel().to(device)
            print(summary(model, input_size=(256, 3000, 3)))
            
            # Train model
            train_losses, val_losses = train_model(
                model, train_loader, val_loader,
                run_id=run_id, test_seed=run_id, model_seed=model_seed
            )
            
            # Evaluate model
            mae, mean_pred, true_values, epistemic_unc, aleatoric_unc, combined_unc = \
                evaluate_model(model, test_loader, run_id, run_id, model_seed)
            
            result = {
                "test_seed": run_id,
                "model_seed": model_seed,
                "mae": float(mae),
                "mean_aleatoric_uncertainty": float(np.mean(aleatoric_unc)),
                "mean_epistemic_uncertainty": float(np.mean(epistemic_unc)),
                "mean_combined_uncertainty": float(np.mean(combined_unc))
            }
            
            test_results.append(result)
            
            # Plot results for this run
            plot_training_curves(train_losses, val_losses)
            plot_prediction_scatter(true_values, mean_pred, mae)
            plot_uncertainty_analysis(mean_pred, aleatoric_unc, 
                                   epistemic_unc, combined_unc)
            
            print(f"\nResults for Run {run_id}, Model Seed {model_seed}:")
            print(f"MAE: {mae:.4f}")
            print(f"Mean Aleatoric Uncertainty: {np.mean(aleatoric_unc):.4f}")
            print(f"Mean Epistemic Uncertainty: {np.mean(epistemic_unc):.4f}")
            print(f"Mean Combined Uncertainty: {np.mean(combined_unc):.4f}")
        
        # Find median performance for this test set
        sorted_results = sorted(test_results, key=lambda x: x['mae'])
        median_result = sorted_results[2]  # Index 2 is the median of 5
        
        results.append({
            "run_id": run_id,
            "median_mae": median_result['mae'],
            "median_aleatoric_uncertainty": median_result['mean_aleatoric_uncertainty'],
            "median_epistemic_uncertainty": median_result['mean_epistemic_uncertainty'],
            "median_combined_uncertainty": median_result['mean_combined_uncertainty'],
            "all_results": test_results
        })
        
        # Save results after each test seed
        with open('baseline_magnet_results.json', 'w') as f:
            json.dump(results, f, indent=4)
    
    # Calculate overall statistics
    maes = [result["median_mae"] for result in results]
    mean_mae = np.mean(maes)
    std_mae = np.std(maes)
    
    print("\nOverall Results:")
    print(f"Mean MAE: {mean_mae:.4f}")
    print(f"Standard Deviation of MAE: {std_mae:.4f}")
    
    # Plot MAE distribution
    plt.figure(figsize=(10, 6))
    plt.hist(maes, bins=10, alpha=0.7, color='blue', edgecolor='black')
    plt.axvline(mean_mae, color='red', linestyle='dashed', linewidth=2,
                label=f'Mean MAE = {mean_mae:.4f}')
    plt.axvline(mean_mae + std_mae, color='green', linestyle='dotted', linewidth=2,
                label=f'Mean ± Std = {mean_mae + std_mae:.4f}')
    plt.axvline(mean_mae - std_mae, color='green', linestyle='dotted', linewidth=2)
    
    plt.xlabel('MAE', fontsize=14, fontweight='bold')
    plt.ylabel('Frequency', fontsize=14, fontweight='bold')
    plt.title('Distribution of Median MAEs over Different Test Sets',
              fontsize=16, fontweight='bold')
    plt.legend(fontsize=12)
    plt.grid(True)
    plt.tight_layout()
    plt.show()
    
    # Calculate correlations between uncertainties and errors
    errors = np.abs(mean_pred - true_values)
    aleatoric_corr = np.corrcoef(errors, aleatoric_unc)[0,1]
    epistemic_corr = np.corrcoef(errors, epistemic_unc)[0,1]
    combined_corr = np.corrcoef(errors, combined_unc)[0,1]
    
    print("\nCorrelations between Uncertainties and Absolute Errors:")
    print(f"Aleatoric Uncertainty: {aleatoric_corr:.4f}")
    print(f"Epistemic Uncertainty: {epistemic_corr:.4f}")
    print(f"Combined Uncertainty: {combined_corr:.4f}")
    
    # Save final model
    final_model_path = 'final_baseline_model.pth'
    torch.save(model.state_dict(), final_model_path)
    print(f"\nFinal model saved to {final_model_path}")
    
    # End timing
    end_time = time.time()
    elapsed_time = end_time - start_time
    print(f"\nTotal execution time: {elapsed_time/60:.2f} minutes")