# EGEAT: Ensemble Gradient-based Ensemble Adversarial Training

This notebook runs the complete EGEAT experiment pipeline including:
- Ensemble model training
- Adversarial training with geometric regularization
- Comprehensive evaluation metrics
- Visualization of all diagrams and results


## 1. Setup and Installation


In [None]:
# Install required packages
!pip install torch torchvision numpy pandas matplotlib seaborn tqdm scikit-learn imageio -q


In [None]:
# Standard imports
import os
import sys
import json
import logging
import traceback
from pathlib import Path
from datetime import datetime
from typing import List, Dict, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split, TensorDataset
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
from copy import deepcopy

# Set up paths for Colab
WORK_DIR = '/content/EGEAT'
os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)

# Create directory structure
os.makedirs('src', exist_ok=True)
os.makedirs('src/attacks', exist_ok=True)
os.makedirs('src/evaluation', exist_ok=True)
os.makedirs('src/models', exist_ok=True)
os.makedirs('src/training', exist_ok=True)
os.makedirs('src/utils', exist_ok=True)

print(f"Working directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")


## 2. Source Code Setup

The following cells contain all the source code from the project organized by module.


### 2.1 Configuration Module


In [None]:
# src/config.py
from dataclasses import dataclass, asdict, field
from typing import Optional
import argparse

@dataclass
class TrainingConfig:
    batch_size: int = 128
    learning_rate: float = 2e-4
    epochs: int = 5
    lambda_geom: float = 0.1
    lambda_soup: float = 0.05
    epsilon: float = 8/255
    use_mixed_precision: bool = False
    num_workers: int = 2
    pin_memory: bool = True

@dataclass
class ModelConfig:
    model_type: str = 'SimpleCNN'
    input_channels: int = 3
    num_classes: int = 10
    ensemble_size: int = 5

@dataclass
class AttackConfig:
    epsilon: float = 8/255
    alpha: float = 2/255
    pgd_iters: int = 10
    cw_iters: int = 50
    cw_c: float = 1e-2
    cw_lr: float = 0.01

@dataclass
class EvaluationConfig:
    max_batches: int = 10
    n_bins_ece: int = 15
    loss_landscape_grid_n: int = 21
    loss_landscape_radius: float = 1.0

@dataclass
class DataConfig:
    dataset: str = 'cifar10'
    val_split: float = 0.1
    augment: bool = True
    data_path: str = './data'

@dataclass
class ExperimentConfig:
    experiment_name: str = 'egeat_experiment'
    seed: int = 42
    device: Optional[str] = None
    save_dir: str = 'results'
    resume: bool = False
    checkpoint_dir: Optional[str] = None
    training: TrainingConfig = field(default_factory=TrainingConfig)
    model: ModelConfig = field(default_factory=ModelConfig)
    attack: AttackConfig = field(default_factory=AttackConfig)
    evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
    data: DataConfig = field(default_factory=DataConfig)
    
    def to_dict(self):
        return asdict(self)
    
    @classmethod
    def from_dict(cls, d):
        training = TrainingConfig(**d.pop('training', {}))
        model = ModelConfig(**d.pop('model', {}))
        attack = AttackConfig(**d.pop('attack', {}))
        evaluation = EvaluationConfig(**d.pop('evaluation', {}))
        data = DataConfig(**d.pop('data', {}))
        return cls(training=training, model=model, attack=attack, evaluation=evaluation, data=data, **d)
    
    def save(self, path: str):
        os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True)
        with open(path, 'w') as f:
            json.dump(self.to_dict(), f, indent=2)
    
    @classmethod
    def load(cls, path: str):
        with open(path, 'r') as f:
            d = json.load(f)
        return cls.from_dict(d)

def get_device(device_str=None, verbose=False):
    if device_str is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    else:
        device = torch.device(device_str)
    if verbose:
        print(f"Using device: {device}")
    return device


### 2.2 Model Architectures


In [None]:
# src/models/cnn.py
class SimpleCNN(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 32, 3, 1, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1, 1)
        self.conv3 = nn.Conv2d(64, 128, 3, 1, 1)
        self.fc1 = nn.Linear(128*4*4, 512)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.conv3(x))
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class DCGAN_CNN(nn.Module):
    def __init__(self, input_channels=3, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, 64, 3, 1, 1)
        self.conv2 = nn.Conv2d(64, 128, 3, 1, 1)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(128)
        self.fc1 = nn.Linear(128*8*8, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = F.leaky_relu(self.bn1(self.conv1(x)), 0.2)
        x = F.max_pool2d(x, 2)
        x = F.leaky_relu(self.bn2(self.conv2(x)), 0.2)
        x = F.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


### 2.3 Data Loaders


In [None]:
# src/utils/data_loader.py
def seed_everything(seed=123):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)

def get_mnist_loaders(batch_size=128, val_split=0.1, num_workers=2, pin_memory=True, seed=123):
    seed_everything(seed)
    transform = transforms.Compose([transforms.ToTensor()])
    dataset = datasets.MNIST(root='./data/MNIST', train=True, download=True, transform=transform)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed))
    test_dataset = datasets.MNIST(root='./data/MNIST', train=False, download=True, transform=transform)
    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
    )

def get_cifar10_loaders(batch_size=128, val_split=0.1, num_workers=2, pin_memory=True, augment=True, seed=123):
    seed_everything(seed)
    if augment:
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    else:
        transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    dataset = datasets.CIFAR10(root='./data/CIFAR10', train=True, download=True, transform=transform_train)
    val_size = int(len(dataset) * val_split)
    train_size = len(dataset) - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size], generator=torch.Generator().manual_seed(seed))
    test_dataset = datasets.CIFAR10(root='./data/CIFAR10', train=False, download=True, transform=transform_test)
    return (
        DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
        DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory),
    )


### 2.4 Training Module (EGEAT)


In [None]:
# src/training/train_egeat.py
def exact_delta_star(grad, eps, norm='inf'):
    if norm == 'inf':
        return eps * grad.sign()
    g_flat = grad.view(grad.size(0), -1)
    if norm == 2:
        g_norm = g_flat.norm(p=2, dim=1).view(-1, 1, 1, 1) + 1e-12
        return eps * grad / g_norm
    if norm == 1:
        g_norm = g_flat.abs().max(dim=1)[0].view(-1, 1, 1, 1) + 1e-12
        return eps * grad / g_norm
    raise ValueError("Unsupported norm type.")

def gradient_subspace_penalty(model, x, y, ensemble_models):
    if not ensemble_models or len(ensemble_models) == 0:
        return 0.0
    grads = []
    x_req = x.clone().detach().requires_grad_(True)
    out = model(x_req)
    loss = F.cross_entropy(out, y)
    g = torch.autograd.grad(loss, x_req)[0].view(x.size(0), -1)
    grads.append(g)
    for m in ensemble_models:
        m.eval()
        x_snap = x.clone().detach().requires_grad_(True)
        out_snap = m(x_snap)
        loss_snap = F.cross_entropy(out_snap, y)
        g_snap = torch.autograd.grad(loss_snap, x_snap)[0].view(x.size(0), -1)
        grads.append(g_snap)
    L = len(grads)
    penalty = 0.0
    for i in range(L):
        for j in range(i + 1, L):
            Gi = grads[i]
            Gj = grads[j]
            num = (Gi * Gj).sum(dim=1).mean()
            den = (Gi.norm(p=2, dim=1) * Gj.norm(p=2, dim=1)).mean() + 1e-12
            penalty += num / den
    return penalty / (L * (L - 1) / 2)

def compute_theta_soup(ensemble_snapshots):
    if ensemble_snapshots is None or len(ensemble_snapshots) == 0:
        return None
    soup = deepcopy(ensemble_snapshots[0])
    for k in soup.keys():
        for s in ensemble_snapshots[1:]:
            soup[k] += s[k]
        soup[k] /= len(ensemble_snapshots)
    return soup

def train_egeat_epoch(model, dataloader, optimizer, device='cuda', lambda_geom=0.1, lambda_soup=0.05, epsilon=8/255, p_norm='inf', ensemble_snapshots=None, use_mixed_precision=False):
    model.train()
    theta_soup = compute_theta_soup(ensemble_snapshots)
    total_loss, total_adv, total_geom, total_soup = 0.0, 0.0, 0.0, 0.0
    n_batches = 0
    
    scaler = torch.cuda.amp.GradScaler() if use_mixed_precision and device.type == 'cuda' else None
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        x.requires_grad_(True)
        
        if use_mixed_precision and scaler is not None:
            with torch.cuda.amp.autocast():
                logits = model(x)
                loss_clean = F.cross_entropy(logits, y)
                grad_x = torch.autograd.grad(loss_clean, x, create_graph=False)[0]
                delta_star = exact_delta_star(grad_x, epsilon, norm=p_norm)
                x_adv = torch.clamp(x + delta_star, 0.0, 1.0)
                logits_adv = model(x_adv)
                L_adv = F.cross_entropy(logits_adv, y)
                L_geom = gradient_subspace_penalty(model, x.detach(), y, ensemble_snapshots)
                L_soup = 0.0
                if theta_soup is not None:
                    for p, key in zip(model.parameters(), theta_soup.keys()):
                        p_s = theta_soup[key].to(device)
                        L_soup += ((p - p_s) ** 2).sum()
                L_total = L_adv + lambda_geom * L_geom + lambda_soup * L_soup
            
            optimizer.zero_grad()
            scaler.scale(L_total).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits = model(x)
            loss_clean = F.cross_entropy(logits, y)
            grad_x = torch.autograd.grad(loss_clean, x, create_graph=False)[0]
            delta_star = exact_delta_star(grad_x, epsilon, norm=p_norm)
            x_adv = torch.clamp(x + delta_star, 0.0, 1.0)
            logits_adv = model(x_adv)
            L_adv = F.cross_entropy(logits_adv, y)
            L_geom = gradient_subspace_penalty(model, x.detach(), y, ensemble_snapshots)
            L_soup = 0.0
            if theta_soup is not None:
                for p, key in zip(model.parameters(), theta_soup.keys()):
                    p_s = theta_soup[key].to(device)
                    L_soup += ((p - p_s) ** 2).sum()
            L_total = L_adv + lambda_geom * L_geom + lambda_soup * L_soup
            
            optimizer.zero_grad()
            L_total.backward()
            optimizer.step()
        
        total_loss += L_total.item()
        total_adv += L_adv.item()
        total_geom += L_geom if isinstance(L_geom, float) else L_geom.item()
        total_soup += L_soup if isinstance(L_soup, float) else L_soup.item()
        n_batches += 1
    
    return {
        'loss': total_loss / n_batches,
        'adv_loss': total_adv / n_batches,
        'geom_loss': total_geom / n_batches,
        'soup_loss': total_soup / n_batches
    }


In [None]:
# src/evaluation/metrics.py
def accuracy(model, dataloader, device='cpu'):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            pred = logits.argmax(dim=1)
            correct += (pred == y).sum().item()
            total += y.size(0)
    return correct / total

def ece(model, dataloader, device='cpu', n_bins=15):
    model.eval()
    confidences = []
    accuracies = []
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            probs = F.softmax(logits, dim=1)
            conf, pred = probs.max(dim=1)
            confidences.extend(conf.cpu().numpy())
            accuracies.extend((pred == y).cpu().numpy())
    confidences = np.array(confidences)
    accuracies = np.array(accuracies)
    bins = np.linspace(0, 1, n_bins + 1)
    ece_score = 0.0
    for i in range(n_bins):
        mask = (confidences > bins[i]) & (confidences <= bins[i+1])
        if mask.sum() > 0:
            acc_bin = accuracies[mask].mean()
            conf_bin = confidences[mask].mean()
            ece_score += np.abs(acc_bin - conf_bin) * mask.sum()
    return ece_score / len(confidences)

# src/evaluation/gradient_similarity.py
def compute_input_gradients(model, dataloader, device=None, max_batches=8):
    if device is None:
        device = next(model.parameters()).device
    else:
        device = torch.device(device)
    model.eval()
    ce = torch.nn.CrossEntropyLoss()
    grads = []
    for i, (x, y) in enumerate(dataloader):
        if i >= max_batches:
            break
        x, y = x.to(device, non_blocking=True).requires_grad_(True), y.to(device, non_blocking=True)
        logits = model(x)
        loss = ce(logits, y)
        model.zero_grad()
        loss.backward()
        g = x.grad.detach().view(x.size(0), -1).cpu().numpy()
        grads.append(g)
    return np.concatenate(grads, axis=0)

def gradient_subspace_similarity(models, dataloader, device=None, max_batches=8):
    if device is None:
        device = next(models[0].parameters()).device
    else:
        device = torch.device(device)
    
    grad_matrices = []
    for model in models:
        G = compute_input_gradients(model, dataloader, device, max_batches)
        grad_matrices.append(G)
    
    n_models = len(models)
    similarity_matrix = np.zeros((n_models, n_models))
    
    for i in range(n_models):
        for j in range(n_models):
            Gi = grad_matrices[i]
            Gj = grad_matrices[j]
            U_i, _, _ = np.linalg.svd(Gi, full_matrices=False)
            U_j, _, _ = np.linalg.svd(Gj, full_matrices=False)
            k = min(U_i.shape[1], U_j.shape[1], 50)
            U_i_k = U_i[:, :k]
            U_j_k = U_j[:, :k]
            similarity = np.trace(U_i_k.T @ U_j_k) / k
            similarity_matrix[i, j] = similarity
    
    return similarity_matrix

# src/evaluation/ensemble_variance.py
def ensemble_variance(models, dataloader, device=None, max_batches=6):
    if device is None:
        device = next(models[0].parameters()).device
    else:
        device = torch.device(device)
    
    ce = torch.nn.CrossEntropyLoss(reduction='none')
    variances = []
    Kvals = list(range(1, len(models) + 1))
    
    for k in Kvals:
        batch_vars = []
        for i, (X, y) in enumerate(dataloader):
            if i >= max_batches:
                break
            per_losses = []
            X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
            for m in models[:k]:
                m.eval()
                with torch.no_grad():
                    l = ce(m(X), y).cpu().numpy()
                    per_losses.append(l)
            per_losses = np.stack(per_losses, axis=0)
            batch_vars.append(np.mean(np.var(per_losses, axis=0)))
        variances.append(np.mean(batch_vars))
    
    return Kvals, variances

# src/evaluation/loss_landscape.py
def param_vector(model):
    return torch.nn.utils.parameters_to_vector(model.parameters()).detach().clone()

def set_param_vector(model, vec):
    torch.nn.utils.vector_to_parameters(vec, model.parameters())

def loss_on_vector(model, vec, X, y, device='cpu'):
    set_param_vector(model, vec)
    model.eval()
    with torch.no_grad():
        logits = model(X.to(device))
        loss = torch.nn.functional.cross_entropy(logits, y.to(device)).item()
    return loss

def scan_1d_loss(model, dataloader, device='cpu', grid_n=21, radius=1.0):
    Xs, ys = [], []
    for i, (X, y) in enumerate(dataloader):
        Xs.append(X); ys.append(y)
        if i >= 2: break
    X, y = torch.cat(Xs, dim=0), torch.cat(ys, dim=0)
    base = param_vector(model).to(device)
    D = base.numel()
    direction = torch.randn(D, device=device); direction /= direction.norm()
    alphas = np.linspace(-radius, radius, grid_n)
    losses = np.zeros(grid_n)
    for i, a in enumerate(alphas):
        vec = base + a*direction
        losses[i] = loss_on_vector(model, vec, X, y, device=device)
    set_param_vector(model, base)
    return alphas, losses

def scan_2d_loss(model, dataloader, device='cpu', grid_n=21, radius=1.0):
    Xs, ys = [], []
    for i, (X, y) in enumerate(dataloader):
        Xs.append(X); ys.append(y)
        if i >= 2: break
    X, y = torch.cat(Xs, dim=0), torch.cat(ys, dim=0)
    base = param_vector(model).to(device)
    D = base.numel()
    dir_a = torch.randn(D, device=device); dir_a /= dir_a.norm()
    dir_b = torch.randn(D, device=device); dir_b -= (dir_b.dot(dir_a))*dir_a; dir_b /= dir_b.norm()
    alphas = np.linspace(-radius, radius, grid_n)
    betas = np.linspace(-radius, radius, grid_n)
    losses = np.zeros((grid_n, grid_n))
    for i, a in enumerate(alphas):
        for j, b in enumerate(betas):
            vec = base + a*dir_a + b*dir_b
            losses[i,j] = loss_on_vector(model, vec, X, y, device=device)
    set_param_vector(model, base)
    return alphas, betas, losses

# src/evaluation/transferability.py
def fgsm_attack(model, x, y, epsilon=8/255):
    x.requires_grad_(True)
    logits = model(x)
    loss = F.cross_entropy(logits, y)
    model.zero_grad()
    loss.backward()
    x_adv = x + epsilon * x.grad.sign()
    x_adv = torch.clamp(x_adv, 0, 1)
    return x_adv.detach()

def transferability_matrix(models, dataloader, device=None, max_batches=10, epsilon=8/255, alpha=2/255, iters=10):
    if device is None:
        device = next(models[0].parameters()).device
    else:
        device = torch.device(device)
    
    n_models = len(models)
    P = np.zeros((n_models, n_models))
    
    for src_idx in range(n_models):
        src_model = models[src_idx]
        src_model.eval()
        
        for tgt_idx in range(n_models):
            tgt_model = models[tgt_idx]
            tgt_model.eval()
            
            correct_orig = 0
            correct_adv = 0
            total = 0
            
            for i, (x, y) in enumerate(dataloader):
                if i >= max_batches:
                    break
                x, y = x.to(device), y.to(device)
                
                with torch.no_grad():
                    pred_orig = tgt_model(x).argmax(dim=1)
                    correct_orig += (pred_orig == y).sum().item()
                
                x_adv = fgsm_attack(src_model, x, y, epsilon)
                
                with torch.no_grad():
                    pred_adv = tgt_model(x_adv).argmax(dim=1)
                    correct_adv += (pred_adv == y).sum().item()
                
                total += y.size(0)
            
            acc_orig = correct_orig / total
            acc_adv = correct_adv / total
            transfer_rate = 1.0 - (acc_adv / acc_orig) if acc_orig > 0 else 0.0
            P[src_idx, tgt_idx] = transfer_rate
    
    return P

# src/evaluation/adversarial_images.py
def generate_adv_examples(src_model, tgt_model, dataloader, device=None, n_samples=16, epsilon=8/255, alpha=2/255, iters=10):
    if device is None:
        device = next(src_model.parameters()).device
    else:
        device = torch.device(device)
    
    src_model.eval()
    samples = []
    
    for x, y in dataloader:
        x, y = x.to(device), y.to(device)
        x_adv = fgsm_attack(src_model, x, y, epsilon)
        
        for i in range(min(n_samples - len(samples), x.size(0))):
            orig_img = x[i].cpu().numpy()
            adv_img = x_adv[i].cpu().numpy()
            samples.append((orig_img, adv_img))
        
        if len(samples) >= n_samples:
            break
    
    return samples


In [None]:
# src/utils/viz_utils.py
def ensure_dir(path):
    os.makedirs(os.path.dirname(path), exist_ok=True)

def save_heatmap(mat, path, title="", cmap="viridis"):
    ensure_dir(path)
    plt.figure(figsize=(8, 6))
    sns.heatmap(mat, cmap=cmap, square=True, annot=True, fmt='.3f')
    plt.title(title, fontsize=14)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

def save_contour(alphas, betas, losses, path, title="Loss surface"):
    ensure_dir(path)
    plt.figure(figsize=(8, 6))
    plt.contourf(betas, alphas, losses, levels=50, cmap="viridis")
    plt.colorbar()
    plt.title(title, fontsize=14)
    plt.xlabel("Beta Direction", fontsize=12)
    plt.ylabel("Alpha Direction", fontsize=12)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

def save_image_grid(imgs, path, nrow=8, normalize=True, title=None):
    ensure_dir(path)
    N = len(imgs)
    cols = nrow
    rows = (N + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].axis('off')
        if i < N:
            img = imgs[i]
            if normalize:
                img = (img + 1) / 2
            if len(img.shape) == 3 and img.shape[0] == 3:
                img = np.transpose(img, (1, 2, 0))
            elif len(img.shape) == 3 and img.shape[0] == 1:
                img = img[0]
            axes[i].imshow(np.clip(img, 0, 1), cmap='gray' if len(img.shape) == 2 else None)
    if title:
        plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

def save_line(xs, ys, path, xlabel="", ylabel="", title=""):
    ensure_dir(path)
    plt.figure(figsize=(8, 5))
    plt.plot(xs, ys, marker='o', linewidth=2, markersize=6)
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(path, dpi=300, bbox_inches='tight')
    plt.close()

# Display functions for Colab
def display_heatmap(mat, title="", cmap="viridis"):
    plt.figure(figsize=(8, 6))
    sns.heatmap(mat, cmap=cmap, square=True, annot=True, fmt='.3f')
    plt.title(title, fontsize=14)
    plt.tight_layout()
    plt.show()

def display_line(xs, ys, xlabel="", ylabel="", title=""):
    plt.figure(figsize=(8, 5))
    plt.plot(xs, ys, marker='o', linewidth=2, markersize=6)
    plt.xlabel(xlabel, fontsize=12)
    plt.ylabel(ylabel, fontsize=12)
    plt.title(title, fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

def display_contour(alphas, betas, losses, title="Loss surface"):
    plt.figure(figsize=(8, 6))
    plt.contourf(betas, alphas, losses, levels=50, cmap="viridis")
    plt.colorbar()
    plt.title(title, fontsize=14)
    plt.xlabel("Beta Direction", fontsize=12)
    plt.ylabel("Alpha Direction", fontsize=12)
    plt.tight_layout()
    plt.show()

def display_image_grid(imgs, nrow=8, normalize=True, title=None):
    N = len(imgs)
    cols = nrow
    rows = (N + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 1.5, rows * 1.5))
    if rows == 1:
        axes = axes.reshape(1, -1)
    axes = axes.flatten()
    for i in range(len(axes)):
        axes[i].axis('off')
        if i < N:
            img = imgs[i]
            if normalize:
                img = (img + 1) / 2
            if len(img.shape) == 3 and img.shape[0] == 3:
                img = np.transpose(img, (1, 2, 0))
            elif len(img.shape) == 3 and img.shape[0] == 1:
                img = img[0]
            axes[i].imshow(np.clip(img, 0, 1), cmap='gray' if len(img.shape) == 2 else None)
    if title:
        plt.suptitle(title, fontsize=14)
    plt.tight_layout()
    plt.show()


## 3. Run Complete Experiment

This section runs the full EGEAT experiment pipeline.


In [None]:
# Configure experiment
config = ExperimentConfig(
    experiment_name='egeat_colab',
    seed=42,
    device=None,  # Auto-detect
    save_dir='results',
    training=TrainingConfig(
        batch_size=128,
        learning_rate=2e-4,
        epochs=5,  # Reduced for Colab demo
        lambda_geom=0.1,
        lambda_soup=0.05,
        epsilon=8/255,
        use_mixed_precision=False,
        num_workers=2
    ),
    model=ModelConfig(
        model_type='SimpleCNN',
        input_channels=3,
        num_classes=10,
        ensemble_size=3  # Reduced for Colab demo
    ),
    data=DataConfig(
        dataset='cifar10',
        val_split=0.1,
        augment=True
    ),
    evaluation=EvaluationConfig(
        max_batches=10,
        n_bins_ece=15,
        loss_landscape_grid_n=21,
        loss_landscape_radius=1.0
    )
)

# Set up directories
os.makedirs(config.save_dir, exist_ok=True)
os.makedirs(os.path.join(config.save_dir, 'figures'), exist_ok=True)
os.makedirs(os.path.join(config.save_dir, 'adv_images'), exist_ok=True)
os.makedirs(os.path.join(config.save_dir, 'checkpoints'), exist_ok=True)

# Set seed
torch.manual_seed(config.seed)
torch.cuda.manual_seed_all(config.seed)
np.random.seed(config.seed)
import random
random.seed(config.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Get device
device = get_device(config.device, verbose=True)

print(f"\nExperiment Configuration:")
print(f"  Dataset: {config.data.dataset}")
print(f"  Model: {config.model.model_type}")
print(f"  Ensemble Size: {config.model.ensemble_size}")
print(f"  Epochs: {config.training.epochs}")
print(f"  Batch Size: {config.training.batch_size}")
print(f"  Lambda Geom: {config.training.lambda_geom}")
print(f"  Lambda Soup: {config.training.lambda_soup}")


In [None]:
# Load dataset
print(f"\nLoading {config.data.dataset} dataset...")
if config.data.dataset.lower() == 'cifar10':
    train_loader, val_loader, test_loader = get_cifar10_loaders(
        batch_size=config.training.batch_size,
        val_split=config.data.val_split,
        num_workers=config.training.num_workers,
        pin_memory=config.training.pin_memory,
        augment=config.data.augment,
        seed=config.seed
    )
    config.model.input_channels = 3
    config.model.num_classes = 10
elif config.data.dataset.lower() == 'mnist':
    train_loader, val_loader, test_loader = get_mnist_loaders(
        batch_size=config.training.batch_size,
        val_split=config.data.val_split,
        num_workers=config.training.num_workers,
        pin_memory=config.training.pin_memory,
        seed=config.seed
    )
    config.model.input_channels = 1
    config.model.num_classes = 10
else:
    raise ValueError(f"Unsupported dataset: {config.data.dataset}")

print(f"Dataset loaded. Train batches: {len(train_loader)}, Val batches: {len(val_loader)}, Test batches: {len(test_loader)}")


In [None]:
# Initialize ensemble models
print(f"\nInitializing {config.model.ensemble_size} {config.model.model_type} models...")
models = []
model_class = SimpleCNN if config.model.model_type == 'SimpleCNN' else DCGAN_CNN

for i in range(config.model.ensemble_size):
    model = model_class(
        input_channels=config.model.input_channels,
        num_classes=config.model.num_classes
    )
    model = model.to(device)
    models.append(model)

print(f"Initialized {len(models)} models on {device}")
