# Seismic Interpretation Analysis v2

This notebook implements a modular approach to seismic interpretation using various deep learning models. It supports an ablation study to compare different model architectures and combinations.

## Overview

The workflow consists of the following steps:
1. Load preprocessed seismic data
2. Configure models and training parameters
3. Run ablation study with different model combinations
4. Analyze and visualize results

This modular approach allows for easy experimentation with different model architectures and hyperparameters.

## 1. Import Libraries

In [1]:
import os
print(os.getcwd())
os.chdir('/home/fuller_m/ablationstudy')
print(os.getcwd())

/home/fuller_m
/home/fuller_m/ablationstudy


In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
import pandas as pd
from tqdm.auto import tqdm
import json
import importlib

# Import custom modules
import config
import model_utils
from utils import extract_traces_from_patches, extract_random_trace_from_patch

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cpu


## 2. Load Preprocessed Data

Instead of preprocessing the data in this notebook, we load preprocessed data that was created using the `preprocess_data.py` script.

In [3]:
#only used once to preprocess dataset
'''import os
import numpy as np
import warnings

# Suppress specific NumPy warnings if needed
warnings.filterwarnings('ignore', category=RuntimeWarning, message='.*invalid value.*')
warnings.filterwarnings('ignore', category=RuntimeWarning, message='.*divide by zero.*')

from preprocess_data import load_and_preprocess_data

# Set up parameters
base_data_dir = "F3_Demo_2020"
segy_filename = "Rawdata/Seismic_data.sgy"
horizon_subdir = "Rawdata/Surface_data"
horizon_filenames = [
    "F3-Horizon-FS4.xyt.bz2",
    "F3-Horizon-MFS4.xyt",
    "F3-Horizon-FS6.xyt",
    "F3-Horizon-FS7.xyt",
    "F3-Horizon-FS8.xyt",
    "F3-Horizon-Shallow.xyt",
    "F3-Horizon-Top-Foresets.xyt"
]
patch_size = 32
stride = 16
max_patches = 50000  # Consider reducing this if memory issues occur

# Construct full paths
segy_path = os.path.join(base_data_dir, segy_filename)
horizon_dir = os.path.join(base_data_dir, horizon_subdir)
horizon_paths = [os.path.join(horizon_dir, hf) for hf in horizon_filenames]

try:
    # Call the function with the correct parameters
    # Wrap in try/except to catch and report errors
    X_patches, y_labels, num_classes = load_and_preprocess_data(
        segy_path, 
        horizon_paths, 
        patch_size, 
        stride, 
        max_patches
    )
    
    # Now you can use X_patches, y_labels, and num_classes directly in your notebook
    print(f"X_patches shape: {X_patches.shape}")
    print(f"y_labels shape: {y_labels.shape}")
    print(f"Number of classes: {num_classes}")
    
    # Save the results if needed
    output_dir = "preprocessed_data"
    output_filename = "preprocessed_seismic_data.npz"
    os.makedirs(output_dir, exist_ok=True)
    output_path = os.path.join(output_dir, output_filename)
    np.savez_compressed(
        output_path,
        X_patches=X_patches,
        y_labels=y_labels,
        num_classes=num_classes
    )
    print("Preprocessing and saving completed successfully!")
    
except Exception as e:
    print(f"Error during preprocessing: {type(e).__name__}: {e}")
    
    # If it's a memory error, suggest reducing max_patches
    if isinstance(e, MemoryError):
        print("Memory error encountered. Try reducing max_patches or processing in smaller batches.")'''



In [4]:
def load_preprocessed_data(data_path):
    """Load preprocessed seismic data from a .npz file.
    
    Args:
        data_path (str): Path to the preprocessed data file
        
    Returns:
        tuple: (X_patches, y_labels, num_classes, class_names)
    """
    print(f"Loading preprocessed data from {data_path}...")
    
    try:
        data = np.load(data_path)
        X_patches = data['X_patches']
        y_labels = data['y_labels']
        num_classes = int(data['num_classes'])
        
        # Load class names if available
        class_names = None
        if 'class_names' in data:
            class_names = data['class_names']
        
        print(f"Loaded {X_patches.shape[0]} samples with {num_classes} classes.")
        print(f"X_patches shape: {X_patches.shape}")
        print(f"y_labels shape: {y_labels.shape}")
        
        return X_patches, y_labels, num_classes, class_names
    
    except Exception as e:
        print(f"Error loading preprocessed data: {e}")
        # Create dummy data for demonstration if real data is not available
        print("Creating dummy data for demonstration...")
        num_samples = 100
        patch_depth = 32
        patch_height = 32
        patch_width = 32
        num_classes = 8
        
        X_patches = np.random.rand(num_samples, 1, patch_depth, patch_height, patch_width).astype(np.float32)
        y_labels = np.random.randint(0, num_classes, size=num_samples).astype(np.int64)
        class_names = [f"Class_{i}" for i in range(num_classes)]
        
        print(f"Created {num_samples} dummy samples with {num_classes} classes.")
        print(f"X_patches shape: {X_patches.shape}")
        print(f"y_labels shape: {y_labels.shape}")
        
        return X_patches, y_labels, num_classes, class_names

# Load preprocessed data
data_path = config.DATA_CONFIG['preprocessed_data_path']
X_patches, y_labels, num_classes, class_names = load_preprocessed_data(data_path)

# Extract traces for sequence models
X_traces = extract_traces_from_patches(X_patches, num_traces_per_patch=5, seed=42)
print(f"X_traces shape: {X_traces.shape}")

Loading preprocessed data from preprocessed_data/preprocessed_seismic_data.npz...
Loaded 50000 samples with 8 classes.
X_patches shape: (50000, 1, 32, 32, 32)
y_labels shape: (50000,)
X_traces shape: (50000, 32)


## 3. Create DataLoaders

Create DataLoaders for training, validation, and testing using the utility function.

In [5]:
# Create DataLoaders
train_loader, val_loader, test_loader = model_utils.create_dataloaders(
    X_patches, y_labels, X_traces,
    batch_size=config.TRAINING_PARAMS['batch_size'],
    train_split=config.DATA_CONFIG['train_split'],
    val_split=config.DATA_CONFIG['val_split'],
    test_split=config.DATA_CONFIG['test_split'],
    random_seed=config.DATA_CONFIG['random_seed'],
    stratify=config.DATA_CONFIG['stratify']
)

print(f"Train loader: {len(train_loader.dataset)} samples")
print(f"Validation loader: {len(val_loader.dataset)} samples")
print(f"Test loader: {len(test_loader.dataset)} samples")

Train loader: 35000 samples
Validation loader: 7500 samples
Test loader: 7500 samples


## 4. Model Selection and Configuration

The configuration module (`config.py`) contains definitions for all available models and their parameters. Here we can review and modify the configuration if needed.

In [6]:
# Display available CNN models
print("Available CNN models:")
for key, model_config in config.CNN_MODELS.items():
    print(f"  - {key}: {model_config['description']}")

print("\nAvailable sequence models:")
for key, model_config in config.SEQ_MODELS.items():
    print(f"  - {key}: {model_config['description']}")

print("\nHybrid model:")
print(f"  - {config.HYBRID_MODEL['class_name']}: {config.HYBRID_MODEL['description']}")

# Display ablation configuration
print("\nAblation configuration:")
print(f"  - CNN models: {config.ABLATION_CONFIG['cnn_models']}")
print(f"  - Sequence models: {config.ABLATION_CONFIG['seq_models']}")
print(f"  - Run all combinations: {config.ABLATION_CONFIG['run_all_combinations']}")
if not config.ABLATION_CONFIG['run_all_combinations']:
    print(f"  - Specific combinations: {config.ABLATION_CONFIG['specific_combinations']}")

Available CNN models:
  - cnn3d: Original 3D CNN with two convolutional layers
  - resnet3d: 3D ResNet with residual connections for deeper feature extraction
  - attention_unet3d: 3D U-Net with attention gates for focused feature extraction
  - patchnet3d: Multi-scale 3D CNN with dilated convolutions for capturing features at different scales

Available sequence models:
  - bilstm: Bidirectional LSTM for processing seismic traces
  - lstm: Unidirectional LSTM for processing seismic traces
  - transformer: Transformer model with self-attention for processing seismic traces
  - wide_deep_transformer: Transformer with parallel wide (shallow) and deep paths for capturing different patterns
  - custom_transformer: Custom transformer implementation with flexible attention mechanisms

Hybrid model:
  - HybridModel: Hybrid model combining CNN and sequence models

Ablation configuration:
  - CNN models: ['cnn3d', 'resnet3d', 'attention_unet3d', 'patchnet3d']
  - Sequence models: ['bilstm', 'ls

## 5. Single Model Testing (Optional)

Before running the full ablation study, we can test individual models to ensure they work correctly.

In [7]:
# Test a single CNN model
def test_cnn_model(cnn_key):
    print(f"Testing CNN model: {cnn_key}")
    cnn_config = config.CNN_MODELS[cnn_key]
    model = model_utils.load_model_from_config(cnn_config, num_classes=num_classes)
    print(model)
    
    # Test with a batch from the dataloader
    batch = next(iter(train_loader))
    if len(batch) == 3:  # X_patches, X_traces, y
        X_patches, _, y = batch
    else:  # X_patches, y
        X_patches, y = batch
    
    # Forward pass
    with torch.no_grad():
        outputs = model(X_patches)
    
    print(f"Input shape: {X_patches.shape}")
    print(f"Output shape: {outputs.shape}")
    print(f"Expected output shape: [batch_size, {num_classes}]")
    
    return model

# Test a single sequence model
def test_seq_model(seq_key):
    print(f"Testing sequence model: {seq_key}")
    seq_config = config.SEQ_MODELS[seq_key]
    model = model_utils.load_model_from_config(seq_config, num_classes=num_classes)
    print(model)
    
    # Test with a batch from the dataloader
    batch = next(iter(train_loader))
    if len(batch) == 3:  # X_patches, X_traces, y
        _, X_traces, y = batch
    else:  # X_patches, y
        X_patches, y = batch
        # Extract traces for testing
        X_traces = torch.tensor(np.array([extract_random_trace_from_patch(patch.numpy()) 
                                         for patch in X_patches]), dtype=torch.float32)
    
    # Forward pass
    with torch.no_grad():
        outputs = model(X_traces)
    
    print(f"Input shape: {X_traces.shape}")
    print(f"Output shape: {outputs.shape}")
    print(f"Expected output shape: [batch_size, {num_classes}]")
    
    return model

# Test a hybrid model
def test_hybrid_model(cnn_key, seq_key):
    print(f"Testing hybrid model: {cnn_key} + {seq_key}")
    cnn_config = config.CNN_MODELS[cnn_key]
    seq_config = config.SEQ_MODELS[seq_key]
    
    # Load individual models
    cnn_model = model_utils.load_model_from_config(cnn_config, num_classes=num_classes)
    seq_model = model_utils.load_model_from_config(seq_config, num_classes=num_classes)
    
    # Load hybrid model
    hybrid_model = model_utils.load_hybrid_model(
        cnn_config, seq_config, config.HYBRID_MODEL, num_classes=num_classes
    )
    print(hybrid_model)
    
    # Test with a batch from the dataloader
    batch = next(iter(train_loader))
    if len(batch) == 3:  # X_patches, X_traces, y
        X_patches, X_traces, y = batch
    else:  # X_patches, y
        X_patches, y = batch
        # Extract traces for testing
        X_traces = torch.tensor(np.array([extract_random_trace_from_patch(patch.numpy()) 
                                         for patch in X_patches]), dtype=torch.float32)
    
    # Forward pass
    with torch.no_grad():
        outputs = hybrid_model(X_patches, X_traces)
    
    print(f"CNN input shape: {X_patches.shape}")
    print(f"Sequence input shape: {X_traces.shape}")
    print(f"Output shape: {outputs.shape}")
    print(f"Expected output shape: [batch_size, {num_classes}]")
    
    return hybrid_model

# Uncomment to test individual models
# cnn_model = test_cnn_model('cnn3d')
# seq_model = test_seq_model('bilstm')
# hybrid_model = test_hybrid_model('cnn3d', 'bilstm')

## 6. Run Ablation Study

Run the ablation study with all specified model combinations.

In [None]:
# Run ablation study
results_df = model_utils.run_ablation_study(
    config, X_patches, y_labels, X_traces,
    device=device,
    class_names=class_names,
    results_dir=config.ABLATION_CONFIG['results_dir']
)

# Display results
print("\nAblation Study Results:")
display(results_df[['model_name', 'accuracy', 'precision', 'recall', 'f1', 'cohen_kappa', 'balanced_accuracy']])


=== Training cnn3d_bilstm ===

Initializing HybridModel...
  CNN model: SeismicCNN3D
  Sequence model: SeismicBiLSTM
  CNN feature size: 16384
  Sequence feature size: 128
  Combined feature size: 16512
  Fusion hidden size: 128
  Number of classes: 8


Epoch 1/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 1/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 1/50: Train Loss: 1.2669, Train Acc: 0.5369, Val Loss: 0.8280, Val Acc: 0.6744


Epoch 2/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 2/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 2/50: Train Loss: 0.9590, Train Acc: 0.6146, Val Loss: 0.7069, Val Acc: 0.7241


Epoch 3/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 3/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 3/50: Train Loss: 0.8814, Train Acc: 0.6408, Val Loss: 0.6774, Val Acc: 0.7251


Epoch 4/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 4/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 4/50: Train Loss: 0.8260, Train Acc: 0.6549, Val Loss: 0.7125, Val Acc: 0.6909


Epoch 5/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 5/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 5/50: Train Loss: 0.8016, Train Acc: 0.6623, Val Loss: 0.6319, Val Acc: 0.7321


Epoch 6/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 6/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 6/50: Train Loss: 0.7610, Train Acc: 0.6763, Val Loss: 0.5484, Val Acc: 0.7529


Epoch 7/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 7/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 7/50: Train Loss: 0.7523, Train Acc: 0.6801, Val Loss: 0.5841, Val Acc: 0.7339


Epoch 8/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 8/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 8/50: Train Loss: 0.7419, Train Acc: 0.6863, Val Loss: 0.5353, Val Acc: 0.7703


Epoch 9/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 9/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 9/50: Train Loss: 0.7142, Train Acc: 0.6987, Val Loss: 0.5542, Val Acc: 0.7613


Epoch 10/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 10/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 10/50: Train Loss: 0.6845, Train Acc: 0.7096, Val Loss: 0.4986, Val Acc: 0.7755


Epoch 11/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 11/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 11/50: Train Loss: 0.6680, Train Acc: 0.7164, Val Loss: 0.4743, Val Acc: 0.7871


Epoch 12/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 12/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 12/50: Train Loss: 0.6371, Train Acc: 0.7292, Val Loss: 0.4659, Val Acc: 0.7944


Epoch 13/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 13/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 13/50: Train Loss: 0.6275, Train Acc: 0.7312, Val Loss: 0.4670, Val Acc: 0.7937


Epoch 14/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 14/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 14/50: Train Loss: 0.5900, Train Acc: 0.7478, Val Loss: 0.4278, Val Acc: 0.8144


Epoch 15/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 15/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 15/50: Train Loss: 0.5709, Train Acc: 0.7570, Val Loss: 0.5054, Val Acc: 0.7907


Epoch 16/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 16/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 16/50: Train Loss: 0.5394, Train Acc: 0.7707, Val Loss: 0.4677, Val Acc: 0.8033


Epoch 17/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 17/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 17/50: Train Loss: 0.4957, Train Acc: 0.7911, Val Loss: 0.4234, Val Acc: 0.8253


Epoch 18/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 18/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 18/50: Train Loss: 0.4699, Train Acc: 0.8005, Val Loss: 0.6307, Val Acc: 0.7963


Epoch 19/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 19/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 19/50: Train Loss: 0.4425, Train Acc: 0.8141, Val Loss: 0.3866, Val Acc: 0.8349


Epoch 20/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 20/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 20/50: Train Loss: 0.4259, Train Acc: 0.8196, Val Loss: 0.3682, Val Acc: 0.8469


Epoch 21/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 21/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 21/50: Train Loss: 0.4048, Train Acc: 0.8296, Val Loss: 0.3892, Val Acc: 0.8328


Epoch 22/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 22/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 22/50: Train Loss: 0.3849, Train Acc: 0.8366, Val Loss: 0.4382, Val Acc: 0.8243


Epoch 23/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 23/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 23/50: Train Loss: 0.3640, Train Acc: 0.8461, Val Loss: 0.3860, Val Acc: 0.8428


Epoch 24/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 24/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 24/50: Train Loss: 0.3412, Train Acc: 0.8556, Val Loss: 0.3784, Val Acc: 0.8505


Epoch 25/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 25/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 25/50: Train Loss: 0.3145, Train Acc: 0.8671, Val Loss: 0.3858, Val Acc: 0.8496


Epoch 26/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

Epoch 26/50 [Val]:   0%|          | 0/235 [00:00<?, ?it/s]

Epoch 26/50: Train Loss: 0.2966, Train Acc: 0.8741, Val Loss: 0.4840, Val Acc: 0.8323


Epoch 27/50 [Train]:   0%|          | 0/1094 [00:00<?, ?it/s]

## 7. Analyze Results

Analyze the results of the ablation study.

In [None]:
# Plot accuracy comparison
plt.figure(figsize=config.VISUALIZATION_CONFIG['figsize'])
sns.barplot(x='accuracy', y='model_name', data=results_df.sort_values('accuracy', ascending=False))
plt.title('Model Comparison - Accuracy')
plt.xlabel('Accuracy')
plt.ylabel('Model')
plt.tight_layout()
plt.show()

# Plot F1 score comparison
plt.figure(figsize=config.VISUALIZATION_CONFIG['figsize'])
sns.barplot(x='f1', y='model_name', data=results_df.sort_values('f1', ascending=False))
plt.title('Model Comparison - F1 Score')
plt.xlabel('F1 Score')
plt.ylabel('Model')
plt.tight_layout()
plt.show()

## 8. Detailed Analysis of Best Model

Analyze the best performing model in more detail.

In [None]:
# Find the best model based on accuracy
best_model_row = results_df.loc[results_df['accuracy'].idxmax()]
best_model_name = best_model_row['model_name']
best_cnn = best_model_row['cnn_model']
best_seq = best_model_row['seq_model']

print(f"Best model: {best_model_name}")
print(f"CNN model: {best_cnn} - {config.CNN_MODELS[best_cnn]['description']}")
print(f"Sequence model: {best_seq} - {config.SEQ_MODELS[best_seq]['description']}")
print(f"Accuracy: {best_model_row['accuracy']:.4f}")
print(f"F1 Score: {best_model_row['f1']:.4f}")

# Load confusion matrix for the best model
results_dir = config.ABLATION_CONFIG['results_dir']
cm_path = os.path.join(results_dir, f"{best_model_name}_confusion_matrix.{config.VISUALIZATION_CONFIG['save_format']}")

if os.path.exists(cm_path):
    plt.figure(figsize=config.VISUALIZATION_CONFIG['figsize'])
    img = plt.imread(cm_path)
    plt.imshow(img)
    plt.axis('off')
    plt.title(f"Confusion Matrix - {best_model_name}")
    plt.show()
else:
    print(f"Confusion matrix image not found at {cm_path}")
    
    # Load the best model and generate confusion matrix
    cnn_config = config.CNN_MODELS[best_cnn]
    seq_config = config.SEQ_MODELS[best_seq]
    
    # Load best model checkpoint if available
    best_model_path = os.path.join(results_dir, f"{best_model_name}_best.pth")
    if os.path.exists(best_model_path):
        # Load hybrid model
        best_model = model_utils.load_hybrid_model(
            cnn_config, seq_config, config.HYBRID_MODEL, num_classes=num_classes
        )
        best_model.load_state_dict(torch.load(best_model_path))
        best_model = best_model.to(device)
        
        # Evaluate on test set
        results = model_utils.evaluate_model(
            best_model, test_loader, config.EVALUATION_METRICS,
            device=device, class_names=class_names
        )
        
        # Plot confusion matrix
        model_utils.plot_confusion_matrix(
            results['confusion_matrix'],
            class_names=class_names,
            figsize=config.VISUALIZATION_CONFIG['figsize'],
            cmap=config.VISUALIZATION_CONFIG['cmap'],
            normalize=True
        )
        plt.show()
    else:
        print(f"Best model checkpoint not found at {best_model_path}")

## 9. Training Curves Analysis

Analyze the training curves of the models.

In [None]:
# Load and plot training history for the best model
history_path = os.path.join(results_dir, f"{best_model_name}_history.json")

if os.path.exists(history_path):
    with open(history_path, 'r') as f:
        history = json.load(f)
    
    model_utils.plot_training_history(history, figsize=config.VISUALIZATION_CONFIG['figsize'])
    plt.suptitle(f"Training Curves - {best_model_name}")
    plt.tight_layout()
    plt.show()
else:
    print(f"Training history not found at {history_path}")

## 10. Per-Class Performance Analysis

Analyze the performance of the best model for each class.

In [None]:
# Extract per-class metrics from results_df
if class_names is not None:
    per_class_metrics = {}
    for class_idx, class_name in enumerate(class_names):
        class_columns = [col for col in results_df.columns if col.startswith(f"{class_name}_")]
        if class_columns:
            per_class_metrics[class_name] = results_df.loc[results_df['model_name'] == best_model_name, class_columns].iloc[0].to_dict()
    
    if per_class_metrics:
        # Create DataFrame for visualization
        per_class_df = pd.DataFrame()
        for class_name, metrics in per_class_metrics.items():
            metrics_dict = {'class': class_name}
            for metric_name, value in metrics.items():
                metric_key = metric_name.split('_')[-1]  # Extract metric name (precision, recall, etc.)
                metrics_dict[metric_key] = value
            per_class_df = pd.concat([per_class_df, pd.DataFrame([metrics_dict])], ignore_index=True)
        
        # Display per-class metrics
        print(f"Per-class performance for {best_model_name}:")
        display(per_class_df)
        
        # Plot per-class F1 scores
        plt.figure(figsize=config.VISUALIZATION_CONFIG['figsize'])
        sns.barplot(x='f1-score', y='class', data=per_class_df.sort_values('f1-score', ascending=False))
        plt.title(f'Per-Class F1 Scores - {best_model_name}')
        plt.xlabel('F1 Score')
        plt.ylabel('Class')
        plt.tight_layout()
        plt.show()
    else:
        print("Per-class metrics not found in results.")
else:
    print("Class names not available for per-class analysis.")

## 11. CNN vs. Sequence Model Impact Analysis

Analyze the impact of different CNN and sequence model choices on performance.

In [None]:
# Group by CNN model and calculate mean performance
cnn_impact = results_df.groupby('cnn_model')[['accuracy', 'f1']].mean().reset_index()
cnn_impact = cnn_impact.sort_values('accuracy', ascending=False)

# Group by sequence model and calculate mean performance
seq_impact = results_df.groupby('seq_model')[['accuracy', 'f1']].mean().reset_index()
seq_impact = seq_impact.sort_values('accuracy', ascending=False)

# Plot CNN model impact
plt.figure(figsize=(10, 6))
sns.barplot(x='accuracy', y='cnn_model', data=cnn_impact)
plt.title('CNN Model Impact on Accuracy (Averaged Across Sequence Models)')
plt.xlabel('Mean Accuracy')
plt.ylabel('CNN Model')
plt.tight_layout()
plt.show()

# Plot sequence model impact
plt.figure(figsize=(10, 6))
sns.barplot(x='accuracy', y='seq_model', data=seq_impact)
plt.title('Sequence Model Impact on Accuracy (Averaged Across CNN Models)')
plt.xlabel('Mean Accuracy')
plt.ylabel('Sequence Model')
plt.tight_layout()
plt.show()

## 12. Conclusion

Summarize the findings of the ablation study.

In [None]:
# Generate summary statistics
summary = {
    'best_model': best_model_name,
    'best_accuracy': best_model_row['accuracy'],
    'best_f1': best_model_row['f1'],
    'mean_accuracy': results_df['accuracy'].mean(),
    'std_accuracy': results_df['accuracy'].std(),
    'best_cnn': cnn_impact.iloc[0]['cnn_model'],
    'best_cnn_mean_accuracy': cnn_impact.iloc[0]['accuracy'],
    'best_seq': seq_impact.iloc[0]['seq_model'],
    'best_seq_mean_accuracy': seq_impact.iloc[0]['accuracy'],
    'num_models_tested': len(results_df),
    'accuracy_range': results_df['accuracy'].max() - results_df['accuracy'].min()
}

print("Ablation Study Summary:")
print(f"Number of models tested: {summary['num_models_tested']}")
print(f"Best model: {summary['best_model']}")
print(f"Best accuracy: {summary['best_accuracy']:.4f}")
print(f"Best F1 score: {summary['best_f1']:.4f}")
print(f"Mean accuracy across all models: {summary['mean_accuracy']:.4f} ± {summary['std_accuracy']:.4f}")
print(f"Accuracy range: {summary['accuracy_range']:.4f}")
print(f"Best CNN architecture: {summary['best_cnn']} (mean accuracy: {summary['best_cnn_mean_accuracy']:.4f})")
print(f"Best sequence model: {summary['best_seq']} (mean accuracy: {summary['best_seq_mean_accuracy']:.4f})")

# Save summary to file
summary_path = os.path.join(results_dir, 'ablation_summary.json')
with open(summary_path, 'w') as f:
    json.dump({k: float(v) if isinstance(v, (float, np.float32, np.float64)) else v for k, v in summary.items()}, f, indent=2)

print(f"\nSummary saved to {summary_path}")

## 13. Save Results

Save all results and visualizations.

In [None]:
# Save results DataFrame to CSV
results_csv_path = os.path.join(results_dir, 'ablation_results_detailed.csv')
results_df.to_csv(results_csv_path, index=False)
print(f"Detailed results saved to {results_csv_path}")

# Create a simplified results table for reporting
simple_results = results_df[['model_name', 'cnn_model', 'seq_model', 'accuracy', 'precision', 'recall', 'f1']]
simple_results = simple_results.sort_values('accuracy', ascending=False)
simple_csv_path = os.path.join(results_dir, 'ablation_results_simple.csv')
simple_results.to_csv(simple_csv_path, index=False)
print(f"Simplified results saved to {simple_csv_path}")