# Membership Inference Attack Analysis for MAMBA vs CNN Models
This notebook implements and analyzes membership inference attacks against MAMBA and CNN model checkpoints

# Imports 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, TensorDataset
import os
import json

from model import ImageMamba, ModelArgs, SmallerComparableCNN
from data_loader import load_cifar10, get_class_names

# Attack Model Architecture

In [None]:
       
# Attack Model Architecture
class AttackModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.input_size = num_classes * 2  # Concatenated logits and probabilities
        
        self.network = nn.Sequential(
            nn.Linear(self.input_size, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.3),
            
            nn.Linear(128, 64),
            nn.BatchNorm1d(64), 
            nn.ReLU(),
            nn.Dropout(0.2),
            
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.1),
            
            nn.Linear(32, 2)  # Binary classification: member vs non-member
        )
        
        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        return self.network(x)


# Membership Inference Attack Implementation

In [None]:
class MembershipInferenceAttack:
    def __init__(self, target_model, device='cuda'):
        self.target_model = target_model
        self.attack_model = AttackModel().to(device)
        self.device = device
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.attack_model.parameters(), lr=0.001, weight_decay=1e-5)
        
    def prepare_attack_data(self, train_loader, test_loader):
        """Prepare data for attack by collecting model outputs"""
        attack_inputs = []
        attack_labels = []
        
        self.target_model.eval()
        with torch.no_grad():
            # Process training data (members)
            for data, _ in train_loader:
                data = data.to(self.device)
                logits, probs = self.target_model(data)
                features = torch.cat([logits, probs], dim=1)
                attack_inputs.append(features.cpu())
                attack_labels.extend([1] * data.size(0))
                
            # Process test data (non-members)    
            for data, _ in test_loader:
                data = data.to(self.device)
                logits, probs = self.target_model(data)
                features = torch.cat([logits, probs], dim=1)
                attack_inputs.append(features.cpu())
                attack_labels.extend([0] * data.size(0))
        
        X = torch.cat(attack_inputs)
        y = torch.tensor(attack_labels, dtype=torch.long)
        
        return X, y
        
    def train_attack(self, train_loader, test_loader, epochs=10):
        """Train the attack model"""
        X, y = self.prepare_attack_data(train_loader, test_loader)
        
        # Balance dataset
        member_idx = (y == 1).nonzero().squeeze()
        non_member_idx = (y == 0).nonzero().squeeze()
        
        min_size = min(len(member_idx), len(non_member_idx))
        member_idx = member_idx[:min_size]
        non_member_idx = non_member_idx[:min_size]
        
        balanced_idx = torch.cat([member_idx, non_member_idx])
        X = X[balanced_idx]
        y = y[balanced_idx]
        
        # Split into train/val
        perm = torch.randperm(len(X))
        train_size = int(0.8 * len(X))
        
        train_idx = perm[:train_size]
        val_idx = perm[train_size:]
        
        train_data = TensorDataset(X[train_idx], y[train_idx])
        val_data = TensorDataset(X[val_idx], y[val_idx])
        
        train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
        val_loader = DataLoader(val_data, batch_size=64)
        
        best_val_acc = 0
        patience = 5
        epochs_no_improve = 0
        
        train_accs = []
        val_accs = []
        
        for epoch in range(epochs):
            # Training
            self.attack_model.train()
            train_loss = 0
            train_correct = 0
            train_total = 0
            
            for inputs, labels in train_loader:
                inputs, labels = inputs.to(self.device), labels.to(self.device)
                
                self.optimizer.zero_grad()
                outputs = self.attack_model(inputs)
                loss = self.criterion(outputs, labels)
                loss.backward()
                self.optimizer.step()
                
                _, predicted = outputs.max(1)
                train_total += labels.size(0)
                train_correct += predicted.eq(labels).sum().item()
                train_loss += loss.item()
            
            train_acc = 100.0 * train_correct / train_total
            train_accs.append(train_acc)
            
            # Validation
            self.attack_model.eval()
            val_correct = 0
            val_total = 0
            
            with torch.no_grad():
                for inputs, labels in val_loader:
                    inputs, labels = inputs.to(self.device), labels.to(self.device)
                    outputs = self.attack_model(inputs)
                    _, predicted = outputs.max(1)
                    val_total += labels.size(0)
                    val_correct += predicted.eq(labels).sum().item()
            
            val_acc = 100.0 * val_correct / val_total
            val_accs.append(val_acc)
            
            print(f'Epoch {epoch+1}/{epochs} | Train Acc: {train_acc:.2f}% | Val Acc: {val_acc:.2f}%')
            
            # Early stopping
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                epochs_no_improve = 0
            else:
                epochs_no_improve += 1
            
            if epochs_no_improve == patience:
                print(f'Early stopping triggered at epoch {epoch+1}')
                break
                
        return train_accs, val_accs

def evaluate_attack(attack_model, target_model, data_loader, is_member=True, device='cuda'):
    """Evaluate attack model performance"""
    attack_model.eval()
    target_model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, _ in data_loader:
            data = data.to(device)
            logits, probs = target_model(data)
            features = torch.cat([logits, probs], dim=1)
            
            outputs = attack_model(features)
            _, predicted = outputs.max(1)
            
            labels = torch.full((data.size(0),), 1 if is_member else 0, device=device)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    accuracy = 100.0 * correct / total
    return accuracy

def analyze_checkpoint(model_name, checkpoint_path, model, train_loader, test_loader, device='cuda'):
    """Analyze a single checkpoint"""
    print(f"\nAnalyzing {model_name} checkpoint: {checkpoint_path}")
    
    # Load checkpoint
    state_dict = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(state_dict['model_state_dict'])
    model.to(device)
    model.eval()
    
    # Train attack model
    attack = MembershipInferenceAttack(model, device)
    train_accs, val_accs = attack.train_attack(train_loader, test_loader)
    
    # Evaluate attack performance
    train_acc = evaluate_attack(attack.attack_model, model, train_loader, True, device)
    test_acc = evaluate_attack(attack.attack_model, model, test_loader, False, device)
    
    return {
        'checkpoint': checkpoint_path,
        'attack_train_history': train_accs,
        'attack_val_history': val_accs,
        'member_accuracy': train_acc,
        'non_member_accuracy': test_acc,
        'attack_success': (train_acc + test_acc) / 2
    }

def plot_attack_results(mamba_results, cnn_results, save_dir='inference_attack_results'):
    """Plot comparison of attack results"""
    os.makedirs(save_dir, exist_ok=True)
    
    plt.figure(figsize=(15, 10))
    
    # Plot 1: Attack success rate over training progress
    plt.subplot(2, 1, 1)
    mamba_epochs = [int(r['checkpoint'].split('_')[-1].split('.')[0]) for r in mamba_results]
    mamba_success = [r['attack_success'] for r in mamba_results]
    cnn_success = [r['attack_success'] for r in cnn_results]
    
    plt.plot(mamba_epochs, mamba_success, 'b-o', label='MAMBA')
    plt.plot(mamba_epochs, cnn_success, 'r-o', label='CNN')
    plt.axhline(y=50, color='gray', linestyle='--', label='Random Guess')
    
    plt.title('Membership Inference Attack Success Rate vs Training Progress')
    plt.xlabel('Training Epochs')
    plt.ylabel('Attack Success Rate (%)')
    plt.legend()
    plt.grid(True)
    
    # Plot 2: Member vs Non-member accuracy
    plt.subplot(2, 1, 2)
    width = 0.35
    indices = np.arange(len(mamba_epochs))
    
    plt.bar(indices - width/2, [r['member_accuracy'] for r in mamba_results], 
           width, label='MAMBA Member', color='blue', alpha=0.6)
    plt.bar(indices + width/2, [r['member_accuracy'] for r in cnn_results],
           width, label='CNN Member', color='red', alpha=0.6)
    
    plt.bar(indices - width/2, [r['non_member_accuracy'] for r in mamba_results],
           width, bottom=[r['member_accuracy'] for r in mamba_results],
           label='MAMBA Non-member', color='blue', alpha=0.3)
    plt.bar(indices + width/2, [r['non_member_accuracy'] for r in cnn_results],
           width, bottom=[r['member_accuracy'] for r in cnn_results],
           label='CNN Non-member', color='red', alpha=0.3)
    
    plt.xticks(indices, mamba_epochs, rotation=45)
    plt.title('Attack Performance on Members vs Non-members')
    plt.xlabel('Training Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'membership_inference_analysis.png'))
    plt.show()
    
    # Save numerical results
    results = {
        'epochs': mamba_epochs,
        'mamba_success_rates': mamba_success,
        'cnn_success_rates': cnn_success,
        'mamba_member_acc': [r['member_accuracy'] for r in mamba_results],
        'mamba_nonmember_acc': [r['non_member_accuracy'] for r in mamba_results],
        'cnn_member_acc': [r['member_accuracy'] for r in cnn_results],
        'cnn_nonmember_acc': [r['non_member_accuracy'] for r in cnn_results]
    }
    
    with open(os.path.join(save_dir, 'attack_results.json'), 'w') as f:
        json.dump(results, f, indent=4)

# Run

In [None]:
def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Load data
    train_loader, test_loader, _, _, _, _ = load_cifar10(batch_size=64, seed=42)
    
    # Initialize models
    mamba_args = ModelArgs(d_model=64, n_layer=4, vocab_size=0)
    mamba_model = ImageMamba(mamba_args, num_classes=10)
    
    cnn_model = SmallerComparableCNN()
    
    # Define checkpoints to analyze (every 100 epochs from checkpoint frequency)
    checkpoints = range(100, 1501, 100)
    
    mamba_results = []
    cnn_results = []
    
    # Analyze MAMBA checkpoints
    print("\nAnalyzing MAMBA checkpoints...")
    for epoch in checkpoints:
        checkpoint_path = os.path.join('mamba_checkpoints', f'model_epoch_{epoch}.pt')
        result = analyze_checkpoint('MAMBA', checkpoint_path, mamba_model, 
                                 train_loader, test_loader, device)
        mamba_results.append(result)
    
    # Analyze CNN checkpoints
    print("\nAnalyzing CNN checkpoints...")
    for epoch in checkpoints:
        checkpoint_path = os.path.join('cnn_checkpoints', f'model_epoch_{epoch}.pt')
        result = analyze_checkpoint('CNN', checkpoint_path, cnn_model,
                                 train_loader, test_loader, device)
        cnn_results.append(result)
    
    # Plot and save results
    plot_attack_results(mamba_results, cnn_results)
    
    # Print summary statistics
    print("\nFinal Attack Results Summary:")
    print("="*50)
    print("\nMAMBA Model:")
    print(f"Final checkpoint attack success rate: {mamba_results[-1]['attack_success']:.2f}%")
    print(f"Member identification accuracy: {mamba_results[-1]['member_accuracy']:.2f}%")
    print(f"Non-member identification accuracy: {mamba_results[-1]['non_member_accuracy']:.2f}%")
    
    print("\nCNN Model:")
    print(f"Final checkpoint attack success rate: {cnn_results[-1]['attack_success']:.2f}%")
    print(f"Member identification accuracy: {cnn_results[-1]['member_accuracy']:.2f}%")
    print(f"Non-member identification accuracy: {cnn_results[-1]['non_member_accuracy']:.2f}%")

if __name__ == "__main__":
    main()