In [None]:
# -*- coding: utf-8 -*-
"""
Title: Enhanced Deeper Architecture for Earthquake Magnitude Estimation

This notebook implements an enhanced deeper architecture building upon the original 
MagNet model. Key improvements include:
1. Additional convolutional layers with residual connections 
2. Dual LSTM layers
3. Integration of log-transformed stream max amplitude information
4. Enhanced uncertainty quantification

Dependencies:
- torch, torchinfo 
- numpy, matplotlib
- tqdm (for progress tracking)
- seaborn (for visualization)
"""

# Part 1: Setup and Imports 

In [None]:
#------------------------------------------------------------------------------
# Part 1: Setup and Imports 
#------------------------------------------------------------------------------

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
import json
import os
from torchinfo import summary
import time
from tqdm import tqdm
import random

# Configure environment
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


# Part 2: Dataset Implementation

In [None]:
#------------------------------------------------------------------------------
# Part 2: Dataset Implementation
#------------------------------------------------------------------------------

class EarthquakeDataset(Dataset):
    """Dataset class handling waveforms, labels, and log stream max values."""
    def __init__(self, data, labels, log_stream_max):
        self.data = data
        self.labels = labels
        self.log_stream_max = log_stream_max

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        return self.data[idx], self.log_stream_max[idx], self.labels[idx]


# Part 3: Model Architecture  

In [None]:
#------------------------------------------------------------------------------
# Part 3: Model Architecture  
#------------------------------------------------------------------------------

class ResidualBlock(nn.Module):
    """Residual block with batch normalization."""
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU()
        
        # Downsample if needed
        self.downsample = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else None

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

class EarthquakeModel(nn.Module):
    """Enhanced architecture with residual connections and dual LSTM."""
    def __init__(self):
        super(EarthquakeModel, self).__init__()
        # CNN layers
        self.conv1 = nn.Conv1d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(64, 128, kernel_size=3, padding=1)
        self.res3 = ResidualBlock(128, 128)
        self.res4 = ResidualBlock(128, 64)
        self.res5 = ResidualBlock(64, 32)
        
        # Pooling and regularization
        self.maxpool = nn.MaxPool1d(2, padding=1)
        self.dropout = nn.Dropout(0.2)
        
        # LSTM layers
        self.lstm1 = nn.LSTM(32, 100, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(200, 200, batch_first=True, bidirectional=True)
        
        # Output layer (including log stream max)
        self.fc = nn.Linear(400 + 3, 2)  # 400 from LSTM, 3 from log stream max

    def forward(self, x, log_stream_max):
        # CNN processing
        x = x.transpose(1, 2)
        x = self.conv1(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        x = self.conv2(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        
        # Residual blocks
        x = self.res3(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        x = self.res4(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        x = self.res5(x)
        x = self.dropout(x)
        x = self.maxpool(x)
        
        # LSTM processing
        x = x.transpose(1, 2)
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)
        x = x[:, -1, :]
        
        # Combine with log stream max
        combined = torch.cat((x, log_stream_max), dim=1)
        x = self.fc(combined)
        return x


# Part 4: Training Components

In [None]:
#------------------------------------------------------------------------------
# Part 4: Training Components
#------------------------------------------------------------------------------

class EarlyStopping:
    """Early stopping to prevent overfitting."""
    def __init__(self, patience=7, verbose=False, delta=0, path='best_model.pth'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f})')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

def custom_loss(y_pred, y_true):
    """Custom loss combining prediction error and uncertainty."""
    y_hat = y_pred[:, 0]
    s = y_pred[:, 1]
    return torch.mean(0.5 * torch.exp(-s) * (y_true - y_hat)**2 + 0.5 * s)


# Part 5: Training Function

In [None]:
#------------------------------------------------------------------------------
# Part 5: Training Function
#------------------------------------------------------------------------------

def train_model(model, train_loader, val_loader, num_epochs=300, patience=50, model_path='best_model.pth'):
    """Train model with early stopping and learning rate scheduling."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15, verbose=True, min_lr=0.5e-6
    )

    early_stopping = EarlyStopping(patience=patience, verbose=True, path=model_path)
    criterion = custom_loss

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for data, log_stream_max, target in train_loader:
            data = data.to(device)
            log_stream_max = log_stream_max.to(device)
            target = target.to(device)
            
            optimizer.zero_grad()
            outputs = model(data, log_stream_max)
            loss = criterion(outputs, target)
            loss.backward()
            
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            running_loss += loss.item()

        val_loss = 0.0
        model.eval()
        with torch.no_grad():
            for data, log_stream_max, target in val_loader:
                data = data.to(device)
                log_stream_max = log_stream_max.to(device)  
                target = target.to(device)
                outputs = model(data, log_stream_max)
                loss = criterion(outputs, target)
                val_loss += loss.item()

        val_loss /= len(val_loader)
        running_loss /= len(train_loader)

        scheduler.step(val_loss)
        early_stopping(val_loss, model)

        print(f'Epoch {epoch+1}:')
        print(f'Training Loss: {running_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')

        train_losses.append(running_loss)
        val_losses.append(val_loss)

        if early_stopping.early_stop:
            print(f'Early stopping triggered at epoch {epoch+1}')
            break

    return train_losses, val_losses


# Part 6: Uncertainty Estimation

In [None]:
#------------------------------------------------------------------------------
# Part 6: Uncertainty Estimation
#------------------------------------------------------------------------------

def estimate_uncertainty(model, data_loader, num_samples=50):
    """Estimate model uncertainties using Monte Carlo dropout."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    for m in model.modules():
        if isinstance(m, nn.Dropout):
            m.train()

    predictions = []
    log_variances = []

    with torch.no_grad():
        for _ in range(num_samples):
            batch_predictions = []
            batch_log_variances = []
            for data, log_stream_max, _ in data_loader:
                data = data.to(device)
                log_stream_max = log_stream_max.to(device)
                output = model(data, log_stream_max)
                batch_predictions.append(output[:, 0].cpu().numpy())
                batch_log_variances.append(output[:, 1].cpu().numpy())
            predictions.append(np.concatenate(batch_predictions))
            log_variances.append(np.concatenate(batch_log_variances))

    predictions = np.array(predictions)
    log_variances = np.array(log_variances)

    mean_prediction = np.mean(predictions, axis=0)
    yhat_squared_mean = np.mean(np.square(predictions), axis=0)

    sigma_squared = np.power(10, log_variances)
    aleatoric_uncertainty = np.mean(sigma_squared, axis=0)

    epistemic_uncertainty = np.std(predictions, axis=0)
    combined_uncertainty = yhat_squared_mean - np.square(mean_prediction) + aleatoric_uncertainty

    return mean_prediction, epistemic_uncertainty, aleatoric_uncertainty, combined_uncertainty

# Part 7: Evaluation Function

In [None]:
#------------------------------------------------------------------------------
# Part 7: Evaluation Function
#------------------------------------------------------------------------------

def evaluate_model(model, test_loader, model_path='best_model.pth'):
    """Evaluate model performance and uncertainties."""
    model.load_state_dict(torch.load(model_path))
    mean_pred, epistemic_unc, aleatoric_unc, combined_unc = estimate_uncertainty(model, test_loader)

    true_values = []
    for _, _, target in test_loader:
        true_values.append(target.numpy())
    true_values = np.concatenate(true_values)

    mae = np.mean(np.abs(mean_pred - true_values))

    return mae, mean_pred, true_values, epistemic_unc, aleatoric_unc, combined_unc


# Part 8: Visualization Functions

In [None]:
#------------------------------------------------------------------------------
# Part 8: Visualization Functions
#------------------------------------------------------------------------------

def plot_uncertainties(true_values, predictions, epistemic_unc, aleatoric_unc, combined_unc, seed):
    """Plot uncertainty analyses and model predictions."""
    # Calculate thresholds
    aleatoric_threshold = np.percentile(aleatoric_unc, 99.995)
    epistemic_threshold = np.percentile(epistemic_unc, 99.995)
    combined_threshold = np.percentile(combined_unc, 99.995)

    # Aleatoric Uncertainty
    plt.figure(figsize=(10, 8))
    mask = aleatoric_unc < aleatoric_threshold
    plt.errorbar(predictions[mask], aleatoric_unc[mask], 
                xerr=aleatoric_unc[mask], fmt='o', alpha=0.4, 
                ecolor='g', capthick=2)
    plt.plot(true_values[mask], aleatoric_unc[mask], 'ro', alpha=0.4)
    plt.xlabel('Predicted Magnitude', fontsize=14, fontweight='bold')
    plt.ylabel('Aleatoric Uncertainty', fontsize=14, fontweight='bold')
    plt.ylim(0, aleatoric_threshold)
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.tight_layout()
    plt.savefig(f'aleatoric_uncertainty_seed_{seed}.png')
    plt.close()

    # Epistemic Uncertainty
    plt.figure(figsize=(10, 8))
    mask = epistemic_unc < epistemic_threshold
    plt.errorbar(predictions[mask], epistemic_unc[mask], 
                xerr=epistemic_unc[mask], fmt='o', alpha=0.4, 
                ecolor='g', capthick=2)
    plt.plot(true_values[mask], epistemic_unc[mask], 'ro', alpha=0.4)
    plt.xlabel('Predicted Magnitude', fontsize=14, fontweight='bold')
    plt.ylabel('Epistemic Uncertainty', fontsize=14, fontweight='bold')
    plt.ylim(0, epistemic_threshold)
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.tight_layout()
    plt.savefig(f'epistemic_uncertainty_seed_{seed}.png')
    plt.close()

    # Combined Uncertainty
    plt.figure(figsize=(10, 8))
    mask = combined_unc < combined_threshold
    plt.errorbar(predictions[mask], combined_unc[mask], 
                xerr=combined_unc[mask], fmt='o', alpha=0.4, 
                ecolor='g', capthick=2)
    plt.plot(true_values[mask], combined_unc[mask], 'ro', alpha=0.4)
    plt.xlabel('Predicted Magnitude', fontsize=14, fontweight='bold')
    plt.ylabel('Combined Uncertainty', fontsize=14, fontweight='bold')
    plt.ylim(0, combined_threshold)
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.tight_layout()
    plt.savefig(f'combined_uncertainty_seed_{seed}.png')
    plt.close()

    # Predicted vs True
    plt.figure(figsize=(10, 8))
    plt.scatter(true_values, predictions, alpha=0.4, facecolors='none', edgecolors='r')
    plt.plot([true_values.min(), true_values.max()], 
             [true_values.min(), true_values.max()], 
             'k--', alpha=0.4, lw=2)
    plt.xlabel('True Magnitude', fontsize=14, fontweight='bold')
    plt.ylabel('Predicted Magnitude', fontsize=14, fontweight='bold')
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.tight_layout()
    plt.savefig(f'predicted_vs_true_seed_{seed}.png')
    plt.close()


# Part 9: Main Execution

In [None]:
#------------------------------------------------------------------------------
# Part 9: Main Execution
#------------------------------------------------------------------------------

if __name__ == "__main__":
    # Start timing
    start_time = time.time()

    # Set random seed for reproducibility
    seed = 1024
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # Load preprocessed data
    output_data_file = "pre_processed_data.npy"
    output_labels_file = "pre_processed_labels.npy"
    log_stream_max_file = "log_max_values.npy"

    # Verify file existence
    assert os.path.isfile(output_data_file), f"Data file not found at {output_data_file}"
    assert os.path.isfile(output_labels_file), f"Labels file not found at {output_labels_file}"
    assert os.path.isfile(log_stream_max_file), f"Log stream max file not found at {log_stream_max_file}"

    # Load data
    print("Loading data...")
    data = torch.tensor(np.load(output_data_file), dtype=torch.float32)
    labels = torch.tensor(np.load(output_labels_file), dtype=torch.float32)
    log_stream_max = torch.tensor(np.load(log_stream_max_file), dtype=torch.float32)
    
    print(f"Data shapes:")
    print(f"Waveforms: {data.shape}")
    print(f"Labels: {labels.shape}")
    print(f"Log stream max: {log_stream_max.shape}")

    # Split data into train/validation/test (80/10/10)
    total_samples = len(data)
    train_size = int(0.8 * total_samples)
    val_size = int(0.1 * total_samples)
    test_size = total_samples - train_size - val_size

    print(f"\nData splitting:")
    print(f"Training samples: {train_size}")
    print(f"Validation samples: {val_size}")
    print(f"Testing samples: {test_size}")

    train_data = data[:train_size]
    val_data = data[train_size:train_size + val_size]
    test_data = data[train_size + val_size:]

    train_labels = labels[:train_size]
    val_labels = labels[train_size:train_size + val_size]
    test_labels = labels[train_size + val_size:]

    train_log_stream_max = log_stream_max[:train_size]
    val_log_stream_max = log_stream_max[train_size:train_size + val_size]
    test_log_stream_max = log_stream_max[train_size + val_size:]

    # Create datasets
    train_dataset = EarthquakeDataset(train_data, train_labels, train_log_stream_max)
    val_dataset = EarthquakeDataset(val_data, val_labels, val_log_stream_max)
    test_dataset = EarthquakeDataset(test_data, test_labels, test_log_stream_max)

    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=256, shuffle=False, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=2)

    # Initialize model
    model = EarthquakeModel()
    print("\nModel architecture:")
    print(summary(model, input_size=[(256, 3000, 3), (256, 3)]))

    # Train model
    print("\nStarting model training...")
    model_path = f'best_model_seed_{seed}.pth'
    train_losses, val_losses = train_model(model, train_loader, val_loader,
                                         num_epochs=300, patience=50,
                                         model_path=model_path)

    # Evaluate model
    print("\nEvaluating model...")
    mae, mean_pred, true_values, epistemic_unc, aleatoric_unc, combined_unc = \
        evaluate_model(model, test_loader, model_path=model_path)

    # Print results
    print(f'\nFinal Results:')
    print(f'Mean Absolute Error (MAE): {mae:.4f}')
    print(f'Mean Aleatoric Uncertainty: {np.mean(aleatoric_unc):.4f}')
    print(f'Mean Epistemic Uncertainty: {np.mean(epistemic_unc):.4f}')
    print(f'Mean Combined Uncertainty: {np.mean(combined_unc):.4f}')

    # Plot results
    print("\nGenerating plots...")
    plot_uncertainties(true_values, mean_pred, epistemic_unc, aleatoric_unc, combined_unc, seed)

    # Plot loss curves
    plt.figure(figsize=(10, 5))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epochs', fontsize=14, fontweight='bold')
    plt.ylabel('Loss', fontsize=14, fontweight='bold')
    plt.legend(fontsize=14)
    plt.grid(True)
    plt.tick_params(axis='both', which='major', labelsize=14)
    plt.tight_layout()
    plt.savefig(f'loss_curves_seed_{seed}.png')
    plt.close()

    # Save results to JSON
    results = {
        "seed": seed,
        "mae": float(mae),
        "mean_aleatoric_uncertainty": float(np.mean(aleatoric_unc)),
        "mean_epistemic_uncertainty": float(np.mean(epistemic_unc)),
        "mean_combined_uncertainty": float(np.mean(combined_unc))
    }

    with open(f'model_results_seed_{seed}.json', 'w') as f:
        json.dump(results, f, indent=4)

    # End timing
    end_time = time.time()
    print(f"\nTotal execution time: {(end_time - start_time)/60:.2f} minutes")