
# EGEAT — Exact Geometric Ensemble Adversarial Training

**Implements Algorithm 1 exactly** with Eq. (11) (exact inner maximization), Eq. (4) (geometric regularizer), and Eq. (5) (weight‑space smoothing).  
Datasets: **MNIST (ε=0.3)** and **CIFAR‑10 (ε=8/255)**. DCGAN‑inspired CNNs.  
Training defaults (Sec. V‑C): **K=5, λ₁=0.1, λ₂=0.05, batch=128, Adam(2e‑4, β₁=0.5), epochs=100**.  
Evaluations: Clean, FGSM, PGD‑20; gradient similarity; loss surface; transferability; ablation.


In [None]:

import os, math, random, time, copy, itertools
from dataclasses import dataclass
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from torchvision import datasets, transforms

seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "/content" if os.path.exists("/content") else "."
os.makedirs(SAVE_DIR, exist_ok=True)
print("Device:", device, "| SAVE_DIR:", SAVE_DIR)

# ===== DARK LINKEDIN-STYLE THEME =====
# High-contrast dark theme for professional presentations
DARK_THEME = {
    "figure.facecolor": "#0A0E27",      # Deep dark blue-black background
    "axes.facecolor": "#0F1629",        # Slightly lighter for axes
    "axes.edgecolor": "#4A90E2",        # LinkedIn blue for edges
    "axes.labelcolor": "#E8F0FE",       # Light blue-white for labels
    "xtick.color": "#B8C5D6",          # Soft gray-blue for ticks
    "ytick.color": "#B8C5D6",
    "text.color": "#E8F0FE",            # Light text
    "font.size": 14,
    "font.weight": "medium",
    "axes.linewidth": 1.5,
    "axes.grid": True,
    "grid.color": "#1A2332",
    "grid.alpha": 0.3,
    "grid.linewidth": 0.8,
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "savefig.facecolor": "#0A0E27",
    "savefig.edgecolor": "none",
}

mpl.rcParams.update(DARK_THEME)
sns.set_style("darkgrid", {"axes.facecolor": "#0F1629"})

# Professional color palette (high contrast, colorblind-friendly)
PALETTE = {
    "primary": "#0077B5",      # LinkedIn blue
    "secondary": "#00A0DC",    # Light LinkedIn blue
    "accent1": "#FF6B35",       # Coral orange
    "accent2": "#4ECDC4",       # Turquoise
    "accent3": "#FFE66D",       # Yellow
    "accent4": "#A8E6CF",       # Mint green
    "success": "#28A745",       # Green
    "warning": "#FFC107",       # Amber
    "error": "#DC3545",         # Red
    "text": "#E8F0FE",          # Light text
    "text_secondary": "#B8C5D6", # Secondary text
}

COLORS = [PALETTE["primary"], PALETTE["accent1"], PALETTE["accent2"], 
          PALETTE["accent3"], PALETTE["accent4"], PALETTE["secondary"]]

def savefig(name, dpi=300, bbox_inches="tight"):
    """Save figure with dark theme background"""
    path = os.path.join(SAVE_DIR, name)
    plt.savefig(path, dpi=dpi, bbox_inches=bbox_inches, 
                facecolor=DARK_THEME["figure.facecolor"],
                edgecolor="none")
    print(f"✓ Saved: {path}")

def apply_dark_style(ax=None):
    """Apply dark theme to current axes"""
    if ax is None:
        ax = plt.gca()
    ax.set_facecolor(DARK_THEME["axes.facecolor"])
    ax.spines['bottom'].set_color(PALETTE["primary"])
    ax.spines['top'].set_color(PALETTE["primary"])
    ax.spines['right'].set_color(PALETTE["primary"])
    ax.spines['left'].set_color(PALETTE["primary"])
    ax.tick_params(colors=PALETTE["text_secondary"])
    ax.xaxis.label.set_color(PALETTE["text"])
    ax.yaxis.label.set_color(PALETTE["text"])
    ax.title.set_color(PALETTE["text"])
    return ax

print("✓ Dark LinkedIn-style theme initialized")


In [None]:

from dataclasses import dataclass

@dataclass
class DataConfig:
    name: str = "CIFAR10"  # "MNIST" or "CIFAR10"
    batch_size: int = 128

def get_loaders(name="CIFAR10", batch_size=128):
    if name.upper() == "MNIST":
        eps = 0.3
        tf = transforms.ToTensor()
        train = datasets.MNIST(root=SAVE_DIR, train=True, download=True, transform=tf)
        test  = datasets.MNIST(root=SAVE_DIR, train=False, download=True, transform=tf)
        num_classes = 10; in_ch = 1; img_sz = 28
    elif name.upper() == "CIFAR10":
        eps = 8/255
        tf_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        tf_test = transforms.ToTensor()
        train = datasets.CIFAR10(root=SAVE_DIR, train=True, download=True, transform=tf_train)
        test  = datasets.CIFAR10(root=SAVE_DIR, train=False, download=True, transform=tf_test)
        num_classes = 10; in_ch = 3; img_sz = 32
    else:
        raise ValueError("Unknown dataset")

    val_len = int(0.2 * len(train))
    train_len = len(train) - val_len
    gen = torch.Generator().manual_seed(seed)
    train_ds, val_ds = random_split(train, [train_len, val_len], generator=gen)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test,     batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader, test_loader, num_classes, in_ch, img_sz, eps

DATASET = "CIFAR10"
train_loader, val_loader, test_loader, NUM_CLASSES, IN_CH, IMG_SZ, EPS = get_loaders(DATASET, 128)
print(f"Dataset: {DATASET} | eps(Linf)={EPS} | classes={NUM_CLASSES} | in_ch={IN_CH} | img={IMG_SZ}x{IMG_SZ}")


In [None]:

def conv_block(in_c, out_c, k=3, s=1, p=1):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p, bias=False),
        nn.BatchNorm2d(out_c),
        nn.LeakyReLU(0.2, inplace=True),
    )

class CNN_MNIST(nn.Module):
    def __init__(self, in_ch=1, num_classes=10):
        super().__init__()
        self.f = nn.Sequential(
            conv_block(in_ch, 32, 3, 1, 1),
            conv_block(32, 64, 3, 2, 1),
            conv_block(64, 128, 3, 2, 1),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.h = nn.Linear(128, num_classes)
    def forward(self, x):
        z = self.f(x)
        return self.h(z.view(z.size(0), -1))

class CNN_CIFAR10(nn.Module):
    def __init__(self, in_ch=3, num_classes=10):
        super().__init__()
        self.f = nn.Sequential(
            conv_block(in_ch, 64, 3, 1, 1),
            conv_block(64, 64, 3, 2, 1),
            conv_block(64, 128, 3, 1, 1),
            conv_block(128, 128, 3, 2, 1),
            conv_block(128, 256, 3, 1, 1),
            conv_block(256, 256, 3, 2, 1),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.h = nn.Linear(256, num_classes)
    def forward(self, x):
        z = self.f(x)
        return self.h(z.view(z.size(0), -1))

def make_model(dataset):
    return (CNN_MNIST if dataset.upper()=="MNIST" else CNN_CIFAR10)(IN_CH, NUM_CLASSES).to(device)

def accuracy(model, loader):
    model.eval(); tot=ok=0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            ok += (model(xb).argmax(1) == yb).sum().item()
            tot += yb.size(0)
    return ok/tot


In [None]:

def exact_perturbation(x, y, model, loss_fn, eps=0.3, p='linf'):
    x = x.detach().clone().to(device).requires_grad_(True)
    model.zero_grad(set_to_none=True)
    loss = loss_fn(model(x), y.to(device)); loss.backward()
    g = x.grad.detach()
    if p == 'linf':
        delta = eps * g.sign()
    elif p == 'l2':
        g_flat = g.view(g.size(0), -1)
        nrm = g_flat.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)
        delta = (eps * (g_flat / nrm)).view_as(g)
    else:
        raise ValueError("Use p in {linf,l2}")
    return (x + delta).clamp(0,1).detach()

def fgsm_attack(x, y, model, loss_fn, eps): return exact_perturbation(x, y, model, loss_fn, eps=eps, p='linf')

def pgd_attack(x, y, model, loss_fn, eps, alpha, steps):
    x0 = x.detach().clone().to(device)
    x_adv = (x0 + torch.empty_like(x0).uniform_(-eps, eps)).clamp(0,1).detach()
    for _ in range(steps):
        x_adv = x_adv.clone().detach().requires_grad_(True)
        model.zero_grad(set_to_none=True)
        loss = loss_fn(model(x_adv), y.to(device)); loss.backward()
        g = x_adv.grad.detach()
        x_adv = x_adv + alpha * g.sign()
        eta = torch.clamp(x_adv - x0, -eps, eps)
        x_adv = (x0 + eta).clamp(0,1).detach()
    return x_adv


In [None]:

def batch_input_grads(models, x, y, loss_fn):
    grads = []
    for m in models:
        m.eval(); m.zero_grad(set_to_none=True)
        x_ = x.detach().clone().to(device).requires_grad_(True)
        loss = loss_fn(m(x_), y.to(device)); loss.backward()
        grads.append(x_.grad.detach())
    return grads

def cos_sim(a, b, eps=1e-12):
    a = a.view(a.size(0), -1); b = b.view(b.size(0), -1)
    num = (a*b).sum(dim=1); den = a.norm(p=2, dim=1)*b.norm(p=2, dim=1)+eps
    return (num/den).mean()

def geometric_regularizer(models, x, y, loss_fn):
    if len(models) < 2: return torch.tensor(0.0, device=device)
    grads = batch_input_grads(models, x, y, loss_fn)
    sims = []
    for i in range(len(models)):
        for j in range(i+1, len(models)):
            sims.append(cos_sim(grads[i], grads[j]))
    return torch.stack(sims).mean()

def update_soup(snapshots):
    if not snapshots: return None
    base = copy.deepcopy(snapshots[0]).to(device)
    with torch.no_grad():
        for p in base.parameters(): p.data.zero_()
        for s in snapshots:
            for pb, ps in zip(base.parameters(), s.parameters()):
                pb.add_(ps.data)
        for p in base.parameters(): p.data.div_(len(snapshots))
    return base


In [None]:

from dataclasses import dataclass

@dataclass
class EGEATConfig:
    epochs: int = 20          
    eps: float = float(EPS)
    lambda_geom: float = 0.1
    lambda_soup: float = 0.05
    snapshots_k: int = 3      
    lr: float = 3e-4          
    beta1: float = 0.5

def train_egeat(cfg: EGEATConfig):
    model = make_model(DATASET)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
    loss_fn = nn.CrossEntropyLoss()
    snapshots=[]; soup=None
    snap_every = max(1, cfg.epochs // cfg.snapshots_k)
    for epoch in range(1, cfg.epochs+1):
        model.train(); tot=n=0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            xb_adv = exact_perturbation(xb, yb, model, loss_fn, eps=cfg.eps, p='linf')
            L_adv = loss_fn(model(xb_adv), yb)
            models_for_geom = [model] + ([soup] if soup is not None else [])
            L_geom = geometric_regularizer(models_for_geom, xb, yb, loss_fn)
            if soup is not None:
                L_soup = torch.tensor(0.0, device=device)
                for p, ps in zip(model.parameters(), soup.parameters()):
                    L_soup += (p-ps).pow(2).sum()
            else:
                L_soup = torch.tensor(0.0, device=device)
            loss = L_adv + cfg.lambda_geom*L_geom + cfg.lambda_soup*L_soup
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            tot += loss.item()*xb.size(0); n += xb.size(0)
        if epoch % snap_every == 0 or epoch == cfg.epochs:
            snapshots.append(copy.deepcopy(model).to(device))
            soup = update_soup(snapshots)
        print(f"[EGEAT] {epoch}/{cfg.epochs} loss={tot/max(1,n):.4f} val_acc={accuracy(model,val_loader):.3f} snaps={len(snapshots)}")
    final_soup = update_soup(snapshots) if snapshots else copy.deepcopy(model)
    return model, final_soup

egeat_model, egeat_soup = train_egeat(EGEATConfig())
print("Validation (EGEAT model):", accuracy(egeat_model, val_loader))
print("Validation (EGEAT soup): ", accuracy(egeat_soup,  val_loader))


In [None]:

@dataclass
class PGDCfg:
    epochs: int = 20                
    eps: float = float(EPS)
    alpha: float = float(EPS) / 4
    steps: int = 7                  
    lr: float = 3e-4                
    beta1: float = 0.5

def train_pgd(cfg: PGDCfg):
    model = make_model(DATASET)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, cfg.epochs+1):
        model.train(); tot=n=0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            xb_adv = pgd_attack(xb, yb, model, loss_fn, eps=cfg.eps, alpha=cfg.alpha, steps=cfg.steps)
            loss = loss_fn(model(xb_adv), yb)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            tot += loss.item()*xb.size(0); n += xb.size(0)
        print(f"[PGD] {epoch}/{cfg.epochs} loss={tot/max(1,n):.4f} val_acc={accuracy(model,val_loader):.3f}")
    return model

pgd_model = train_pgd(PGDCfg())
print("Validation (PGD):", accuracy(pgd_model, val_loader))


In [None]:

def eval_adv_acc(model, loader, attack='fgsm', eps=EPS, alpha=None, steps=20):
    model.eval(); loss_fn = nn.CrossEntropyLoss(); tot=ok=0
    for xb, yb in loader:
        if attack=='fgsm':
            xb_adv = fgsm_attack(xb, yb, model, loss_fn, eps=eps)
        else:
            a = alpha if alpha is not None else eps/4
            xb_adv = pgd_attack(xb, yb, model, loss_fn, eps=eps, alpha=a, steps=steps)
        with torch.no_grad():
            pred = model(xb_adv.to(device)).argmax(1).cpu()
        ok += (pred == yb).sum().item(); tot += yb.size(0)
    return ok/tot

# Comprehensive evaluation metrics
clean_e = accuracy(egeat_model, test_loader)
fgsm_e = eval_adv_acc(egeat_model, test_loader, 'fgsm')
pgd_e = eval_adv_acc(egeat_model, test_loader, 'pgd', steps=20)
clean_s = accuracy(egeat_soup, test_loader)
fgsm_s = eval_adv_acc(egeat_soup, test_loader, 'fgsm')
pgd_s = eval_adv_acc(egeat_soup, test_loader, 'pgd', steps=20)
clean_p = accuracy(pgd_model, test_loader)
fgsm_p = eval_adv_acc(pgd_model, test_loader, 'fgsm')
pgd_p = eval_adv_acc(pgd_model, test_loader, 'pgd', steps=20)

# Display metrics in a professional table format
print("\n" + "="*70)
print("COMPREHENSIVE EVALUATION METRICS")
print("="*70)
print(f"{'Model':<20} {'Clean Acc':<12} {'FGSM Acc':<12} {'PGD-20 Acc':<12}")
print("-"*70)
print(f"{'EGEAT Model':<20} {clean_e:.4f}      {fgsm_e:.4f}      {pgd_e:.4f}")
print(f"{'EGEAT Soup':<20} {clean_s:.4f}      {fgsm_s:.4f}      {pgd_s:.4f}")
print(f"{'PGD Model':<20} {clean_p:.4f}      {fgsm_p:.4f}      {pgd_p:.4f}")
print("="*70)

# Visual metrics comparison
fig, ax = plt.subplots(figsize=(12, 7))
models = ['EGEAT\nModel', 'EGEAT\nSoup', 'PGD\nModel']
x = np.arange(len(models))
width = 0.25

clean_accs = [clean_e, clean_s, clean_p]
fgsm_accs = [fgsm_e, fgsm_s, fgsm_p]
pgd_accs = [pgd_e, pgd_s, pgd_p]

bars1 = ax.bar(x - width, clean_accs, width, label='Clean', color=PALETTE["primary"], 
               edgecolor=PALETTE["text"], linewidth=1.5, alpha=0.9)
bars2 = ax.bar(x, fgsm_accs, width, label='FGSM', color=PALETTE["accent1"], 
               edgecolor=PALETTE["text"], linewidth=1.5, alpha=0.9)
bars3 = ax.bar(x + width, pgd_accs, width, label='PGD-20', color=PALETTE["accent2"], 
               edgecolor=PALETTE["text"], linewidth=1.5, alpha=0.9)

# Add value labels
for bars in [bars1, bars2, bars3]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{height:.3f}', ha='center', va='bottom', 
               color=PALETTE["text"], fontsize=10, fontweight='bold')

ax.set_xlabel('Model', fontsize=14, color=PALETTE["text"], fontweight='medium')
ax.set_ylabel('Accuracy', fontsize=14, color=PALETTE["text"], fontweight='medium')
ax.set_title('Model Performance: Clean vs Adversarial Accuracy', 
            fontsize=16, fontweight='bold', color=PALETTE["text"], pad=20)
ax.set_xticks(x)
ax.set_xticklabels(models, color=PALETTE["text"], fontsize=12)
ax.legend(loc='upper right', frameon=True, facecolor=DARK_THEME["axes.facecolor"], 
         edgecolor=PALETTE["primary"], labelcolor=PALETTE["text"], fontsize=12)
ax.set_ylim(0, 1.0)
apply_dark_style(ax)
plt.tight_layout()
savefig("fig_comprehensive_metrics.png", dpi=300)
plt.show()


In [None]:

# Gradient similarity - Enhanced with dark theme
loss_fn = nn.CrossEntropyLoss()
models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT Model", "EGEAT Soup", "PGD Model"]
xb, yb = next(iter(test_loader)); xb, yb = xb.to(device), yb.to(device)
grads = []
for m in models:
    m.zero_grad(set_to_none=True)
    x_ = xb.detach().clone().requires_grad_(True)
    loss = loss_fn(m(x_), yb); loss.backward()
    grads.append(x_.grad.detach())

def pairwise_cos(grads):
    M=len(grads); mat=torch.zeros(M,M)
    for i in range(M):
        for j in range(M):
            mat[i,j] = cos_sim(grads[i], grads[j]).detach().cpu()
    return mat.numpy()

G = pairwise_cos(grads)
fig, ax = plt.subplots(figsize=(8, 7))
im = ax.imshow(G, interpolation='nearest', cmap='viridis', vmin=0, vmax=1)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, rotation=45, ha='right', color=PALETTE["text"])
ax.set_yticks(range(len(names)))
ax.set_yticklabels(names, color=PALETTE["text"])
ax.set_title("Gradient Subspace Similarity Matrix", fontsize=16, fontweight='bold', color=PALETTE["text"], pad=20)

# Add text annotations
for i in range(len(names)):
    for j in range(len(names)):
        text = ax.text(j, i, f'{G[i, j]:.3f}', ha="center", va="center", 
                      color="white" if G[i, j] < 0.5 else "black", fontweight='bold')

cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Cosine Similarity', color=PALETTE["text"], fontsize=12)
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"])
cbar.outline.set_edgecolor(PALETTE["primary"])

apply_dark_style(ax)
plt.tight_layout()
savefig("fig_gradient_similarity.png", dpi=300)
plt.show()


In [None]:

# Loss surface
def get_vec(m): return torch.cat([p.detach().view(-1) for p in m.parameters()])
def set_vec(m, v):
    i=0
    for p in m.parameters():
        n=p.numel(); p.data.copy_(v[i:i+n].view_as(p)); i+=n

base = copy.deepcopy(egeat_model).eval().to(device)
w = get_vec(base); d1, d2 = torch.randn_like(w), torch.randn_like(w); d1/=d1.norm()+1e-12; d2/=d2.norm()+1e-12
grid=21; alphas=torch.linspace(-0.5,0.5,grid); betas=torch.linspace(-0.5,0.5,grid); Z=np.zeros((grid,grid),dtype=np.float32)

with torch.no_grad():
    xb_s, yb_s = xb[:min(128, xb.size(0))], yb[:min(128, yb.size(0))]
    for i,a in enumerate(alphas):
        for j,b in enumerate(betas):
            tmp = copy.deepcopy(base).to(device); set_vec(tmp, w + a*d1 + b*d2)
            Z[i,j] = nn.CrossEntropyLoss()(tmp(xb_s), yb_s).item()

# Loss surface - Enhanced with dark theme
fig, ax = plt.subplots(figsize=(10, 8))
cs = ax.contourf(alphas.cpu(), betas.cpu(), Z.T, levels=30, cmap='plasma', alpha=0.9)
contours = ax.contour(alphas.cpu(), betas.cpu(), Z.T, levels=15, colors='white', linewidths=1.2, alpha=0.4)
ax.clabel(contours, inline=True, fontsize=9, colors='white', fmt='%.2f')
ax.set_xlabel("α (Direction 1)", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax.set_ylabel("β (Direction 2)", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax.set_title("Loss Landscape Around EGEAT Model", fontsize=16, fontweight='bold', color=PALETTE["text"], pad=20)
cbar = plt.colorbar(cs, ax=ax)
cbar.set_label('Cross-Entropy Loss', color=PALETTE["text"], fontsize=12, fontweight='medium')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"])
cbar.outline.set_edgecolor(PALETTE["primary"])
apply_dark_style(ax)
plt.tight_layout()
savefig("fig_loss_surface.png", dpi=300)
plt.show()


In [None]:

# Adversarial example grid
loss_fn = nn.CrossEntropyLoss()
idxs = torch.arange(0, min(10, xb.size(0))); xv, yv = xb[idxs].detach().cpu(), yb[idxs].detach().cpu()
x_fgsm_pgd = fgsm_attack(xv, yv, pgd_model, loss_fn, eps=EPS).cpu()
x_exact_e  = exact_perturbation(xv, yv, egeat_model, loss_fn, eps=EPS, p='linf').cpu()

def show_triplet(orig, a, b, title_a="FGSM (PGD)", title_b="Exact (EGEAT)"):
    n=orig.size(0)
    fig, axes = plt.subplots(3, n, figsize=(1.5*n, 4.5))
    fig.patch.set_facecolor(DARK_THEME["figure.facecolor"])
    
    for i in range(n):
        im0=orig[i].permute(1,2,0).squeeze().numpy()
        im1=a[i].permute(1,2,0).squeeze().numpy()
        im2=b[i].permute(1,2,0).squeeze().numpy()
        
        # Handle grayscale images
        if len(im0.shape) == 2:
            im0 = np.stack([im0]*3, axis=-1)
        if len(im1.shape) == 2:
            im1 = np.stack([im1]*3, axis=-1)
        if len(im2.shape) == 2:
            im2 = np.stack([im2]*3, axis=-1)
        
        axes[0,i].imshow(np.clip(im0, 0, 1), vmin=0, vmax=1)
        axes[0,i].axis("off")
        axes[0,i].set_facecolor(DARK_THEME["axes.facecolor"])
        
        axes[1,i].imshow(np.clip(im1, 0, 1), vmin=0, vmax=1)
        axes[1,i].axis("off")
        axes[1,i].set_facecolor(DARK_THEME["axes.facecolor"])
        
        axes[2,i].imshow(np.clip(im2, 0, 1), vmin=0, vmax=1)
        axes[2,i].axis("off")
        axes[2,i].set_facecolor(DARK_THEME["axes.facecolor"])
    
    axes[0,0].set_ylabel("Original", fontsize=12, color=PALETTE["text"], fontweight='bold', rotation=0, ha='right', va='center')
    axes[1,0].set_ylabel(title_a, fontsize=12, color=PALETTE["accent1"], fontweight='bold', rotation=0, ha='right', va='center')
    axes[2,0].set_ylabel(title_b, fontsize=12, color=PALETTE["accent2"], fontweight='bold', rotation=0, ha='right', va='center')
    
    fig.suptitle("Adversarial Examples Comparison", fontsize=16, fontweight='bold', color=PALETTE["text"], y=0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    savefig("fig_adversarial_grid.png", dpi=300)
    plt.show()

show_triplet(xv, x_fgsm_pgd, x_exact_e)


In [None]:

# Transferability
def transfer_rate(src, tgt, loader, eps=EPS):
    loss_fn = nn.CrossEntropyLoss(); total=fooled=0
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, src, loss_fn, eps=eps)
        with torch.no_grad(): pred = tgt(xb_adv.to(device)).argmax(1).cpu()
        fooled += (pred != yb).sum().item(); total += yb.size(0)
    return fooled/total

models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT model", "EGEAT soup", "PGD model"]
pairs = [(0,1),(0,2),(1,2)]; rates=[]
for i,j in pairs:
    r = transfer_rate(models[i], models[j], test_loader, eps=EPS); rates.append(r); print(f"{names[i]}→{names[j]}: {r:.3f}")

# Transferability - Enhanced with dark theme
fig, ax = plt.subplots(figsize=(10, 6))
xs = np.arange(len(pairs))
bars = ax.bar(xs, rates, color=[COLORS[i % len(COLORS)] for i in range(len(rates))], 
              edgecolor=PALETTE["primary"], linewidth=2, alpha=0.85)
ax.set_xticks(xs)
ax.set_xticklabels([f"{names[i]} → {names[j]}" for i,j in pairs], 
                   rotation=25, ha='right', color=PALETTE["text"], fontsize=12)
ax.set_ylabel("Transfer Rate $P_T$", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax.set_title("Adversarial Transferability Across Models", fontsize=16, fontweight='bold', 
            color=PALETTE["text"], pad=20)
ax.set_ylim(0, max(rates) * 1.15)

# Add value labels on bars
for i, (bar, rate) in enumerate(zip(bars, rates)):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
           f'{rate:.3f}', ha='center', va='bottom', 
           color=PALETTE["text"], fontsize=11, fontweight='bold')

apply_dark_style(ax)
plt.tight_layout()
savefig("fig_transfer.png", dpi=300)
plt.show()


In [None]:

# Ablation
import pandas as pd
ABL_SETTINGS=[(0.00,0.00),(0.10,0.00),(0.10,0.05),(0.20,0.05)]
ABL_EPOCHS=20  # set to 100 for strict replication

def quick_egeat(lg, ls, epochs=ABL_EPOCHS):
    cfg = EGEATConfig(epochs=epochs, eps=EPS, lambda_geom=lg, lambda_soup=ls, snapshots_k=5, lr=2e-4)
    m,s = train_egeat(cfg)
    clean = accuracy(m, test_loader)
    pgd20 = eval_adv_acc(m, test_loader, 'pgd', eps=EPS, steps=20)
    m.eval(); ent=[]
    with torch.no_grad():
        for xb, yb in test_loader:
            p = F.softmax(m(xb.to(device)), dim=1).cpu().numpy()
            ent.append(-(p*np.log(p+1e-12)).sum(axis=1))
    ent = np.concatenate(ent); ece_proxy=float(ent.mean())
    return clean, pgd20, ece_proxy

rows=[]
for lg, ls in ABL_SETTINGS:
    print(f"[Ablation] λ1={lg} λ2={ls}")
    c,p20,ece = quick_egeat(lg, ls, epochs=ABL_EPOCHS)
    rows.append((lg, ls, c, p20, ece))

df = pd.DataFrame(rows, columns=["lambda1","lambda2","Acc_clean","Acc_PGD20","ECE_proxy"])
print(df)
csv_path = os.path.join(SAVE_DIR, "table_ablation_results.csv"); df.to_csv(csv_path, index=False); print("Saved:", csv_path)

# Ablation study - Enhanced with dark theme
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

# Plot 1: Effect of λ1 on PGD-20 accuracy
ax1.plot(df["lambda1"], df["Acc_PGD20"], marker="o", linewidth=3, markersize=10, 
        color=PALETTE["primary"], markerfacecolor=PALETTE["accent1"], 
        markeredgecolor=PALETTE["primary"], markeredgewidth=2)
ax1.set_xlabel("λ₁ (Geometric Regularization)", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax1.set_ylabel("PGD-20 Accuracy", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax1.set_title("Effect of Geometric Regularization", fontsize=15, fontweight='bold', 
             color=PALETTE["text"], pad=15)
ax1.grid(True, alpha=0.3, color=PALETTE["text_secondary"])
apply_dark_style(ax1)

# Plot 2: Clean vs Robust accuracy trade-off
ax2.scatter(df["Acc_clean"], df["Acc_PGD20"], s=200, alpha=0.8, 
           c=[COLORS[i % len(COLORS)] for i in range(len(df))],
           edgecolors=PALETTE["primary"], linewidths=2)
for i, row in df.iterrows():
    ax2.annotate(f'λ₁={row["lambda1"]:.2f}\nλ₂={row["lambda2"]:.2f}', 
                (row["Acc_clean"], row["Acc_PGD20"]),
                xytext=(5, 5), textcoords='offset points',
                fontsize=10, color=PALETTE["text_secondary"])
ax2.set_xlabel("Clean Accuracy", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax2.set_ylabel("PGD-20 Accuracy", fontsize=14, color=PALETTE["text"], fontweight='medium')
ax2.set_title("Clean vs Robust Accuracy Trade-off", fontsize=15, fontweight='bold', 
             color=PALETTE["text"], pad=15)
ax2.grid(True, alpha=0.3, color=PALETTE["text_secondary"])
apply_dark_style(ax2)

plt.suptitle("Ablation Study: Hyperparameter Effects", fontsize=17, fontweight='bold', 
            color=PALETTE["text"], y=1.02)
plt.tight_layout()
savefig("fig_ablation.png", dpi=300)
plt.show()



# === BLOG / LINKEDIN VISUALS — Modern Showcase ===

The following cells generate **publication-quality** and **social-ready** visuals (1080×1080) with a modern dark theme.
They reuse the trained models (`egeat_model`, `egeat_soup`, `pgd_model`) and the dataset/test loader from above.
All outputs are saved to: `os.path.join(SAVE_DIR, "blog_visuals")`.


In [None]:

# Showcase visualizations directory (uses existing dark theme from cell 1)
BLOG_DIR = os.path.join(SAVE_DIR, "blog_visuals")
os.makedirs(BLOG_DIR, exist_ok=True)

def save_square_png(name, fig=None, size=12):
    """Save square figure for social media (1080×1080 equivalent)"""
    if fig is None:
        fig = plt.gcf()
    fig.set_size_inches(size, size)
    path = os.path.join(BLOG_DIR, name)
    fig.savefig(path, bbox_inches="tight", dpi=300, 
                facecolor=DARK_THEME["figure.facecolor"],
                edgecolor="none")
    print(f"✓ Saved showcase: {path}")

print(f"✓ Showcase directory ready: {BLOG_DIR}")


In [None]:

# 1) Loss Landscape "Evolution": sweep along two random directions around EGEAT θ
# Produces a strip of frames that can be combined into a GIF (kept simple as a strip for portability).
import copy, torch, torch.nn as nn

def get_vec(m): 
    return torch.cat([p.detach().view(-1) for p in m.parameters()])

def set_vec(m, v):
    i=0
    for p in m.parameters():
        n = p.numel()
        p.data.copy_(v[i:i+n].view_as(p)); i+=n

base = copy.deepcopy(egeat_model).eval().to(device)
w = get_vec(base)
d1, d2 = torch.randn_like(w), torch.randn_like(w)
d1 /= (d1.norm()+1e-12); d2 /= (d2.norm()+1e-12)

grid = 9
alphas = torch.linspace(-0.6, 0.6, grid)
betas  = torch.linspace(-0.6, 0.6, grid)
Z = np.zeros((grid, grid), dtype=np.float32)

xb, yb = next(iter(test_loader))
xb, yb = xb.to(device), yb.to(device)

with torch.no_grad():
    xb_s, yb_s = xb[:min(256, xb.size(0))], yb[:min(256, yb.size(0))]
    for i,a in enumerate(alphas):
        for j,b in enumerate(betas):
            tmp = copy.deepcopy(base).to(device)
            set_vec(tmp, w + a*d1 + b*d2)
            Z[i,j] = nn.CrossEntropyLoss()(tmp(xb_s), yb_s).item()

# Render as a modern contour with bright strokes - Enhanced dark theme
fig, ax = plt.subplots(figsize=(10, 10))
cs = ax.contourf(alphas.cpu(), betas.cpu(), Z.T, levels=30, cmap='plasma', alpha=0.95)
contours = ax.contour(alphas.cpu(), betas.cpu(), Z.T, levels=15, colors='white', 
                     linewidths=1.5, alpha=0.5)
ax.clabel(contours, inline=True, fontsize=10, colors='white', fmt='%.2f', fontweight='bold')
ax.set_xlabel("α (Direction 1)", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_ylabel("β (Direction 2)", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_title("Loss Landscape Around EGEAT Model", fontsize=18, fontweight='bold', 
            color=PALETTE["text"], pad=25)
cbar = plt.colorbar(cs, ax=ax)
cbar.set_label('Cross-Entropy Loss', color=PALETTE["text"], fontsize=14, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=12)
cbar.outline.set_edgecolor(PALETTE["primary"])
cbar.outline.set_linewidth(2)
apply_dark_style(ax)
plt.tight_layout()
save_square_png("loss_landscape_showcase.png", fig=fig, size=10)
plt.show()


In [None]:

# 2) Gradient Flow Constellation: PCA of ∇_x ℓ over test mini-batches for multiple models
from sklearn.decomposition import PCA
import torch

loss_fn = nn.CrossEntropyLoss()
models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT", "Soup", "PGD"]

# Sample multiple mini-batches to build a gradient cloud
grads_all = []
labels_all = []
num_batches = 6
it = iter(test_loader)
for b in range(num_batches):
    try:
        xb, yb = next(it)
    except StopIteration:
        break
    xb, yb = xb.to(device), yb.to(device)
    for idx, m in enumerate(models):
        m.zero_grad(set_to_none=True)
        x_ = xb.detach().clone().requires_grad_(True)
        loss = loss_fn(m(x_), yb); loss.backward()
        g = x_.grad.detach().view(x_.size(0), -1).cpu().numpy()
        grads_all.append(g)
        labels_all += [names[idx]] * g.shape[0]

G = np.concatenate(grads_all, axis=0)
pca = PCA(n_components=2, random_state=42).fit(G)
P = pca.transform(G)

# Gradient constellation - Enhanced dark theme
fig, ax = plt.subplots(figsize=(10, 10))
for name, color in zip(names, COLORS[:len(names)]):
    mask = np.array(labels_all) == name
    ax.scatter(P[mask,0], P[mask,1], s=25, alpha=0.7, label=name, c=color, 
              edgecolors=PALETTE["primary"], linewidths=0.5)
ax.legend(frameon=True, facecolor=DARK_THEME["axes.facecolor"], 
         edgecolor=PALETTE["primary"], labelcolor=PALETTE["text"], 
         fontsize=14, loc='best', framealpha=0.9)
ax.set_title("Gradient Constellation (PCA) — Decorrelated Subspaces", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)
ax.set_xlabel("Principal Component 1", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_ylabel("Principal Component 2", fontsize=16, color=PALETTE["text"], fontweight='bold')
apply_dark_style(ax)
plt.tight_layout()
save_square_png("gradient_constellation.png", fig=fig, size=10)
plt.show()


In [None]:

# 3) Adversarial Transfer Graph: edges weighted by transfer rate P_T
import networkx as nx

def transfer_rate(source_model, target_model, loader, eps=float(EPS)):
    loss_fn = nn.CrossEntropyLoss()
    total=fooled=0
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, source_model, loss_fn, eps=eps)
        with torch.no_grad():
            pred_t = target_model(xb_adv.to(device)).argmax(1).cpu()
        fooled += (pred_t != yb).sum().item()
        total += yb.size(0)
    return fooled/total

models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT", "Soup", "PGD"]

Gx = nx.DiGraph()
for n in names:
    Gx.add_node(n)

edges = []
for i, si in enumerate(names):
    for j, tj in enumerate(names):
        if i==j: continue
        r = transfer_rate(models[i], models[j], test_loader, eps=float(EPS))
        Gx.add_edge(si, tj, weight=r)
        edges.append((si, tj, r))

# Layout & draw - Enhanced dark theme
fig, ax = plt.subplots(figsize=(10, 10))
pos = nx.circular_layout(Gx)
weights = [Gx[u][v]['weight'] for u,v in Gx.edges()]
w_scaled = [(1.0 + 8*w) for w in weights]  # emphasize thickness
nx.draw_networkx_nodes(Gx, pos, node_size=2800, node_color=COLORS[:3], 
                       linewidths=3, edgecolors=PALETTE["primary"], ax=ax)
nx.draw_networkx_labels(Gx, pos, font_size=18, font_color=PALETTE["text"], 
                        font_weight='bold', ax=ax)
nx.draw_networkx_edges(Gx, pos, width=w_scaled, edge_color=PALETTE["accent1"], 
                       arrows=True, arrowsize=25, connectionstyle='arc3,rad=0.15',
                       alpha=0.8, ax=ax)

# Annotate edge weights
for (u,v,r) in edges:
    x=(pos[u][0]+pos[v][0])/2
    y=(pos[u][1]+pos[v][1])/2
    ax.text(x, y, f"{r:.3f}", ha="center", va="center", 
           fontsize=13, color=PALETTE["text"], fontweight='bold',
           bbox=dict(boxstyle='round,pad=0.3', facecolor=DARK_THEME["axes.facecolor"], 
                    edgecolor=PALETTE["primary"], alpha=0.8))

ax.set_title("Adversarial Transfer Graph\n(Lower Transfer Rate = Better Robustness)", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)
ax.axis("off")
apply_dark_style(ax)
plt.tight_layout()
save_square_png("transfer_graph.png", fig=fig, size=10)
plt.show()


In [None]:

# 4) Variance vs Ensemble Size K: empirical variance of adversarial loss across K-model soups
# If snapshot list isn't available from training, we quickly generate extra snapshots by light finetuning copies.
import copy, torch

def collect_snapshots(base_model, k=5, steps=100):
    snaps = [copy.deepcopy(base_model).to(device).eval()]
    opt = torch.optim.SGD(base_model.parameters(), lr=1e-3, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    it = iter(train_loader)
    for i in range(1, k):
        # light finetune from previous snapshot to diversify
        for _ in range(steps):
            try:
                xb, yb = next(it)
            except StopIteration:
                it = iter(train_loader)
                xb, yb = next(it)
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = loss_fn(base_model(xb), yb); loss.backward(); opt.step()
        snaps.append(copy.deepcopy(base_model).to(device).eval())
    return snaps

def soup(models):
    base = copy.deepcopy(models[0]).to(device)
    with torch.no_grad():
        for p in base.parameters(): p.data.zero_()
        for m in models:
            for pb, pm in zip(base.parameters(), m.parameters()):
                pb.add_(pm.data)
        for p in base.parameters(): p.data.div_(len(models))
    return base

def adv_loss_on_loader(model, loader, eps=float(EPS)):
    loss_fn = nn.CrossEntropyLoss()
    total=[]; 
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, model, loss_fn, eps=eps)
        with torch.no_grad():
            l = loss_fn(model(xb_adv.to(device)), yb.to(device)).item()
        total.append(l)
    return np.array(total)

# collect snapshots from EGEAT model
snaps = collect_snapshots(copy.deepcopy(egeat_model).to(device).train(), k=5, steps=50)
Ks = [1,2,3,4,5]
variances = []
for k in Ks:
    S = soup(snaps[:k])
    losses = adv_loss_on_loader(S, test_loader, eps=float(EPS))
    variances.append(float(np.var(losses)))

# Variance vs ensemble size - Enhanced dark theme
fig, ax = plt.subplots(figsize=(10, 8))
ax.plot(Ks, variances, marker="o", linewidth=4, markersize=14, 
       color=PALETTE["primary"], markerfacecolor=PALETTE["accent1"], 
       markeredgecolor=PALETTE["primary"], markeredgewidth=3, alpha=0.9)
ax.fill_between(Ks, variances, alpha=0.3, color=PALETTE["primary"])
ax.set_xlabel("Ensemble Size K", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_ylabel("Var[ Adversarial Loss ]", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_title("Variance vs Ensemble Size (EGEAT Parameter Soups)", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)

# Add value labels
for k, var in zip(Ks, variances):
    ax.text(k, var, f'{var:.4f}', ha='center', va='bottom', 
           color=PALETTE["text"], fontsize=12, fontweight='bold')

apply_dark_style(ax)
plt.tight_layout()
save_square_png("variance_vs_k.png", fig=fig, size=10)
plt.show()


In [None]:

# 5) 3D λ1–λ2 trade-off surface (clean vs robust accuracy)
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import torch

grid_l1 = np.linspace(0.0, 0.2, 4)   # [0.00, 0.067, 0.133, 0.2]
grid_l2 = np.linspace(0.00, 0.08, 4) # [0.00, 0.027, 0.053, 0.08]

Acc_clean = np.zeros((len(grid_l1), len(grid_l2)))
Acc_pgd20 = np.zeros((len(grid_l1), len(grid_l2)))

def train_quick(l1, l2, epochs=15):
    cfg = EGEATConfig(epochs=epochs, eps=float(EPS), lambda_geom=float(l1), lambda_soup=float(l2), snapshots_k=3, lr=2e-4)
    m, _ = train_egeat(cfg)
    ac = accuracy(m, test_loader)
    ar = eval_adv_acc(m, test_loader, attack='pgd', eps=float(EPS), steps=20)
    return ac, ar

for i, l1 in enumerate(grid_l1):
    for j, l2 in enumerate(grid_l2):
        print(f"[λ-surface] λ1={l1:.3f}, λ2={l2:.3f}")
        ac, ar = train_quick(l1, l2, epochs=12)
        Acc_clean[i,j] = ac; Acc_pgd20[i,j] = ar

# Clean surface - Enhanced dark theme
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
L1, L2 = np.meshgrid(grid_l1, grid_l2, indexing='ij')
surf = ax.plot_surface(L1, L2, Acc_clean, linewidth=0, antialiased=True, 
                      alpha=0.95, cmap="viridis", edgecolor='none')
ax.set_xlabel("λ₁ (Geometric)", fontsize=14, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_ylabel("λ₂ (Soup)", fontsize=14, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_zlabel("Clean Accuracy", fontsize=14, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_title("Hyperparameter Surface: Clean Accuracy", fontsize=16, fontweight='bold', 
            color=PALETTE["text"], pad=20)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"])
cbar = fig.colorbar(surf, ax=ax, shrink=0.6)
cbar.set_label('Clean Accuracy', color=PALETTE["text"], fontsize=12, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"])
save_square_png("lambda_surface_clean.png", fig=fig, size=10)
plt.show()

# Robust surface - Enhanced dark theme
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(L1, L2, Acc_pgd20, linewidth=0, antialiased=True, 
                      alpha=0.95, cmap="plasma", edgecolor='none')
ax.set_xlabel("λ₁ (Geometric)", fontsize=14, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_ylabel("λ₂ (Soup)", fontsize=14, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_zlabel("PGD-20 Accuracy", fontsize=14, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_title("Hyperparameter Surface: Robust Accuracy (PGD-20)", fontsize=16, fontweight='bold', 
            color=PALETTE["text"], pad=20)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"])
cbar = fig.colorbar(surf, ax=ax, shrink=0.6)
cbar.set_label('PGD-20 Accuracy', color=PALETTE["text"], fontsize=12, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"])
save_square_png("lambda_surface_robust.png", fig=fig, size=10)
plt.show()


## Summary: All Results and Metrics

This section provides a comprehensive summary of all evaluation results with professional visualizations.


In [None]:
# ===== COMPREHENSIVE RESULTS SUMMARY =====
# Display all metrics in a professional dashboard format

print("\n" + "="*80)
print(" " * 20 + "EGEAT EXPERIMENT SUMMARY")
print("="*80)

# Recompute all metrics for summary
all_metrics = {
    'EGEAT Model': {
        'Clean': accuracy(egeat_model, test_loader),
        'FGSM': eval_adv_acc(egeat_model, test_loader, 'fgsm'),
        'PGD-20': eval_adv_acc(egeat_model, test_loader, 'pgd', steps=20)
    },
    'EGEAT Soup': {
        'Clean': accuracy(egeat_soup, test_loader),
        'FGSM': eval_adv_acc(egeat_soup, test_loader, 'fgsm'),
        'PGD-20': eval_adv_acc(egeat_soup, test_loader, 'pgd', steps=20)
    },
    'PGD Model': {
        'Clean': accuracy(pgd_model, test_loader),
        'FGSM': eval_adv_acc(pgd_model, test_loader, 'fgsm'),
        'PGD-20': eval_adv_acc(pgd_model, test_loader, 'pgd', steps=20)
    }
}

# Print formatted table
print(f"\n{'Model':<20} {'Clean Acc':<15} {'FGSM Acc':<15} {'PGD-20 Acc':<15}")
print("-"*80)
for model_name, metrics in all_metrics.items():
    print(f"{model_name:<20} {metrics['Clean']:<15.4f} {metrics['FGSM']:<15.4f} {metrics['PGD-20']:<15.4f}")
print("="*80)

# Create comprehensive dashboard visualization
fig = plt.figure(figsize=(16, 10))
fig.patch.set_facecolor(DARK_THEME["figure.facecolor"])

# Create a 2x3 grid of subplots
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

# 1. Accuracy comparison (bar chart)
ax1 = fig.add_subplot(gs[0, 0])
models_list = list(all_metrics.keys())
clean_vals = [all_metrics[m]['Clean'] for m in models_list]
fgsm_vals = [all_metrics[m]['FGSM'] for m in models_list]
pgd_vals = [all_metrics[m]['PGD-20'] for m in models_list]
x = np.arange(len(models_list))
width = 0.25
bars1 = ax1.bar(x - width, clean_vals, width, label='Clean', color=PALETTE["primary"], 
                edgecolor=PALETTE["text"], linewidth=1.5, alpha=0.9)
bars2 = ax1.bar(x, fgsm_vals, width, label='FGSM', color=PALETTE["accent1"], 
                edgecolor=PALETTE["text"], linewidth=1.5, alpha=0.9)
bars3 = ax1.bar(x + width, pgd_vals, width, label='PGD-20', color=PALETTE["accent2"], 
                edgecolor=PALETTE["text"], linewidth=1.5, alpha=0.9)
ax1.set_ylabel('Accuracy', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax1.set_title('Accuracy Comparison', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax1.set_xticks(x)
ax1.set_xticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=10)
ax1.legend(frameon=True, facecolor=DARK_THEME["axes.facecolor"], edgecolor=PALETTE["primary"], 
          labelcolor=PALETTE["text"], fontsize=10)
ax1.set_ylim(0, 1.0)
apply_dark_style(ax1)

# 2. Robustness comparison (radar/spider chart alternative - bar chart)
ax2 = fig.add_subplot(gs[0, 1])
robustness = [all_metrics[m]['PGD-20'] for m in models_list]
bars = ax2.barh(models_list, robustness, color=[COLORS[i % len(COLORS)] for i in range(len(models_list))],
               edgecolor=PALETTE["primary"], linewidth=2, alpha=0.85)
for i, (bar, val) in enumerate(zip(bars, robustness)):
    ax2.text(val + 0.01, i, f'{val:.3f}', va='center', color=PALETTE["text"], 
            fontsize=11, fontweight='bold')
ax2.set_xlabel('PGD-20 Accuracy', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax2.set_title('Robustness Ranking', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax2.set_xlim(0, 1.0)
apply_dark_style(ax2)

# 3. Performance improvement over baseline
ax3 = fig.add_subplot(gs[0, 2])
baseline_pgd = all_metrics['PGD Model']['PGD-20']
improvements = [
    (all_metrics['EGEAT Model']['PGD-20'] - baseline_pgd) * 100,
    (all_metrics['EGEAT Soup']['PGD-20'] - baseline_pgd) * 100
]
improvement_names = ['EGEAT\nModel', 'EGEAT\nSoup']
colors_imp = [PALETTE["success"] if imp > 0 else PALETTE["error"] for imp in improvements]
bars = ax3.barh(improvement_names, improvements, color=colors_imp, 
               edgecolor=PALETTE["primary"], linewidth=2, alpha=0.85)
for i, (bar, imp) in enumerate(zip(bars, improvements)):
    ax3.text(imp + (0.5 if imp > 0 else -0.5), i, f'{imp:+.2f}%', va='center', 
            ha='left' if imp > 0 else 'right', color=PALETTE["text"], 
            fontsize=11, fontweight='bold')
ax3.axvline(0, color=PALETTE["text_secondary"], linestyle='--', linewidth=1.5)
ax3.set_xlabel('Improvement (%)', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax3.set_title('Improvement over PGD Baseline', fontsize=14, fontweight='bold', color=PALETTE["text"])
apply_dark_style(ax3)

# 4. Transferability heatmap
ax4 = fig.add_subplot(gs[1, :2])
transfer_matrix = np.zeros((len(models_list), len(models_list)))
for i, src in enumerate(models_list):
    for j, tgt in enumerate(models_list):
        if i != j:
            transfer_matrix[i, j] = transfer_rate(
                [egeat_model, egeat_soup, pgd_model][i],
                [egeat_model, egeat_soup, pgd_model][j],
                test_loader, eps=EPS
            )
        else:
            transfer_matrix[i, j] = 1.0  # Self-attack (always 100%)

im = ax4.imshow(transfer_matrix, cmap='RdYlGn_r', vmin=0, vmax=1, aspect='auto')
ax4.set_xticks(range(len(models_list)))
ax4.set_xticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=10)
ax4.set_yticks(range(len(models_list)))
ax4.set_yticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=10)
ax4.set_xlabel('Target Model', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax4.set_ylabel('Source Model', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax4.set_title('Adversarial Transferability Matrix', fontsize=14, fontweight='bold', color=PALETTE["text"])

# Add annotations
for i in range(len(models_list)):
    for j in range(len(models_list)):
        text = ax4.text(j, i, f'{transfer_matrix[i, j]:.2f}', ha="center", va="center",
                       color="white" if transfer_matrix[i, j] < 0.5 else "black", 
                       fontweight='bold', fontsize=10)

cbar = plt.colorbar(im, ax=ax4)
cbar.set_label('Transfer Rate', color=PALETTE["text"], fontsize=11, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"])
cbar.outline.set_edgecolor(PALETTE["primary"])
apply_dark_style(ax4)

# 5. Key statistics text box
ax5 = fig.add_subplot(gs[1, 2])
ax5.axis('off')
stats_text = f"""
KEY STATISTICS
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Best Clean Accuracy:
  {max(clean_vals):.4f} ({models_list[np.argmax(clean_vals)]})

Best FGSM Accuracy:
  {max(fgsm_vals):.4f} ({models_list[np.argmax(fgsm_vals)]})

Best PGD-20 Accuracy:
  {max(pgd_vals):.4f} ({models_list[np.argmax(pgd_vals)]})

Average Robustness Gain:
  {np.mean(improvements):+.2f}%

Dataset: {DATASET}
Epsilon: {EPS:.4f}
"""
ax5.text(0.1, 0.5, stats_text, transform=ax5.transAxes, fontsize=11,
        color=PALETTE["text"], verticalalignment='center',
        family='monospace', fontweight='bold',
        bbox=dict(boxstyle='round', facecolor=DARK_THEME["axes.facecolor"],
                 edgecolor=PALETTE["primary"], linewidth=2, alpha=0.9))

fig.suptitle('EGEAT: Complete Evaluation Dashboard', fontsize=20, fontweight='bold',
            color=PALETTE["text"], y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.96])
savefig("fig_complete_dashboard.png", dpi=300)
plt.show()

print("\n✓ All visualizations generated successfully!")
print(f"✓ All figures saved to: {SAVE_DIR}")
print(f"✓ Showcase figures saved to: {BLOG_DIR}")
