# Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from model import ImageMamba, ModelArgs
from data_loader import load_cifar10
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os
import json
from tqdm import tqdm

In [2]:
# Make this an import soon too
class SmallerComparableCNN(nn.Module):
    def __init__(self):
        super(SmallerComparableCNN, self).__init__()
        # Reduced initial channels and total layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)  # Reduced from 64 to 32
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)  # Reduced from 128 to 64
        self.bn2 = nn.BatchNorm2d(64)
        
        # Global average pooling and final dense layer
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, 10)  # Changed input features to match last conv layer
        
    def forward(self, x):
        # First block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.max_pool2d(x, 2)
        
        # Second block
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        
        # Global average pooling
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        
        # Final classification
        logits = self.fc(x)
        probabilities = F.softmax(logits, dim=-1)
        
        return logits, probabilities

# code

In [3]:
def load_all_checkpoints(checkpoint_dir, model, device='cuda'):
    """Load all available checkpoints for a model."""
    checkpoints = []
    epochs = []
    
    for filename in os.listdir(checkpoint_dir):
        if filename.endswith('.pt'):
            epoch = int(filename.split('_')[-1].split('.pt')[0])
            checkpoint_path = os.path.join(checkpoint_dir, filename)
            checkpoint = torch.load(checkpoint_path, map_location=device)
            model.load_state_dict(checkpoint['model_state_dict'])
            
            checkpoints.append({
                'epoch': epoch,
                'model_state': model.state_dict(),
                'metrics': checkpoint['metrics'] if 'metrics' in checkpoint else None
            })
            epochs.append(epoch)
    
    return sorted(checkpoints, key=lambda x: x['epoch']), sorted(epochs)

def evaluate_model_vulnerability(model, train_loader, test_loader, device='cuda'):
    """Evaluate model's vulnerability to inference attacks."""
    model.eval()
    
    def get_confidence_stats(loader):
        confidences = []
        predictions = []
        with torch.no_grad():
            for data, labels in loader:
                data = data.to(device)
                _, probs = model(data)
                max_probs, preds = torch.max(probs, dim=1)
                confidences.extend(max_probs.cpu().numpy())
                predictions.extend((preds == labels.to(device)).cpu().numpy())
        return np.mean(confidences), np.std(confidences), np.mean(predictions)

    train_conf_mean, train_conf_std, train_acc = get_confidence_stats(train_loader)
    test_conf_mean, test_conf_std, test_acc = get_confidence_stats(test_loader)
    
    # Calculate vulnerability metrics
    confidence_gap = train_conf_mean - test_conf_mean
    acc_gap = train_acc - test_acc
    
    # Calculate confidence distribution overlap
    confidence_overlap = min(train_conf_mean + train_conf_std, test_conf_mean + test_conf_std) - \
                        max(train_conf_mean - train_conf_std, test_conf_mean - test_conf_std)
    
    return {
        'train_confidence': train_conf_mean,
        'test_confidence': test_conf_mean,
        'train_accuracy': train_acc * 100,
        'test_accuracy': test_acc * 100,
        'confidence_gap': confidence_gap,
        'accuracy_gap': acc_gap * 100,
        'confidence_overlap': confidence_overlap
    }

def analyze_models(mamba_model, cnn_model, train_loader, test_loader, mamba_dir, cnn_dir, device='cuda'):
    """Analyze both models across their checkpoints."""
    
    # Load checkpoints
    print("Loading MAMBA checkpoints...")
    mamba_checkpoints, mamba_epochs = load_all_checkpoints(mamba_dir, mamba_model, device)
    print("Loading CNN checkpoints...")
    cnn_checkpoints, cnn_epochs = load_all_checkpoints(cnn_dir, cnn_model, device)
    
    results = {
        'mamba': {'vulnerabilities': [], 'epochs': mamba_epochs},
        'cnn': {'vulnerabilities': [], 'epochs': cnn_epochs}
    }
    
    # Analyze MAMBA checkpoints
    print("\nAnalyzing MAMBA checkpoints...")
    for checkpoint in tqdm(mamba_checkpoints):
        mamba_model.load_state_dict(checkpoint['model_state'])
        vuln = evaluate_model_vulnerability(mamba_model, train_loader, test_loader, device)
        results['mamba']['vulnerabilities'].append(vuln)
    
    # Analyze CNN checkpoints
    print("\nAnalyzing CNN checkpoints...")
    for checkpoint in tqdm(cnn_checkpoints):
        cnn_model.load_state_dict(checkpoint['model_state'])
        vuln = evaluate_model_vulnerability(cnn_model, train_loader, test_loader, device)
        results['cnn']['vulnerabilities'].append(vuln)
    
    return results

def plot_vulnerability_comparison(results, save_dir='comparison_plots'):
    """Plot comparison of model vulnerabilities."""
    os.makedirs(save_dir, exist_ok=True)
    
    # Create subplots
    fig = plt.figure(figsize=(20, 15))
    gs = plt.GridSpec(3, 2)
    
    # 1. Accuracy Gaps
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.plot(results['mamba']['epochs'], 
             [v['accuracy_gap'] for v in results['mamba']['vulnerabilities']], 
             'b-', label='MAMBA')
    ax1.plot(results['cnn']['epochs'], 
             [v['accuracy_gap'] for v in results['cnn']['vulnerabilities']], 
             'r-', label='CNN')
    ax1.set_title('Train-Test Accuracy Gap Evolution')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Accuracy Gap (%)')
    ax1.legend()
    ax1.grid(True)
    
    # 2. Confidence Gaps
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.plot(results['mamba']['epochs'], 
             [v['confidence_gap'] for v in results['mamba']['vulnerabilities']], 
             'b-', label='MAMBA')
    ax2.plot(results['cnn']['epochs'], 
             [v['confidence_gap'] for v in results['cnn']['vulnerabilities']], 
             'r-', label='CNN')
    ax2.set_title('Confidence Gap Evolution')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Train-Test Confidence Gap')
    ax2.legend()
    ax2.grid(True)
    
    # 3. Train vs Test Accuracies
    ax3 = fig.add_subplot(gs[1, :])
    ax3.plot(results['mamba']['epochs'], 
             [v['train_accuracy'] for v in results['mamba']['vulnerabilities']], 
             'b-', label='MAMBA Train')
    ax3.plot(results['mamba']['epochs'], 
             [v['test_accuracy'] for v in results['mamba']['vulnerabilities']], 
             'b--', label='MAMBA Test')
    ax3.plot(results['cnn']['epochs'], 
             [v['train_accuracy'] for v in results['cnn']['vulnerabilities']], 
             'r-', label='CNN Train')
    ax3.plot(results['cnn']['epochs'], 
             [v['test_accuracy'] for v in results['cnn']['vulnerabilities']], 
             'r--', label='CNN Test')
    ax3.set_title('Accuracy Comparison')
    ax3.set_xlabel('Epochs')
    ax3.set_ylabel('Accuracy (%)')
    ax3.legend()
    ax3.grid(True)
    
    # 4. Vulnerability Score Evolution
    ax4 = fig.add_subplot(gs[2, :])
    mamba_scores = [v['confidence_gap'] * (1 - v['confidence_overlap']) for v in results['mamba']['vulnerabilities']]
    cnn_scores = [v['confidence_gap'] * (1 - v['confidence_overlap']) for v in results['cnn']['vulnerabilities']]
    
    ax4.plot(results['mamba']['epochs'], mamba_scores, 'b-', label='MAMBA')
    ax4.plot(results['cnn']['epochs'], cnn_scores, 'r-', label='CNN')
    ax4.set_title('Overall Vulnerability Score Evolution')
    ax4.set_xlabel('Epochs')
    ax4.set_ylabel('Vulnerability Score')
    ax4.legend()
    ax4.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'vulnerability_analysis.png'))
    
    # Print summary statistics
    print("\nVulnerability Analysis Summary:")
    print("="*50)
    
    final_mamba = results['mamba']['vulnerabilities'][-1]
    final_cnn = results['cnn']['vulnerabilities'][-1]
    
    print("\nFinal Checkpoint Statistics:")
    print(f"MAMBA:")
    print(f"  - Accuracy Gap: {final_mamba['accuracy_gap']:.2f}%")
    print(f"  - Confidence Gap: {final_mamba['confidence_gap']:.4f}")
    print(f"  - Train Accuracy: {final_mamba['train_accuracy']:.2f}%")
    print(f"  - Test Accuracy: {final_mamba['test_accuracy']:.2f}%")
    
    print(f"\nCNN:")
    print(f"  - Accuracy Gap: {final_cnn['accuracy_gap']:.2f}%")
    print(f"  - Confidence Gap: {final_cnn['confidence_gap']:.4f}")
    print(f"  - Train Accuracy: {final_cnn['train_accuracy']:.2f}%")
    print(f"  - Test Accuracy: {final_cnn['test_accuracy']:.2f}%")
    
    # Calculate overall vulnerability scores
    mamba_vuln_score = final_mamba['confidence_gap'] * (1 - final_mamba['confidence_overlap'])
    cnn_vuln_score = final_cnn['confidence_gap'] * (1 - final_cnn['confidence_overlap'])
    
    print(f"\nOverall Vulnerability Score (higher = more vulnerable):")
    print(f"MAMBA: {mamba_vuln_score:.4f}")
    print(f"CNN: {cnn_vuln_score:.4f}")
    
    conclusion = "MAMBA" if mamba_vuln_score > cnn_vuln_score else "CNN"
    print(f"\nConclusion: {conclusion} appears more vulnerable to inference attacks.")
    
    return fig

# Run

In [None]:
# Initialize models
d_model = 64
n_layer = 4
num_classes = 10

model_args = ModelArgs(d_model=d_model, n_layer=n_layer, vocab_size=0)
mamba_model = ImageMamba(model_args, num_classes=num_classes)
cnn_model = SmallerComparableCNN()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mamba_model = mamba_model.to(device)
cnn_model = cnn_model.to(device)

# Load data
train_loader, test_loader, _, _, _, _ = load_cifar10(batch_size=64, seed=42)

In [None]:
# Run analysis with explicitly specified checkpoint directories
results = analyze_models(
    mamba_model=mamba_model,
    cnn_model=cnn_model,
    train_loader=train_loader,
    test_loader=test_loader,
    mamba_dir='mamba_checkpoints',
    cnn_dir='cnn_checkpoints'
)

# Create plots
plot_vulnerability_comparison(results)