# Measure MAMBA and CNN

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
import os

In [3]:
from tqdm import tqdm

import numpy as np
import keras
from tensorflow import keras
from keras.datasets import cifar10
from __future__ import print_function
from keras.models import Sequential
from keras.models import save_model, load_model
from keras.layers import Dense, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D

import keras.backend as K
K.clear_session()

# Import necessary libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score
# from model import Mamba, ModelArgs  # Import your custom Mamba implementation
# Assuming the model classes are defined in `model.py`
import sys
sys.path.append('/path/to/your/model/directory')
from model import ImageMamba, ModelArgs

ModuleNotFoundError: No module named 'model'

In [None]:
def evaluate_checkpoint(model, loader, device):
    """Evaluate a model checkpoint comprehensively"""
    model.eval()
    total = 0
    correct = 0
    running_loss = 0
    confidences = []
    
    criterion = nn.CrossEntropyLoss()
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits, probabilities = model(inputs)
            loss = criterion(logits, labels.squeeze())
            
            _, predicted = torch.max(logits, 1)
            confidence, _ = torch.max(probabilities, 1)
            
            total += labels.size(0)
            correct += (predicted == labels.squeeze()).sum().item()
            running_loss += loss.item()
            confidences.extend(confidence.cpu().numpy())
    
    accuracy = (correct / total) * 100
    avg_loss = running_loss / len(loader)
    avg_confidence = np.mean(confidences)
    
    return {
        'accuracy': accuracy,
        'loss': avg_loss,
        'avg_confidence': avg_confidence,
        'confidences': confidences
    }

def reconstruct_metrics_from_checkpoints(model_class, checkpoint_base_path, 
                                       train_loader, test_loader, 
                                       json_save_path, device='cuda'):
    """Reconstruct metrics from saved checkpoints"""
    model = model_class().to(device)
    metrics = {
        'train_accuracies': [],
        'test_accuracies': [],
        'train_losses': [],
        'test_losses': [],
        'train_confidences': [],
        'test_confidences': [],
        'epochs': [],
        'train_confidence_distributions': [],
        'test_confidence_distributions': []
    }
    
    # Evaluate checkpoints at every 100 epochs
    for epoch in range(100, 1501, 100):
        checkpoint_path = os.path.join(checkpoint_base_path, f'model_epoch_{epoch}.pt')
        if not os.path.exists(checkpoint_path):
            print(f"Skipping epoch {epoch} - checkpoint not found")
            continue
            
        print(f"Processing epoch {epoch}")
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Evaluate on train and test sets
        train_metrics = evaluate_checkpoint(model, train_loader, device)
        test_metrics = evaluate_checkpoint(model, test_loader, device)
        
        # Store metrics
        metrics['epochs'].append(epoch)
        metrics['train_accuracies'].append(train_metrics['accuracy'])
        metrics['test_accuracies'].append(test_metrics['accuracy'])
        metrics['train_losses'].append(train_metrics['loss'])
        metrics['test_losses'].append(test_metrics['loss'])
        metrics['train_confidences'].append(train_metrics['avg_confidence'])
        metrics['test_confidences'].append(test_metrics['avg_confidence'])
        metrics['train_confidence_distributions'].append(train_metrics['confidences'])
        metrics['test_confidence_distributions'].append(test_metrics['confidences'])
    
    # Save metrics to JSON
    with open(json_save_path, 'w') as f:
        # Convert numpy arrays to lists for JSON serialization
        json_metrics = {
            'train_accuracies': [float(x) for x in metrics['train_accuracies']],
            'test_accuracies': [float(x) for x in metrics['test_accuracies']],
            'train_losses': [float(x) for x in metrics['train_losses']],
            'test_losses': [float(x) for x in metrics['test_losses']],
            'train_confidences': [float(x) for x in metrics['train_confidences']],
            'test_confidences': [float(x) for x in metrics['test_confidences']],
            'epochs': metrics['epochs'],
            # Store only summary statistics for confidence distributions to keep JSON size manageable
            'train_confidence_distributions_stats': [
                {
                    'mean': float(np.mean(dist)),
                    'std': float(np.std(dist)),
                    'min': float(np.min(dist)),
                    'max': float(np.max(dist)),
                    'median': float(np.median(dist)),
                    'q1': float(np.percentile(dist, 25)),
                    'q3': float(np.percentile(dist, 75))
                } for dist in metrics['train_confidence_distributions']
            ],
            'test_confidence_distributions_stats': [
                {
                    'mean': float(np.mean(dist)),
                    'std': float(np.std(dist)),
                    'min': float(np.min(dist)),
                    'max': float(np.max(dist)),
                    'median': float(np.median(dist)),
                    'q1': float(np.percentile(dist, 25)),
                    'q3': float(np.percentile(dist, 75))
                } for dist in metrics['test_confidence_distributions']
            ]
        }
        json.dump(json_metrics, f, indent=4)
    
    return metrics

def plot_comparative_metrics(cnn_metrics, mamba_metrics, save_dir='comparison_plots'):
    """Create comparative plots of CNN vs MAMBA metrics"""
    os.makedirs(save_dir, exist_ok=True)
    
    # Set up the plots
    plt.style.use('seaborn')
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(15, 12))
    
    epochs = cnn_metrics['epochs']
    
    # 1. Accuracy Comparison
    ax1.plot(epochs, cnn_metrics['train_accuracies'], 'b-', label='CNN Train')
    ax1.plot(epochs, cnn_metrics['test_accuracies'], 'b--', label='CNN Test')
    ax1.plot(epochs, mamba_metrics['train_accuracies'], 'r-', label='MAMBA Train')
    ax1.plot(epochs, mamba_metrics['test_accuracies'], 'r--', label='MAMBA Test')
    ax1.set_title('Accuracy Comparison')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Accuracy (%)')
    ax1.legend()
    ax1.grid(True)
    
    # 2. Loss Comparison
    ax2.plot(epochs, cnn_metrics['train_losses'], 'b-', label='CNN Train')
    ax2.plot(epochs, cnn_metrics['test_losses'], 'b--', label='CNN Test')
    ax2.plot(epochs, mamba_metrics['train_losses'], 'r-', label='MAMBA Train')
    ax2.plot(epochs, mamba_metrics['test_losses'], 'r--', label='MAMBA Test')
    ax2.set_title('Loss Comparison')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Loss')
    ax2.legend()
    ax2.grid(True)
    
    # 3. Confidence Comparison
    ax3.plot(epochs, cnn_metrics['train_confidences'], 'b-', label='CNN Train')
    ax3.plot(epochs, cnn_metrics['test_confidences'], 'b--', label='CNN Test')
    ax3.plot(epochs, mamba_metrics['train_confidences'], 'r-', label='MAMBA Train')
    ax3.plot(epochs, mamba_metrics['test_confidences'], 'r--', label='MAMBA Test')
    ax3.set_title('Average Confidence Comparison')
    ax3.set_xlabel('Epochs')
    ax3.set_ylabel('Confidence')
    ax3.legend()
    ax3.grid(True)
    
    # 4. Overfitting Analysis (Train-Test Accuracy Gap)
    cnn_gap = [t - v for t, v in zip(cnn_metrics['train_accuracies'], cnn_metrics['test_accuracies'])]
    mamba_gap = [t - v for t, v in zip(mamba_metrics['train_accuracies'], mamba_metrics['test_accuracies'])]
    
    ax4.plot(epochs, cnn_gap, 'b-', label='CNN')
    ax4.plot(epochs, mamba_gap, 'r-', label='MAMBA')
    ax4.set_title('Overfitting Analysis (Train-Test Accuracy Gap)')
    ax4.set_xlabel('Epochs')
    ax4.set_ylabel('Accuracy Gap (%)')
    ax4.legend()
    ax4.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'model_comparisons.png'))
    plt.close()

In [None]:
# First reconstruct metrics for CNN
cnn_metrics = reconstruct_metrics_from_checkpoints(
    model_class=SmallerComparableCNN,
    checkpoint_base_path='smaller_cnn_checkpoints',
    train_loader=train_loader,
    test_loader=test_loader,
    json_save_path='New_CNN_Metrics.json',
    device=device
)

# Then reconstruct metrics for MAMBA
mamba_metrics = reconstruct_metrics_from_checkpoints(
    model_class=ImageMamba,  # your MAMBA model class
    checkpoint_base_path='model_checkpoints_extended',
    train_loader=train_loader,
    test_loader=test_loader,
    json_save_path='New_MAMBA_Metrics.json',
    device=device
)

# Create comparison plots
plot_comparative_metrics(cnn_metrics, mamba_metrics)