# Imports

In [1]:
import matplotlib.pyplot as plt
import json
import seaborn as sns
from matplotlib.gridspec import GridSpec
import numpy as np

In [2]:
def load_model_metrics(mamba_dir='mamba_checkpoints', cnn_dir='cnn_checkpoints'):
    """Load metrics from both model JSON files using correct paths."""
    try:
        with open(f'{mamba_dir}/training_metrics.json', 'r') as f:
            mamba_metrics = json.load(f)
        with open(f'{cnn_dir}/training_metrics.json', 'r') as f:
            cnn_metrics = json.load(f)
    except FileNotFoundError as e:
        print(f"Error loading metrics: {e}")
        print("Please ensure the metrics files exist in the correct directories:")
        print(f"- {mamba_dir}/training_metrics.json")
        print(f"- {cnn_dir}/training_metrics.json")
        raise
        
    return mamba_metrics, cnn_metrics

# Comparing the models
Comparing the metrics that were tracked during training in the traininc_metrics.json

In [3]:
def plot_metrics_comparison(mamba_dir='mamba_checkpoints', 
                          cnn_dir='cnn_checkpoints',
                          save_path=None):
    """Plot comparison of model metrics from training."""
    
    # Load metrics
    mamba_metrics, cnn_metrics = load_model_metrics(mamba_dir, cnn_dir)
    
    # Create figure with subplots
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    # Get epochs - should be [100, 200, ..., N] where N is the last completed epoch
    epochs = mamba_metrics['epochs']
    final_epoch = max(epochs)
    print(f"Plotting metrics for epochs {epochs[0]} to {final_epoch}")
    
    # 1. Accuracy Comparison
    ax1.plot(epochs, mamba_metrics['train_accuracies'], 'b-', label='MAMBA Train')
    ax1.plot(epochs, mamba_metrics['test_accuracies'], 'b--', label='MAMBA Test')
    ax1.plot(epochs, cnn_metrics['train_accuracies'], 'r-', label='CNN Train')
    ax1.plot(epochs, cnn_metrics['test_accuracies'], 'r--', label='CNN Test')
    ax1.set_title('Accuracy Comparison')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Accuracy (%)')
    ax1.legend()
    ax1.grid(True)
    
    # 2. Loss Comparison
    ax2.plot(epochs, mamba_metrics['train_losses'], 'b-', label='MAMBA Train')
    ax2.plot(epochs, mamba_metrics['test_losses'], 'b--', label='MAMBA Test')
    ax2.plot(epochs, cnn_metrics['train_losses'], 'r-', label='CNN Train')
    ax2.plot(epochs, cnn_metrics['test_losses'], 'r--', label='CNN Test')
    ax2.set_title('Loss Comparison')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    # 3. Average Confidence
    ax3.plot(epochs, mamba_metrics['train_confidences'], 'b-', label='MAMBA Train')
    ax3.plot(epochs, mamba_metrics['test_confidences'], 'b--', label='MAMBA Test')
    ax3.plot(epochs, cnn_metrics['train_confidences'], 'r-', label='CNN Train')
    ax3.plot(epochs, cnn_metrics['test_confidences'], 'r--', label='CNN Test')
    ax3.set_title('Average Confidence')
    ax3.set_xlabel('Epochs')
    ax3.set_ylabel('Confidence')
    ax3.legend()
    ax3.grid(True)
    
    # 4. Overfitting Analysis (Train-Test Accuracy Gap)
    mamba_gap = [train - test for train, test in 
                 zip(mamba_metrics['train_accuracies'], mamba_metrics['test_accuracies'])]
    cnn_gap = [train - test for train, test in 
               zip(cnn_metrics['train_accuracies'], cnn_metrics['test_accuracies'])]
    
    ax4.plot(epochs, mamba_gap, 'b-', label='MAMBA')
    ax4.plot(epochs, cnn_gap, 'r-', label='CNN')
    ax4.set_title('Overfitting Analysis (Train-Test Accuracy Gap)')
    ax4.set_xlabel('Epochs')
    ax4.set_ylabel('Accuracy Gap (%)')
    ax4.legend()
    ax4.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path)
    
    # Print summary statistics for final checkpoint
    print(f"\nFinal Checkpoint (Epoch {final_epoch}) Statistics:")
    print("="*50)
    print("\nAccuracies:")
    print(f"MAMBA - Train: {mamba_metrics['train_accuracies'][-1]:.2f}%, Test: {mamba_metrics['test_accuracies'][-1]:.2f}%")
    print(f"CNN    - Train: {cnn_metrics['train_accuracies'][-1]:.2f}%, Test: {cnn_metrics['test_accuracies'][-1]:.2f}%")
    
    print("\nOverfitting Gap (Train-Test):")
    print(f"MAMBA: {mamba_gap[-1]:.2f}%")
    print(f"CNN: {cnn_gap[-1]:.2f}%")
    
    print("\nConfidence (Train/Test):")
    print(f"MAMBA - Train: {mamba_metrics['train_confidences'][-1]:.4f}, Test: {mamba_metrics['test_confidences'][-1]:.4f}")
    print(f"CNN    - Train: {cnn_metrics['train_confidences'][-1]:.4f}, Test: {cnn_metrics['test_confidences'][-1]:.4f}")
    
    return fig, (mamba_metrics, cnn_metrics)

In [None]:
# Create the plots using the correct paths
fig, (mamba_metrics, cnn_metrics) = plot_metrics_comparison(
    mamba_dir='mamba_checkpoints',
    cnn_dir='cnn_checkpoints',
    save_path='model_comparison.png'  # Optional
)

plt.show()