# EGEAT: Ensemble Gradient-based Ensemble Adversarial Training

This notebook runs the complete EGEAT experiment pipeline including:
- Ensemble model training
- Adversarial training with geometric regularization
- Comprehensive evaluation metrics
- Visualization of all diagrams and results


## 1. Setup and Installation


In [None]:
# Install required packages
!pip install torch torchvision numpy pandas matplotlib seaborn tqdm scikit-learn imageio -q


In [None]:
# Standard imports
import os
import sys
import json
import logging
import traceback
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from copy import deepcopy

# Set up paths for Colab
WORK_DIR = '/content/EGEAT'
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)

# Create directory structure
os.makedirs('src', exist_ok=True)
os.makedirs('src/attacks', exist_ok=True)
os.makedirs('src/evaluation', exist_ok=True)
os.makedirs('src/models', exist_ok=True)
os.makedirs('src/training', exist_ok=True)
os.makedirs('src/utils', exist_ok=True)

print(f"Working directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## 2. Source Code Setup

The following cells contain all the source code from the project organized by module.


### 2.1 Configuration Module


In [None]:
# src/config.py
from dataclasses import dataclass, asdict, field
from typing import Optional
import argparse

@dataclass
class TrainingConfig:
    batch_size: int = 128
    learning_rate: float = 2e-4
    epochs: int = 5
    lambda_geom: float = 0.1
    lambda_soup: float = 0.05
    epsilon: float = 8/255
    use_mixed_precision: bool = False
    num_workers: int = 2
    pin_memory: bool = True

@dataclass
class ModelConfig:
    model_type: str = 'SimpleCNN'
    input_channels: int = 3
    num_classes: int = 10
    ensemble_size: int = 5

@dataclass
class AttackConfig:
    epsilon: float = 8/255
    alpha: float = 2/255
    pgd_iters: int = 10
    cw_iters: int = 50
    cw_c: float = 1e-2
    cw_lr: float = 0.01

@dataclass
class EvaluationConfig:
    max_batches: int = 10
    n_bins_ece: int = 15
    loss_landscape_grid_n: int = 21
    loss_landscape_radius: float = 1.0

@dataclass
class DataConfig:
    dataset: str = 'cifar10'
    val_split: float = 0.1
    augment: bool = True
    data_path: str = './data'

@dataclass
class ExperimentConfig:
    experiment_name: str = 'egeat_experiment'
    seed: int = 42
    device: Optional[str] = None
    save_dir: str = 'results'
    resume: bool = False
    checkpoint_dir: Optional[str] = None
    training: TrainingConfig = field(default_factory=TrainingConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    attack: AttackConfig = field(default_factory=AttackConfig)
    evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
    data: DataConfig = field(default_factory=DataConfig)
    
    def to_dict(self):
        return asdict(self)
    
    @classmethod
    def from_dict(cls, d):
        training = TrainingConfig(**d.pop('training', {}))
        model = ModelConfig(**d.pop('model', {}))
        attack = AttackConfig(**d.pop('attack', {}))
        evaluation = EvaluationConfig(**d.pop('evaluation', {}))
        data = DataConfig(**d.pop('data', {}))
        return cls(training=training, model=model, attack=attack, evaluation=evaluation, data=data, **d)
    
    def save(self, path: str):
        os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)
    
    @classmethod
    def load(cls, path: str):
        with open(path, 'r') as f:
            d = json.load(f)
        return cls.from_dict(d)

def get_device(device_str=None, verbose=False):
    if device_str is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device_str)
    if verbose:
        print(f"Using device: {device}")
    return device


### 2.2 Model Architectures


In [None]:
# src/models/cnn.py
class SimpleCNN(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, 3, 1, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.fc1 = nn.Linear(128*4*4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class DCGAN_CNN(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128*8*8, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


### 2.3 Data Loaders


In [None]:
# src/utils/data_loader.py
def seed_everything(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_mnist_loaders(batch_size=128, val_split=0.1, num_workers=2, pin_memory=True, seed=123):
    seed_everything(seed)
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root='./data/MNIST', train=True, download=True, transform=transform)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed))
    test_dataset = datasets.MNIST(root='./data/MNIST', train=False, download=True, transform=transform)
    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
    )

def get_cifar10_loaders(batch_size=128, val_split=0.1, num_workers=2, pin_memory=True, augment=True, seed=123):
    seed_everything(seed)
    if augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    dataset = datasets.CIFAR10(root='./data/CIFAR10', train=True, download=True, transform=transform_train)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed))
    test_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True, transform=transform_test)
    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
    )


### 2.4 Training Module (EGEAT)


In [None]:
# src/training/train_egeat.py
def exact_delta_star(grad, eps, norm='inf'):
    if norm == 'inf':
        return eps * grad.sign()
    g_flat = grad.view(grad.size(0), -1)
    if norm == 2:
        g_norm = g_flat.norm(p=2, dim=1).view(-1, 1, 1, 1) + 1e-12
        return eps * grad / g_norm
    if norm == 1:
        g_norm = g_flat.abs().max(dim=1)[0].view(-1, 1, 1, 1) + 1e-12
        return eps * grad / g_norm
    raise ValueError("Unsupported norm type.")

def gradient_subspace_penalty(model, x, y, ensemble_models):
    if not ensemble_models or len(ensemble_models) == 0:
        return 0.0
    grads = []
    x_req = x.clone().detach().requires_grad_(True)
    out = model(x_req)
    loss = F.cross_entropy(out, y)
    g = torch.autograd.grad(loss, x_req)[0].view(x.size(0), -1)
    grads.append(g)
    for m in ensemble_models:
        m.eval()
        x_snap = x.clone().detach().requires_grad_(True)
        out_snap = m(x_snap)
        loss_snap = F.cross_entropy(out_snap, y)
        g_snap = torch.autograd.grad(loss_snap, x_snap)[0].view(x.size(0), -1)
        grads.append(g_snap)
    L = len(grads)
    penalty = 0.0
    for i in range(L):
        for j in range(i + 1, L):
            Gi = grads[i]
            Gj = grads[j]
            num = (Gi * Gj).sum(dim=1).mean()
            den = (Gi.norm(p=2, dim=1) * Gj.norm(p=2, dim=1)).mean() + 1e-12
            penalty += num / den
    return penalty / (L * (L - 1) / 2)

def compute_theta_soup(ensemble_snapshots):
    if ensemble_snapshots is None or len(ensemble_snapshots) == 0:
        return None
    soup = deepcopy(ensemble_snapshots[0])
    for k in soup.keys():
        for s in ensemble_snapshots[1:]:
            soup[k] += s[k]
        soup[k] /= len(ensemble_snapshots)
    return soup

def train_egeat_epoch(model, dataloader, optimizer, device='cuda', lambda_geom=0.1, lambda_soup=0.05, epsilon=8/255, p_norm='inf', ensemble_snapshots=None, use_mixed_precision=False):
    model.train()
    theta_soup = compute_theta_soup(ensemble_snapshots)
    total_loss, total_adv, total_geom, total_soup = 0.0, 0.0, 0.0, 0.0
    n_batches = 0
    
    scaler = torch.cuda.amp.GradScaler() if use_mixed_precision and device.type == 'cuda' else None
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        x.requires_grad_(True)
        
        if use_mixed_precision and scaler is not None:
            with torch.cuda.amp.autocast():
                logits = model(x)
                loss_clean = F.cross_entropy(logits, y)
                grad_x = torch.autograd.grad(loss_clean, x, create_graph=False)[0]
                delta_star = exact_delta_star(grad_x, epsilon, norm=p_norm)
                x_adv = torch.clamp(x + delta_star, 0.0, 1.0)
                logits_adv = model(x_adv)
                L_adv = F.cross_entropy(logits_adv, y)
                L_geom = gradient_subspace_penalty(model, x.detach(), y, ensemble_snapshots)
                L_soup = 0.0
                if theta_soup is not None:
                    for p, key in zip(model.parameters(), theta_soup.keys()):
                        p_s = theta_soup[key].to(device)
                        L_soup += ((p - p_s) ** 2).sum()
                L_total = L_adv + lambda_geom * L_geom + lambda_soup * L_soup
            
            optimizer.zero_grad()
            scaler.scale(L_total).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss_clean = F.cross_entropy(logits, y)
            grad_x = torch.autograd.grad(loss_clean, x, create_graph=False)[0]
            delta_star = exact_delta_star(grad_x, epsilon, norm=p_norm)
            x_adv = torch.clamp(x + delta_star, 0.0, 1.0)
            logits_adv = model(x_adv)
            L_adv = F.cross_entropy(logits_adv, y)
            L_geom = gradient_subspace_penalty(model, x.detach(), y, ensemble_snapshots)
            L_soup = 0.0
            if theta_soup is not None:
                for p, key in zip(model.parameters(), theta_soup.keys()):
                    p_s = theta_soup[key].to(device)
                    L_soup += ((p - p_s) ** 2).sum()
            L_total = L_adv + lambda_geom * L_geom + lambda_soup * L_soup
            
            optimizer.zero_grad()
            L_total.backward()
            optimizer.step()
        
        total_loss += L_total.item()
        total_adv += L_adv.item()
        total_geom += L_geom if isinstance(L_geom, float) else L_geom.item()
        total_soup += L_soup if isinstance(L_soup, float) else L_soup.item()
        n_batches += 1
    
    return {
        'loss': total_loss / n_batches,
        'adv_loss': total_adv / n_batches,
        'geom_loss': total_geom / n_batches,
        'soup_loss': total_soup / n_batches
    }


### 2.5 Attack Modules


In [None]:
# src/attacks/pgd.py
def pgd_attack(
    model, x, y, epsilon=8/255, alpha=2/255, iters=10, device=None, norm='inf', random_start=True, targeted=False
):
    """
    Generate adversarial examples using PGD (Projected Gradient Descent).
    """
    if device is None:
        device = x.device
    else:
        device = torch.device(device)
    
    model.eval()
    x = x.clone().detach().to(device)
    y = y.to(device)
    x_orig = x.clone()
    
    # Initialize adversarial example
    if random_start:
        if norm == 'inf':
            x_adv = x + torch.empty_like(x).uniform_(-epsilon, epsilon)
        else:  # L2 norm
            delta = torch.randn_like(x)
            delta_norm = delta.view(delta.size(0), -1).norm(p=2, dim=1, keepdim=True)
            delta = delta / (delta_norm.view(-1, 1, 1, 1) + 1e-10) * epsilon
            x_adv = x + delta
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
    else:
        x_adv = x.clone()
    
    # PGD iterations
    for i in range(iters):
        x_adv.requires_grad = True
        
        # Forward pass
        output = model(x_adv)
        loss = nn.CrossEntropyLoss()(output, y)
        
        if targeted:
            loss = -loss  # Minimize loss for targeted attack
        
        # Backward pass
        model.zero_grad()
        loss.backward()
        grad = x_adv.grad.detach()
        
        # Update adversarial example
        if norm == 'inf':
            # L_inf: sign of gradient
            x_adv = x_adv + alpha * grad.sign()
            # Project back to epsilon-ball
            x_adv = torch.min(torch.max(x_adv, x_orig - epsilon), x_orig + epsilon)
        else:  # L2 norm
            # L2: normalized gradient
            grad_norm = grad.view(grad.size(0), -1).norm(p=2, dim=1, keepdim=True)
            grad_normalized = grad / (grad_norm.view(-1, 1, 1, 1) + 1e-10)
            x_adv = x_adv + alpha * grad_normalized
            # Project back to epsilon-ball
            delta = x_adv - x_orig
            delta_norm = delta.view(delta.size(0), -1).norm(p=2, dim=1, keepdim=True)
            delta_normalized = delta / (delta_norm.view(-1, 1, 1, 1) + 1e-10)
            x_adv = x_orig + delta_normalized * torch.minimum(
                delta_norm.view(-1, 1, 1, 1) / epsilon,
                torch.ones_like(delta_norm.view(-1, 1, 1, 1))
            ) * epsilon
        
        # Clip to valid image range
        x_adv = torch.clamp(x_adv, 0.0, 1.0)
        x_adv = x_adv.detach()
    
    return x_adv


### 2.6 Evaluation Modules


In [None]:
# src/evaluation/metrics.py
def accuracy(model, dataloader, device=None):
    if device is None:
        device = next(model.parameters()).device
    else:
        device = torch.device(device)
    
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            preds = model(x).argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total if total > 0 else 0.0

def ece(model, dataloader, device=None, n_bins=15):
    if device is None:
        device = next(model.parameters()).device
    else:
        device = torch.device(device)
    
    model.eval()
    confidences, predictions, labels = [], [], []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            probs = torch.softmax(model(x), dim=1)
            conf, pred = probs.max(1)
            confidences.append(conf.cpu())
            predictions.append(pred.cpu())
            labels.append(y.cpu())
    
    confidences = torch.cat(confidences)
    predictions = torch.cat(predictions)
    labels = torch.cat(labels)
    
    bins = torch.linspace(0, 1, n_bins + 1)
    ece_score = 0.0
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if mask.sum() > 0:
            acc = (predictions[mask] == labels[mask]).float().mean()
            conf = confidences[mask].mean()
            ece_score += (mask.float().mean()) * abs(acc - conf)
    
    return ece_score.item()

# src/evaluation/gradient_similarity.py
def compute_input_gradients(model, dataloader, device=None, max_batches=8):
    if device is None:
        device = next(model.parameters()).device
    else:
        device = torch.device(device)
    
    model.eval()
    ce = torch.nn.CrossEntropyLoss()
    grads = []
    
    for i, (x, y) in enumerate(dataloader):
        if i >= max_batches:
            break
        x, y = x.to(device, non_blocking=True).requires_grad_(True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        model.zero_grad()
        loss.backward()
        g = x.grad.detach().view(x.size(0), -1).cpu().numpy()
        grads.append(g)
    
    if len(grads) == 0:
        return np.zeros((0, 0))
    
    grads = np.concatenate(grads, axis=0)
    return grads

def gradient_subspace_similarity(models, dataloader, device=None, max_batches=8):
    if device is None:
        device = next(models[0].parameters()).device
    else:
        device = torch.device(device)
    
    Gvecs = []
    for m in models:
        g = compute_input_gradients(m, dataloader, device=device, max_batches=max_batches)
        Gvecs.append(g.mean(axis=0) if g.size > 0 else np.zeros(1))
    
    L = len(Gvecs)
    S = np.zeros((L, L))
    norms = [np.linalg.norm(v) + 1e-12 for v in Gvecs]
    
    for i in range(L):
        for j in range(L):
            S[i, j] = np.dot(Gvecs[i], Gvecs[j]) / (norms[i] * norms[j])
    
    return S

# src/evaluation/ensemble_variance.py
def ensemble_variance(models, dataloader, device=None, max_batches=6):
    if device is None:
        device = next(models[0].parameters()).device
    else:
        device = torch.device(device)
    
    ce = torch.nn.CrossEntropyLoss(reduction='none')
    variances = []
    Kvals = list(range(1, len(models) + 1))
    
    for k in Kvals:
        batch_vars = []
        for i, (X, y) in enumerate(dataloader):
            if i >= max_batches:
                break
            per_losses = []
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            for m in models[:k]:
                m.eval()
                with torch.no_grad():
                    l = ce(m(X), y).cpu().numpy()
                    per_losses.append(l)
            per_losses = np.stack(per_losses, axis=0)
            batch_vars.append(np.mean(np.var(per_losses, axis=0)))
        variances.append(np.mean(batch_vars))
    
    return Kvals, variances

# src/evaluation/loss_landscape.py
def param_vector(model):
    return torch.nn.utils.parameters_to_vector(model.parameters()).detach().clone()

def set_param_vector(model, vec):
    torch.nn.utils.vector_to_parameters(vec, model.parameters())

def loss_on_vector(model, vec, X, y, device='cpu'):
    set_param_vector(model, vec)
    model.eval()
    with torch.no_grad():
        logits = model(X.to(device))
        loss = torch.nn.functional.cross_entropy(logits, y.to(device)).item()
    return loss

def scan_1d_loss(model, dataloader, device='cpu', grid_n=21, radius=1.0):
    Xs, ys = [], []
    for i, (X, y) in enumerate(dataloader):
        Xs.append(X); ys.append(y)
        if i >= 2: break
    X, y = torch.cat(Xs, dim=0), torch.cat(ys, dim=0)
    base = param_vector(model).to(device)
    D = base.numel()
    direction = torch.randn(D, device=device); direction /= direction.norm()
    alphas = np.linspace(-radius, radius, grid_n)
    losses = np.zeros(grid_n)
    for i, a in enumerate(alphas):
        vec = base + a*direction
        losses[i] = loss_on_vector(model, vec, X, y, device=device)
    set_param_vector(model, base)
    return alphas, losses

def scan_2d_loss(model, dataloader, device='cpu', grid_n=21, radius=1.0):
    Xs, ys = [], []
    for i, (X, y) in enumerate(dataloader):
        Xs.append(X); ys.append(y)
        if i >= 2: break
    X, y = torch.cat(Xs, dim=0), torch.cat(ys, dim=0)
    base = param_vector(model).to(device)
    D = base.numel()
    dir_a = torch.randn(D, device=device); dir_a /= dir_a.norm()
    dir_b = torch.randn(D, device=device); dir_b -= (dir_b.dot(dir_a))*dir_a; dir_b /= dir_b.norm()
    alphas = np.linspace(-radius, radius, grid_n)
    betas = np.linspace(-radius, radius, grid_n)
    losses = np.zeros((grid_n, grid_n))
    for i, a in enumerate(alphas):
        for j, b in enumerate(betas):
            vec = base + a*dir_a + b*dir_b
            losses[i,j] = loss_on_vector(model, vec, X, y, device=device)
    set_param_vector(model, base)
    return alphas, betas, losses

# src/evaluation/transferability.py
def transferability_matrix(models, dataloader, device=None, max_batches=10, epsilon=8/255, alpha=2/255, iters=10):
    if device is None:
        device = next(models[0].parameters()).device
    else:
        device = torch.device(device)
    
    L = len(models)
    counts = np.zeros((L, L))
    totals = np.zeros((L, L))
    
    for i, src_model in enumerate(models):
        src_model.eval()
        print(f"Generating adversarial examples from model {i+1}/{L}...")
        
        for b, (X, y) in enumerate(dataloader):
            if b >= max_batches:
                break
            
            X, y = X.to(device), y.to(device)
            
            # Generate adversarial examples using source model
            X_adv = pgd_attack(
                src_model, X, y,
                epsilon=epsilon,
                alpha=alpha,
                iters=iters,
                device=device
            )
            
            # Test transferability to all target models
            for j, tgt_model in enumerate(models):
                tgt_model.eval()
                with torch.no_grad():
                    preds = tgt_model(X_adv).argmax(dim=1)
                    incorrect = (preds != y).sum().item()
                    counts[i, j] += incorrect
                    totals[i, j] += y.size(0)
    
    # Compute transferability rates
    transferability = counts / (totals + 1e-12)
    
    return transferability

# src/evaluation/adversarial_images.py
def generate_adv_examples(src_model, tgt_model, dataloader, device=None, n_samples=24, epsilon=8/255, alpha=2/255, iters=10):
    if device is None:
        device = next(src_model.parameters()).device
    else:
        device = torch.device(device)
    
    src_model.eval()
    tgt_model.eval()
    imgs = []
    cnt = 0
    
    for X, y in dataloader:
        X = X.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        
        # Generate adversarial examples
        X_adv_src = pgd_attack(
            src_model, X, y,
            epsilon=epsilon,
            alpha=alpha,
            iters=iters,
            device=device
        )
        
        # Test on target model (optional, for verification)
        with torch.no_grad():
            _ = tgt_model(X_adv_src)
        
        # Convert to numpy
        X_np = X.cpu().numpy()
        X_adv_np = X_adv_src.cpu().numpy()
        
        for i in range(X_np.shape[0]):
            imgs.append((X_np[i], X_adv_np[i]))
            cnt += 1
            if cnt >= n_samples:
                break
        if cnt >= n_samples:
            break
    
    return imgs


In [None]:
# src/utils/viz_utils.py
def ensure_dir(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)

def save_heatmap(mat, path, title="", cmap="viridis"):
    ensure_dir(path)
    plt.figure(figsize=(8, 6))
    sns.heatmap(mat, cmap=cmap, square=True, annot=True, fmt='.3f')
    plt.title(title, fontsize=14)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

def save_contour(alphas, betas, losses, path, title="Loss surface"):
    ensure_dir(path)
    plt.figure(figsize=(8, 6))
    plt.contourf(betas, alphas, losses, levels=50, cmap="viridis")
    plt.colorbar()
    plt.title(title, fontsize=14)
    plt.xlabel("Beta Direction", fontsize=12)
    plt.ylabel("Alpha Direction", fontsize=12)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

def save_image_grid(imgs, path, nrow=8, normalize=True, title=None):
    ensure_dir(path)
    N = len(imgs)
    cols = nrow
    rows = (N + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].axis('off')
        if i < N:
            img = imgs[i]
            if normalize:
                img = (img + 1) / 2
            if len(img.shape) == 3 and img.shape[0] == 3:
                img = np.transpose(img, (1, 2, 0))
            elif len(img.shape) == 3 and img.shape[0] == 1:
                img = img[0]
            axes[i].imshow(np.clip(img, 0, 1), cmap='gray' if len(img.shape) == 2 else None)
    if title:
        plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

def save_line(xs, ys, path, xlabel="", ylabel="", title=""):
    ensure_dir(path)
    plt.figure(figsize=(8, 5))
    plt.plot(xs, ys, marker='o', linewidth=2, markersize=6)
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

# Display functions for Colab
def display_heatmap(mat, title="", cmap="viridis"):
    plt.figure(figsize=(8, 6))
    sns.heatmap(mat, cmap=cmap, square=True, annot=True, fmt='.3f')
    plt.title(title, fontsize=14)
    plt.tight_layout()
    plt.show()

def display_line(xs, ys, xlabel="", ylabel="", title=""):
    plt.figure(figsize=(8, 5))
    plt.plot(xs, ys, marker='o', linewidth=2, markersize=6)
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def display_contour(alphas, betas, losses, title="Loss surface"):
    plt.figure(figsize=(8, 6))
    plt.contourf(betas, alphas, losses, levels=50, cmap="viridis")
    plt.colorbar()
    plt.title(title, fontsize=14)
    plt.xlabel("Beta Direction", fontsize=12)
    plt.ylabel("Alpha Direction", fontsize=12)
    plt.tight_layout()
    plt.show()

def display_image_grid(imgs, nrow=8, normalize=True, title=None):
    N = len(imgs)
    cols = nrow
    rows = (N + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].axis('off')
        if i < N:
            img = imgs[i]
            if normalize:
                img = (img + 1) / 2
            if len(img.shape) == 3 and img.shape[0] == 3:
                img = np.transpose(img, (1, 2, 0))
            elif len(img.shape) == 3 and img.shape[0] == 1:
                img = img[0]
            axes[i].imshow(np.clip(img, 0, 1), cmap='gray' if len(img.shape) == 2 else None)
    if title:
        plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()


## 3. Run Complete Experiment

This section runs the full EGEAT experiment pipeline.


In [None]:
# Configure experiment
config = ExperimentConfig(
    experiment_name='egeat_colab',
    seed=42,
    device=None,  # Auto-detect
    save_dir='results',
    training=TrainingConfig(
        batch_size=128,
        learning_rate=2e-4,
        epochs=5,  # Reduced for Colab demo
        lambda_geom=0.1,
        lambda_soup=0.05,
        epsilon=8/255,
        use_mixed_precision=False,
        num_workers=2
    ),
    model=ModelConfig(
        model_type='SimpleCNN',
        input_channels=3,
        num_classes=10,
        ensemble_size=3  # Reduced for Colab demo
    ),
    data=DataConfig(
        dataset='cifar10',
        val_split=0.1,
        augment=True
    ),
    evaluation=EvaluationConfig(
        max_batches=10,
        n_bins_ece=15,
        loss_landscape_grid_n=21,
        loss_landscape_radius=1.0
    )
)

# Set up directories
os.makedirs(config.save_dir, exist_ok=True)
os.makedirs(os.path.join(config.save_dir, 'figures'), exist_ok=True)
os.makedirs(os.path.join(config.save_dir, 'adv_images'), exist_ok=True)
os.makedirs(os.path.join(config.save_dir, 'checkpoints'), exist_ok=True)

# Set seed
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)
import random
random.seed(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Get device
device = get_device(config.device, verbose=True)

print(f"\nExperiment Configuration:")
print(f"  Dataset: {config.data.dataset}")
print(f"  Model: {config.model.model_type}")
print(f"  Ensemble Size: {config.model.ensemble_size}")
print(f"  Epochs: {config.training.epochs}")
print(f"  Batch Size: {config.training.batch_size}")
print(f"  Lambda Geom: {config.training.lambda_geom}")
print(f"  Lambda Soup: {config.training.lambda_soup}")


## 4. Training Ensemble Models

This section trains each model in the ensemble using EGEAT.


In [None]:
# Load dataset
print(f"\nLoading {config.data.dataset} dataset...")
if config.data.dataset.lower() == 'cifar10':
    train_loader, val_loader, test_loader = get_cifar10_loaders(
        batch_size=config.training.batch_size,
        val_split=config.data.val_split,
        num_workers=config.training.num_workers,
        pin_memory=config.training.pin_memory,
        augment=config.data.augment,
        seed=config.seed
    )
    config.model.input_channels = 3
    config.model.num_classes = 10
elif config.data.dataset.lower() == 'mnist':
    train_loader, val_loader, test_loader = get_mnist_loaders(
        batch_size=config.training.batch_size,
        val_split=config.data.val_split,
        num_workers=config.training.num_workers,
        pin_memory=config.training.pin_memory,
        seed=config.seed
    )
    config.model.input_channels = 1
    config.model.num_classes = 10
else:
    raise ValueError(f"Unsupported dataset: {config.data.dataset}")

print(f"Dataset loaded. Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")


In [None]:
# Initialize ensemble models
print(f"\nInitializing {config.model.ensemble_size} {config.model.model_type} models...")
models = []
model_class = SimpleCNN if config.model.model_type == 'SimpleCNN' else DCGAN_CNN

for i in range(config.model.ensemble_size):
    model = model_class(
        input_channels=config.model.input_channels,
        num_classes=config.model.num_classes
    )
    model = model.to(device)
    models.append(model)

print(f"Initialized {len(models)} models on {device}")


## 4. Training Ensemble Models

This section trains each model in the ensemble using EGEAT.


In [None]:
# Train ensemble models
print("\n" + "="*60)
print("Starting Ensemble Training")
print("="*60)

training_history = {
    'models': [],
    'metrics': []
}

for i, model in enumerate(models):
    print(f"\n{'='*60}")
    print(f"Training Model {i+1}/{len(models)}")
    print(f"{'='*60}")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=config.training.learning_rate)
    
    model_history = {
        'epoch': [],
        'loss': [],
        'adv_loss': [],
        'geom_loss': [],
        'soup_loss': []
    }
    
    # Collect snapshots from previously trained models for ensemble regularization
    ensemble_snapshots = None
    if i > 0 and config.training.lambda_geom > 0:
        ensemble_snapshots = []
        for j in range(i):
            snapshot_path = os.path.join(config.save_dir, 'checkpoints', f'model_{j}_final.pt')
            if os.path.exists(snapshot_path):
                # Create a new model instance and load the state_dict
                snapshot_model = model_class(
                    input_channels=config.model.input_channels,
                    num_classes=config.model.num_classes
                )
                snapshot_model.load_state_dict(torch.load(snapshot_path, map_location=device))
                snapshot_model = snapshot_model.to(device)
                snapshot_model.eval()  # Set to eval mode
                ensemble_snapshots.append(snapshot_model)
    
    for epoch in range(config.training.epochs):
        print(f"\nEpoch {epoch+1}/{config.training.epochs}")
        
        # Train one epoch
        metrics = train_egeat_epoch(
            model=model,
            dataloader=train_loader,
            optimizer=optimizer,
            device=device,
            lambda_geom=config.training.lambda_geom,
            lambda_soup=config.training.lambda_soup,
            epsilon=config.training.epsilon,
            ensemble_snapshots=ensemble_snapshots,
            use_mixed_precision=config.training.use_mixed_precision
        )
        
        model_history['epoch'].append(epoch + 1)
        model_history['loss'].append(metrics['loss'])
        model_history['adv_loss'].append(metrics['adv_loss'])
        model_history['geom_loss'].append(metrics['geom_loss'])
        model_history['soup_loss'].append(metrics['soup_loss'])
        
        print(f"  Loss: {metrics['loss']:.4f}, "
              f"Adv: {metrics['adv_loss']:.4f}, "
              f"Geom: {metrics['geom_loss']:.4f}, "
              f"Soup: {metrics['soup_loss']:.4f}")
        
        # Save checkpoint
        checkpoint_dir = os.path.join(config.save_dir, 'checkpoints')
        os.makedirs(checkpoint_dir, exist_ok=True)
        torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'model_{i}_epoch_{epoch+1}.pt'))
    
    # Save final model
    torch.save(model.state_dict(), os.path.join(checkpoint_dir, f'model_{i}_final.pt'))
    
    training_history['models'].append(model_history)
    
    # Plot training curves for this model
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(model_history['epoch'], model_history['loss'], 'o-', label='Total Loss', linewidth=2)
    plt.plot(model_history['epoch'], model_history['adv_loss'], 's-', label='Adversarial Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title(f'Model {i+1} - Training Losses', fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.subplot(1, 2, 2)
    plt.plot(model_history['epoch'], model_history['geom_loss'], '^-', label='Geometric Loss', linewidth=2)
    plt.plot(model_history['epoch'], model_history['soup_loss'], 'v-', label='Soup Loss', linewidth=2)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.title(f'Model {i+1} - Regularization Losses', fontsize=14)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

print("\n" + "="*60)
print("Training Complete!")
print("="*60)


## 5. Evaluation and Metrics

This section computes and displays all evaluation metrics and visualizations.


In [None]:
# Evaluate ensemble models
print("\n" + "="*60)
print("Evaluating Ensemble Models")
print("="*60)

results = {}

# Individual model accuracies and ECE scores
print("\n" + "-"*60)
print("Computing Model Accuracies and ECE Scores...")
print("-"*60)

accuracies = []
ece_scores = []
for i, model in enumerate(models):
    acc = accuracy(model, test_loader, device=device)
    ece_score = ece(model, test_loader, device=device, n_bins=config.evaluation.n_bins_ece)
    accuracies.append(acc)
    ece_scores.append(ece_score)
    print(f"Model {i+1}: Accuracy={acc:.4f}, ECE={ece_score:.4f}")

results['accuracies'] = accuracies
results['ece_scores'] = ece_scores

# Display accuracy and ECE bar charts
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].bar(range(1, len(accuracies)+1), accuracies, color='steelblue', alpha=0.7)
axes[0].set_xlabel('Model', fontsize=12)
axes[0].set_ylabel('Accuracy', fontsize=12)
axes[0].set_title('Model Accuracies', fontsize=14)
axes[0].set_xticks(range(1, len(accuracies)+1))
axes[0].grid(True, alpha=0.3, axis='y')
for i, acc in enumerate(accuracies):
    axes[0].text(i+1, acc + 0.01, f'{acc:.3f}', ha='center', va='bottom')

axes[1].bar(range(1, len(ece_scores)+1), ece_scores, color='coral', alpha=0.7)
axes[1].set_xlabel('Model', fontsize=12)
axes[1].set_ylabel('ECE Score', fontsize=12)
axes[1].set_title('Expected Calibration Error (ECE)', fontsize=14)
axes[1].set_xticks(range(1, len(ece_scores)+1))
axes[1].grid(True, alpha=0.3, axis='y')
for i, ece in enumerate(ece_scores):
    axes[1].text(i+1, ece + 0.005, f'{ece:.3f}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

print(f"\nAverage Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")
print(f"Average ECE: {np.mean(ece_scores):.4f} ± {np.std(ece_scores):.4f}")


In [None]:
# Gradient subspace similarity
print("\n" + "-"*60)
print("Computing Gradient Subspace Similarity...")
print("-"*60)

sim_matrix = gradient_subspace_similarity(
    models, test_loader, device=device, max_batches=config.evaluation.max_batches
)

display_heatmap(
    sim_matrix,
    title="Gradient Subspace Similarity Matrix"
)

results['grad_similarity'] = sim_matrix.tolist()

print(f"\nAverage similarity: {np.mean(sim_matrix[np.triu_indices(len(sim_matrix), k=1)]):.4f}")


In [None]:
# Ensemble variance
print("\n" + "-"*60)
print("Computing Ensemble Variance...")
print("-"*60)

Kvals, vars_ = ensemble_variance(
    models, test_loader, device=device, max_batches=config.evaluation.max_batches
)

display_line(
    Kvals, vars_,
    xlabel="K (Number of Models)",
    ylabel="Variance",
    title="Ensemble Variance vs Number of Models"
)

results['ensemble_variance'] = {'K': Kvals, 'variance': vars_}

print(f"\nVariance with {len(models)} models: {vars_[-1]:.6f}")


In [None]:
# Transferability matrix
print("\n" + "-"*60)
print("Computing Transferability Matrix...")
print("-"*60)

P = transferability_matrix(
    models, test_loader, device=device, max_batches=config.evaluation.max_batches,
    epsilon=config.attack.epsilon, alpha=config.attack.alpha, iters=config.attack.pgd_iters
)

display_heatmap(
    P,
    title="Adversarial Transferability Matrix",
    cmap="Reds"
)

results['transferability'] = P.tolist()

print("\nTransferability Matrix:")
print("Rows: Source model (attacker)")
print("Columns: Target model (victim)")
print(f"Matrix:\n{P}")

# Compute average transferability (excluding diagonal)
mask = ~np.eye(len(P), dtype=bool)
avg_transfer = np.mean(P[mask])
print(f"\nAverage transferability (off-diagonal): {avg_transfer:.4f}")


In [None]:
# Adversarial examples
print("\n" + "-"*60)
print("Generating Adversarial Examples...")
print("-"*60)

adv_samples = generate_adv_examples(
    models[0], models[1] if len(models) > 1 else models[0],
    test_loader, device=device, n_samples=16, 
    epsilon=config.attack.epsilon, alpha=config.attack.alpha, iters=config.attack.pgd_iters
)

print("\nOriginal Images:")
display_image_grid(
    [orig for orig, _ in adv_samples[:16]],
    nrow=4, title="Original Images"
)

print("\nAdversarial Examples:")
display_image_grid(
    [adv for _, adv in adv_samples[:16]],
    nrow=4, title="Adversarial Examples"
)


In [None]:
# Loss landscapes
print("\n" + "-"*60)
print("Computing Loss Landscapes...")
print("-"*60)

print("Computing 1D loss landscape...")
alphas1D, losses1D = scan_1d_loss(
    models[0], test_loader, device=device,
    grid_n=config.evaluation.loss_landscape_grid_n,
    radius=config.evaluation.loss_landscape_radius
)

display_line(
    alphas1D, losses1D,
    xlabel="Alpha (Direction)", ylabel="Loss",
    title="1D Loss Landscape"
)

results['loss_1d'] = {'alphas': alphas1D.tolist(), 'losses': losses1D.tolist()}


In [None]:
print("Computing 2D loss landscape...")
alphas2D, betas2D, losses2D = scan_2d_loss(
    models[0], test_loader, device=device,
    grid_n=config.evaluation.loss_landscape_grid_n,
    radius=config.evaluation.loss_landscape_radius
)

display_contour(
    alphas2D, betas2D, losses2D,
    title="2D Loss Landscape"
)

results['loss_2d'] = {
    'alphas': alphas2D.tolist(),
    'betas': betas2D.tolist(),
    'losses': losses2D.tolist()
}


## 6. Summary and Results

This section provides a summary of all results.


In [None]:
# Save results and display summary
print("\n" + "="*60)
print("Experiment Summary")
print("="*60)

# Save results to JSON
results_path = os.path.join(config.save_dir, 'results.json')
with open(results_path, 'w') as f:
    json.dump({
        'config': config.to_dict(),
        'training_history': training_history,
        'evaluation_results': results
    }, f, indent=2)
print(f"\nResults saved to {results_path}")

# Display summary
print("\n" + "-"*60)
print("FINAL RESULTS SUMMARY")
print("-"*60)
print(f"\nDataset: {config.data.dataset}")
print(f"Ensemble Size: {len(models)}")
print(f"\nModel Accuracies:")
for i, acc in enumerate(accuracies):
    print(f"  Model {i+1}: {acc:.4f}")
print(f"\nAverage Accuracy: {np.mean(accuracies):.4f} ± {np.std(accuracies):.4f}")

print(f"\nECE Scores:")
for i, ece in enumerate(ece_scores):
    print(f"  Model {i+1}: {ece:.4f}")
print(f"\nAverage ECE: {np.mean(ece_scores):.4f} ± {np.std(ece_scores):.4f}")

if 'grad_similarity' in results:
    sim_matrix = np.array(results['grad_similarity'])
    mask = ~np.eye(len(sim_matrix), dtype=bool)
    avg_sim = np.mean(sim_matrix[mask])
    print(f"\nAverage Gradient Similarity: {avg_sim:.4f}")

if 'transferability' in results:
    P = np.array(results['transferability'])
    mask = ~np.eye(len(P), dtype=bool)
    avg_transfer = np.mean(P[mask])
    print(f"Average Transferability: {avg_transfer:.4f}")

if 'ensemble_variance' in results:
    vars_ = results['ensemble_variance']['variance']
    print(f"Final Ensemble Variance: {vars_[-1]:.6f}")

print("\n" + "="*60)
print("Experiment Completed Successfully!")
print("="*60)
