
# EGEAT — Exact Geometric Ensemble Adversarial Training (Colab-Heavy Reproduction)

**Implements Algorithm 1 exactly** with Eq. (11) (exact inner maximization), Eq. (4) (geometric regularizer), and Eq. (5) (weight‑space smoothing).  
Datasets: **MNIST (ε=0.3)** and **CIFAR‑10 (ε=8/255)**. DCGAN‑inspired CNNs.  
Training defaults (Sec. V‑C): **K=5, λ₁=0.1, λ₂=0.05, batch=128, Adam(2e‑4, β₁=0.5), epochs=100**.  
Evaluations: Clean, FGSM, PGD‑20; gradient similarity; loss surface; transferability; ablation.


In [None]:

import os, math, random, time, copy, itertools
from dataclasses import dataclass
import numpy as np
import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
from torchvision import datasets, transforms

seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAVE_DIR = "/content" if os.path.exists("/content") else "."
os.makedirs(SAVE_DIR, exist_ok=True)
print("Device:", device, "| SAVE_DIR:", SAVE_DIR)

# ===== DARK LINKEDIN-STYLE THEME =====
# High-contrast dark theme for professional presentations
DARK_THEME = {
    "figure.facecolor": "#0A0E27",      # Deep dark blue-black background
    "axes.facecolor": "#0F1629",        # Slightly lighter for axes
    "axes.edgecolor": "#4A90E2",        # LinkedIn blue for edges
    "axes.labelcolor": "#E8F0FE",       # Light blue-white for labels
    "xtick.color": "#B8C5D6",          # Soft gray-blue for ticks
    "ytick.color": "#B8C5D6",
    "text.color": "#E8F0FE",            # Light text
    "font.size": 14,
    "font.weight": "medium",
    "axes.linewidth": 1.5,
    "axes.grid": True,
    "grid.color": "#1A2332",
    "grid.alpha": 0.3,
    "grid.linewidth": 0.8,
    "figure.dpi": 150,
    "savefig.dpi": 300,
    "savefig.facecolor": "#0A0E27",
    "savefig.edgecolor": "none",
}

mpl.rcParams.update(DARK_THEME)
sns.set_style("darkgrid", {"axes.facecolor": "#0F1629"})

# Professional color palette (high contrast, colorblind-friendly)
PALETTE = {
    "primary": "#0077B5",      # LinkedIn blue
    "secondary": "#00A0DC",    # Light LinkedIn blue
    "accent1": "#FF6B35",       # Coral orange
    "accent2": "#4ECDC4",       # Turquoise
    "accent3": "#FFE66D",       # Yellow
    "accent4": "#A8E6CF",       # Mint green
    "success": "#28A745",       # Green
    "warning": "#FFC107",       # Amber
    "error": "#DC3545",         # Red
    "text": "#E8F0FE",          # Light text
    "text_secondary": "#B8C5D6", # Secondary text
}

COLORS = [PALETTE["primary"], PALETTE["accent1"], PALETTE["accent2"], 
          PALETTE["accent3"], PALETTE["accent4"], PALETTE["secondary"]]

def savefig(name, dpi=300, bbox_inches="tight"):
    """Save figure with dark theme background"""
    path = os.path.join(SAVE_DIR, name)
    plt.savefig(path, dpi=dpi, bbox_inches=bbox_inches, 
                facecolor=DARK_THEME["figure.facecolor"],
                edgecolor="none")
    print(f"✓ Saved: {path}")

def apply_dark_style(ax=None):
    """Apply dark theme to current axes"""
    if ax is None:
        ax = plt.gca()
    ax.set_facecolor(DARK_THEME["axes.facecolor"])
    ax.spines['bottom'].set_color(PALETTE["primary"])
    ax.spines['top'].set_color(PALETTE["primary"])
    ax.spines['right'].set_color(PALETTE["primary"])
    ax.spines['left'].set_color(PALETTE["primary"])
    ax.tick_params(colors=PALETTE["text_secondary"])
    ax.xaxis.label.set_color(PALETTE["text"])
    ax.yaxis.label.set_color(PALETTE["text"])
    ax.title.set_color(PALETTE["text"])
    return ax

print("✓ Dark LinkedIn-style theme initialized")


In [None]:

from dataclasses import dataclass

@dataclass
class DataConfig:
    name: str = "CIFAR10"  # "MNIST" or "CIFAR10"
    batch_size: int = 128

def get_loaders(name="CIFAR10", batch_size=128):
    if name.upper() == "MNIST":
        eps = 0.3
        tf = transforms.ToTensor()
        train = datasets.MNIST(root=SAVE_DIR, train=True, download=True, transform=tf)
        test  = datasets.MNIST(root=SAVE_DIR, train=False, download=True, transform=tf)
        num_classes = 10; in_ch = 1; img_sz = 28
    elif name.upper() == "CIFAR10":
        eps = 8/255
        tf_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])
        tf_test = transforms.ToTensor()
        train = datasets.CIFAR10(root=SAVE_DIR, train=True, download=True, transform=tf_train)
        test  = datasets.CIFAR10(root=SAVE_DIR, train=False, download=True, transform=tf_test)
        num_classes = 10; in_ch = 3; img_sz = 32
    else:
        raise ValueError("Unknown dataset")

    val_len = int(0.2 * len(train))
    train_len = len(train) - val_len
    gen = torch.Generator().manual_seed(seed)
    train_ds, val_ds = random_split(train, [train_len, val_len], generator=gen)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    test_loader  = DataLoader(test,     batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader, test_loader, num_classes, in_ch, img_sz, eps

DATASET = "CIFAR10"
train_loader, val_loader, test_loader, NUM_CLASSES, IN_CH, IMG_SZ, EPS = get_loaders(DATASET, 128)
print(f"Dataset: {DATASET} | eps(Linf)={EPS} | classes={NUM_CLASSES} | in_ch={IN_CH} | img={IMG_SZ}x{IMG_SZ}")


In [None]:

def conv_block(in_c, out_c, k=3, s=1, p=1):
    return nn.Sequential(
        nn.Conv2d(in_c, out_c, kernel_size=k, stride=s, padding=p, bias=False),
        nn.BatchNorm2d(out_c),
        nn.LeakyReLU(0.2, inplace=True),
    )

class CNN_MNIST(nn.Module):
    def __init__(self, in_ch=1, num_classes=10):
        super().__init__()
        self.f = nn.Sequential(
            conv_block(in_ch, 32, 3, 1, 1),
            conv_block(32, 64, 3, 2, 1),
            conv_block(64, 128, 3, 2, 1),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.h = nn.Linear(128, num_classes)
    def forward(self, x):
        z = self.f(x)
        return self.h(z.view(z.size(0), -1))

class CNN_CIFAR10(nn.Module):
    def __init__(self, in_ch=3, num_classes=10):
        super().__init__()
        self.f = nn.Sequential(
            conv_block(in_ch, 64, 3, 1, 1),
            conv_block(64, 64, 3, 2, 1),
            conv_block(64, 128, 3, 1, 1),
            conv_block(128, 128, 3, 2, 1),
            conv_block(128, 256, 3, 1, 1),
            conv_block(256, 256, 3, 2, 1),
            nn.AdaptiveAvgPool2d((1,1)),
        )
        self.h = nn.Linear(256, num_classes)
    def forward(self, x):
        z = self.f(x)
        return self.h(z.view(z.size(0), -1))

def make_model(dataset):
    return (CNN_MNIST if dataset.upper()=="MNIST" else CNN_CIFAR10)(IN_CH, NUM_CLASSES).to(device)

def accuracy(model, loader):
    model.eval(); tot=ok=0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            ok += (model(xb).argmax(1) == yb).sum().item()
            tot += yb.size(0)
    return ok/tot


In [None]:

def exact_perturbation(x, y, model, loss_fn, eps=0.3, p='linf'):
    x = x.detach().clone().to(device).requires_grad_(True)
    model.zero_grad(set_to_none=True)
    loss = loss_fn(model(x), y.to(device)); loss.backward()
    g = x.grad.detach()
    if p == 'linf':
        delta = eps * g.sign()
    elif p == 'l2':
        g_flat = g.view(g.size(0), -1)
        nrm = g_flat.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)
        delta = (eps * (g_flat / nrm)).view_as(g)
    else:
        raise ValueError("Use p in {linf,l2}")
    return (x + delta).clamp(0,1).detach()

def fgsm_attack(x, y, model, loss_fn, eps): return exact_perturbation(x, y, model, loss_fn, eps=eps, p='linf')

def pgd_attack(x, y, model, loss_fn, eps, alpha, steps):
    x0 = x.detach().clone().to(device)
    x_adv = (x0 + torch.empty_like(x0).uniform_(-eps, eps)).clamp(0,1).detach()
    for _ in range(steps):
        x_adv = x_adv.clone().detach().requires_grad_(True)
        model.zero_grad(set_to_none=True)
        loss = loss_fn(model(x_adv), y.to(device)); loss.backward()
        g = x_adv.grad.detach()
        x_adv = x_adv + alpha * g.sign()
        eta = torch.clamp(x_adv - x0, -eps, eps)
        x_adv = (x0 + eta).clamp(0,1).detach()
    return x_adv


In [None]:

def batch_input_grads(models, x, y, loss_fn):
    grads = []
    for m in models:
        m.eval(); m.zero_grad(set_to_none=True)
        x_ = x.detach().clone().to(device).requires_grad_(True)
        loss = loss_fn(m(x_), y.to(device)); loss.backward()
        grads.append(x_.grad.detach())
    return grads

def cos_sim(a, b, eps=1e-12):
    a = a.view(a.size(0), -1); b = b.view(b.size(0), -1)
    num = (a*b).sum(dim=1); den = a.norm(p=2, dim=1)*b.norm(p=2, dim=1)+eps
    return (num/den).mean()

def geometric_regularizer(models, x, y, loss_fn):
    if len(models) < 2: return torch.tensor(0.0, device=device)
    grads = batch_input_grads(models, x, y, loss_fn)
    sims = []
    for i in range(len(models)):
        for j in range(i+1, len(models)):
            sims.append(cos_sim(grads[i], grads[j]))
    return torch.stack(sims).mean()

def update_soup(snapshots):
    if not snapshots: return None
    base = copy.deepcopy(snapshots[0]).to(device)
    with torch.no_grad():
        for p in base.parameters(): p.data.zero_()
        for s in snapshots:
            for pb, ps in zip(base.parameters(), s.parameters()):
                pb.add_(ps.data)
        for p in base.parameters(): p.data.div_(len(snapshots))
    return base


In [None]:

from dataclasses import dataclass

@dataclass
class EGEATConfig:
    epochs: int = 20
    eps: float = float(EPS)
    lambda_geom: float = 0.1
    lambda_soup: float = 0.05
    snapshots_k: int = 3
    lr: float = 3e-5
    beta1: float = 0.5

def train_egeat(cfg: EGEATConfig):
    model = make_model(DATASET)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
    loss_fn = nn.CrossEntropyLoss()
    snapshots=[]; soup=None
    
    # Track training history
    history = {
        'epoch': [],
        'loss': [],
        'val_acc': [],
        'adv_loss': [],
        'geom_loss': [],
        'soup_loss': []
    }
    
    snap_every = max(1, cfg.epochs // cfg.snapshots_k)
    for epoch in range(1, cfg.epochs+1):
        model.train(); tot=n=0; adv_tot=geom_tot=soup_tot=0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            xb_adv = exact_perturbation(xb, yb, model, loss_fn, eps=cfg.eps, p='linf')
            L_adv = loss_fn(model(xb_adv), yb)
            models_for_geom = [model] + ([soup] if soup is not None else [])
            L_geom = geometric_regularizer(models_for_geom, xb, yb, loss_fn)
            if soup is not None:
                L_soup = torch.tensor(0.0, device=device)
                for p, ps in zip(model.parameters(), soup.parameters()):
                    L_soup += (p-ps).pow(2).sum()
            else:
                L_soup = torch.tensor(0.0, device=device)
            loss = L_adv + cfg.lambda_geom*L_geom + cfg.lambda_soup*L_soup
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            tot += loss.item()*xb.size(0); n += xb.size(0)
            adv_tot += L_adv.item()*xb.size(0)
            geom_tot += L_geom.item()*xb.size(0) if isinstance(L_geom, torch.Tensor) else L_geom*xb.size(0)
            soup_tot += L_soup.item()*xb.size(0) if isinstance(L_soup, torch.Tensor) else L_soup*xb.size(0)
        if epoch % snap_every == 0 or epoch == cfg.epochs:
            snapshots.append(copy.deepcopy(model).to(device))
            soup = update_soup(snapshots)
        val_acc = accuracy(model, val_loader)
        history['epoch'].append(epoch)
        history['loss'].append(tot/max(1,n))
        history['val_acc'].append(val_acc)
        history['adv_loss'].append(adv_tot/max(1,n))
        history['geom_loss'].append(geom_tot/max(1,n))
        history['soup_loss'].append(soup_tot/max(1,n))
        print(f"[EGEAT] {epoch}/{cfg.epochs} loss={tot/max(1,n):.4f} val_acc={val_acc:.3f} snaps={len(snapshots)}")
    final_soup = update_soup(snapshots) if snapshots else copy.deepcopy(model)
    return model, final_soup, history

egeat_model, egeat_soup, egeat_history = train_egeat(EGEATConfig())
print("Validation (EGEAT model):", accuracy(egeat_model, val_loader))
print("Validation (EGEAT soup): ", accuracy(egeat_soup,  val_loader))

# Visualize training progress
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.patch.set_facecolor(DARK_THEME["figure.facecolor"])

# Loss curves
ax1 = axes[0, 0]
ax1.plot(egeat_history['epoch'], egeat_history['loss'], 'o-', label='Total Loss', 
         color=PALETTE["primary"], linewidth=2, markersize=6)
ax1.plot(egeat_history['epoch'], egeat_history['adv_loss'], 's-', label='Adversarial Loss', 
         color=PALETTE["accent1"], linewidth=2, markersize=5)
ax1.set_xlabel('Epoch', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax1.set_ylabel('Loss', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax1.set_title('Training Losses', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax1.legend(frameon=True, facecolor=DARK_THEME["axes.facecolor"], edgecolor=PALETTE["primary"], 
          labelcolor=PALETTE["text"], fontsize=10)
apply_dark_style(ax1)

# Regularization losses
ax2 = axes[0, 1]
ax2.plot(egeat_history['epoch'], egeat_history['geom_loss'], '^-', label='Geometric Loss', 
         color=PALETTE["accent2"], linewidth=2, markersize=6)
ax2.plot(egeat_history['epoch'], egeat_history['soup_loss'], 'v-', label='Soup Loss', 
         color=PALETTE["accent3"], linewidth=2, markersize=5)
ax2.set_xlabel('Epoch', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax2.set_ylabel('Loss', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax2.set_title('Regularization Losses', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax2.legend(frameon=True, facecolor=DARK_THEME["axes.facecolor"], edgecolor=PALETTE["primary"], 
          labelcolor=PALETTE["text"], fontsize=10)
apply_dark_style(ax2)

# Validation accuracy
ax3 = axes[1, 0]
ax3.plot(egeat_history['epoch'], egeat_history['val_acc'], 'o-', 
         color=PALETTE["success"], linewidth=3, markersize=7, markerfacecolor=PALETTE["accent4"],
         markeredgecolor=PALETTE["success"], markeredgewidth=2)
ax3.fill_between(egeat_history['epoch'], egeat_history['val_acc'], alpha=0.3, color=PALETTE["success"])
ax3.set_xlabel('Epoch', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax3.set_ylabel('Validation Accuracy', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax3.set_title('Validation Accuracy Progress', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax3.set_ylim(0, 1.0)
apply_dark_style(ax3)

# Loss components breakdown
ax4 = axes[1, 1]
epochs = egeat_history['epoch']
width = 0.6
x = np.array(epochs)
bottom = np.zeros(len(epochs))
cfg = EGEATConfig()  # Use default config for lambda values
components = [
    (np.array(egeat_history['adv_loss']), 'Adversarial', PALETTE["accent1"]),
    (np.array(egeat_history['geom_loss']) * cfg.lambda_geom, 'Geometric', PALETTE["accent2"]),
    (np.array(egeat_history['soup_loss']) * cfg.lambda_soup, 'Soup', PALETTE["accent3"])
]
for values, label, color in components:
    ax4.bar(x, values, width, bottom=bottom, label=label, color=color, alpha=0.8, 
           edgecolor=PALETTE["primary"], linewidth=1)
    bottom += values
ax4.set_xlabel('Epoch', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax4.set_ylabel('Loss Components', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax4.set_title('Loss Components Breakdown', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax4.legend(frameon=True, facecolor=DARK_THEME["axes.facecolor"], edgecolor=PALETTE["primary"], 
          labelcolor=PALETTE["text"], fontsize=10)
apply_dark_style(ax4)

plt.suptitle('EGEAT Training Progress', fontsize=18, fontweight='bold', color=PALETTE["text"], y=0.995)
plt.tight_layout(rect=[0, 0, 1, 0.97])
savefig("fig_training_progress.png", dpi=300)
plt.show()


In [None]:

@dataclass
class PGDCfg:
    epochs: int = 20
    eps: float = float(EPS)
    alpha: float = float(EPS)/4
    steps: int = 20
    lr: float = 3e-5
    beta1: float = 0.5

def train_pgd(cfg: PGDCfg):
    model = make_model(DATASET)
    opt = torch.optim.Adam(model.parameters(), lr=cfg.lr, betas=(cfg.beta1, 0.999))
    loss_fn = nn.CrossEntropyLoss()
    
    # Track training history
    history = {
        'epoch': [],
        'loss': [],
        'val_acc': []
    }
    
    for epoch in range(1, cfg.epochs+1):
        model.train(); tot=n=0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            xb_adv = pgd_attack(xb, yb, model, loss_fn, eps=cfg.eps, alpha=cfg.alpha, steps=cfg.steps)
            loss = loss_fn(model(xb_adv), yb)
            opt.zero_grad(set_to_none=True); loss.backward(); opt.step()
            tot += loss.item()*xb.size(0); n += xb.size(0)
        val_acc = accuracy(model, val_loader)
        history['epoch'].append(epoch)
        history['loss'].append(tot/max(1,n))
        history['val_acc'].append(val_acc)
        print(f"[PGD] {epoch}/{cfg.epochs} loss={tot/max(1,n):.4f} val_acc={val_acc:.3f}")
    return model, history

pgd_model, pgd_history = train_pgd(PGDCfg())
print("Validation (PGD):", accuracy(pgd_model, val_loader))

# Visualize PGD training progress
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(DARK_THEME["figure.facecolor"])

# Loss curve
ax1 = axes[0]
ax1.plot(pgd_history['epoch'], pgd_history['loss'], 'o-', 
         color=PALETTE["primary"], linewidth=3, markersize=6, 
         markerfacecolor=PALETTE["accent1"], markeredgecolor=PALETTE["primary"], markeredgewidth=2)
ax1.fill_between(pgd_history['epoch'], pgd_history['loss'], alpha=0.3, color=PALETTE["primary"])
ax1.set_xlabel('Epoch', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax1.set_ylabel('Training Loss', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax1.set_title('PGD Training Loss', fontsize=14, fontweight='bold', color=PALETTE["text"])
apply_dark_style(ax1)

# Validation accuracy
ax2 = axes[1]
ax2.plot(pgd_history['epoch'], pgd_history['val_acc'], 'o-', 
         color=PALETTE["success"], linewidth=3, markersize=6,
         markerfacecolor=PALETTE["accent4"], markeredgecolor=PALETTE["success"], markeredgewidth=2)
ax2.fill_between(pgd_history['epoch'], pgd_history['val_acc'], alpha=0.3, color=PALETTE["success"])
ax2.set_xlabel('Epoch', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax2.set_ylabel('Validation Accuracy', fontsize=12, color=PALETTE["text"], fontweight='bold')
ax2.set_title('PGD Validation Accuracy', fontsize=14, fontweight='bold', color=PALETTE["text"])
ax2.set_ylim(0, 1.0)
apply_dark_style(ax2)

plt.suptitle('PGD Training Progress', fontsize=16, fontweight='bold', color=PALETTE["text"], y=1.02)
plt.tight_layout()
savefig("fig_pgd_training.png", dpi=300)
plt.show()


In [None]:

def eval_adv_acc(model, loader, attack='fgsm', eps=EPS, alpha=None, steps=20):
    model.eval(); loss_fn = nn.CrossEntropyLoss(); tot=ok=0
    for xb, yb in loader:
        if attack=='fgsm':
            xb_adv = fgsm_attack(xb, yb, model, loss_fn, eps=eps)
        else:
            a = alpha if alpha is not None else eps/4
            xb_adv = pgd_attack(xb, yb, model, loss_fn, eps=eps, alpha=a, steps=steps)
        with torch.no_grad():
            pred = model(xb_adv.to(device)).argmax(1).cpu()
        ok += (pred == yb).sum().item(); tot += yb.size(0)
    return ok/tot

# Comprehensive evaluation metrics
clean_e = accuracy(egeat_model, test_loader)
fgsm_e = eval_adv_acc(egeat_model, test_loader, 'fgsm')
pgd_e = eval_adv_acc(egeat_model, test_loader, 'pgd', steps=20)
clean_s = accuracy(egeat_soup, test_loader)
fgsm_s = eval_adv_acc(egeat_soup, test_loader, 'fgsm')
pgd_s = eval_adv_acc(egeat_soup, test_loader, 'pgd', steps=20)
clean_p = accuracy(pgd_model, test_loader)
fgsm_p = eval_adv_acc(pgd_model, test_loader, 'fgsm')
pgd_p = eval_adv_acc(pgd_model, test_loader, 'pgd', steps=20)

# Display metrics in a professional table format
print("\n" + "="*70)
print("COMPREHENSIVE EVALUATION METRICS")
print("="*70)
print(f"{'Model':<20} {'Clean Acc':<12} {'FGSM Acc':<12} {'PGD-20 Acc':<12}")
print("-"*70)
print(f"{'EGEAT Model':<20} {clean_e:.4f}      {fgsm_e:.4f}      {pgd_e:.4f}")
print(f"{'EGEAT Soup':<20} {clean_s:.4f}      {fgsm_s:.4f}      {pgd_s:.4f}")
print(f"{'PGD Model':<20} {clean_p:.4f}      {fgsm_p:.4f}      {pgd_p:.4f}")
print("="*70)

# Publication-quality 3D metrics with smooth, non-overlapping surfaces
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

fig = plt.figure(figsize=(18, 12))
ax = fig.add_subplot(111, projection='3d')

models = ['EGEAT Model', 'EGEAT Soup', 'PGD Model']
attack_types = ['Clean', 'FGSM', 'PGD-20']
acc_data = np.array([
    [clean_e, fgsm_e, pgd_e],
    [clean_s, fgsm_s, pgd_s],
    [clean_p, fgsm_p, pgd_p]
])

# Create high-resolution smooth surface (no overlapping bars)
x = np.arange(len(models))
y = np.arange(len(attack_types))
X_coarse, Y_coarse = np.meshgrid(x, y)
Z_coarse = acc_data.T

# Interpolate to smooth surface
x_fine = np.linspace(0, len(models)-1, 100)
y_fine = np.linspace(0, len(attack_types)-1, 100)
X_fine, Y_fine = np.meshgrid(x_fine, y_fine)

# Create smooth surface using cubic interpolation
points = np.column_stack([X_coarse.flatten(), Y_coarse.flatten()])
values = Z_coarse.flatten()
Z_fine = griddata(points, values, (X_fine, Y_fine), method='cubic', fill_value=np.nan)

# Fill NaN values with nearest neighbor
mask = np.isnan(Z_fine)
if mask.any():
    Z_fine[mask] = griddata(points, values, (X_fine[mask], Y_fine[mask]), method='nearest')

# Create smooth gradient surface (NO overlapping bars)
surf = ax.plot_surface(X_fine, Y_fine, Z_fine, cmap='viridis', alpha=0.95, 
                       edgecolor='none', linewidth=0, antialiased=True, shade=True,
                       rstride=2, cstride=2, vmin=0, vmax=1)

# Add value labels at data points only (no overlap)
for i in range(len(models)):
    for j in range(len(attack_types)):
        ax.text(i, j, acc_data[i, j] + 0.03, f'{acc_data[i, j]:.3f}', 
               ha='center', va='bottom', color=PALETTE["text"], fontsize=11, fontweight='bold',
               bbox=dict(boxstyle='round,pad=0.4', facecolor=DARK_THEME["axes.facecolor"], 
                        edgecolor=PALETTE["primary"], alpha=0.9, linewidth=1.5))

# Professional camera angle
ax.view_init(elev=30, azim=45)
ax.set_xlabel('Model', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_ylabel('Attack Type', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_zlabel('Accuracy', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_title('Model Performance: Clean vs Adversarial Accuracy', 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)
ax.set_xticks(x)
ax.set_xticklabels(models, color=PALETTE["text"], fontsize=12, rotation=12)
ax.set_yticks(y)
ax.set_yticklabels(attack_types, color=PALETTE["text"], fontsize=12)
ax.set_zlim(0, 1.0)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('Accuracy', color=PALETTE["text"], fontsize=13, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)

plt.tight_layout()
savefig("fig_comprehensive_metrics_3d.png", dpi=300)
plt.show()


In [None]:

# Gradient similarity - Enhanced with dark theme
loss_fn = nn.CrossEntropyLoss()
models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT Model", "EGEAT Soup", "PGD Model"]
xb, yb = next(iter(test_loader)); xb, yb = xb.to(device), yb.to(device)
grads = []
for m in models:
    m.zero_grad(set_to_none=True)
    x_ = xb.detach().clone().requires_grad_(True)
    loss = loss_fn(m(x_), yb); loss.backward()
    grads.append(x_.grad.detach())

def pairwise_cos(grads):
    M=len(grads); mat=torch.zeros(M,M)
    for i in range(M):
        for j in range(M):
            mat[i,j] = cos_sim(grads[i], grads[j]).detach().cpu()
    return mat.numpy()

G = pairwise_cos(grads)

# Smooth 3D Surface plot of gradient similarity (no overlapping elements)
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

fig = plt.figure(figsize=(16, 12))
ax = fig.add_subplot(111, projection='3d')

x = np.arange(len(names))
y = np.arange(len(names))
X_coarse, Y_coarse = np.meshgrid(x, y)
Z_coarse = G

# Create smooth interpolated surface
x_fine = np.linspace(0, len(names)-1, 150)
y_fine = np.linspace(0, len(names)-1, 150)
X_fine, Y_fine = np.meshgrid(x_fine, y_fine)

points = np.column_stack([X_coarse.flatten(), Y_coarse.flatten()])
values = Z_coarse.flatten()
Z_fine = griddata(points, values, (X_fine, Y_fine), method='cubic', fill_value=np.nan)
if np.isnan(Z_fine).any():
    Z_fine[np.isnan(Z_fine)] = griddata(points, values, 
                                       (X_fine[np.isnan(Z_fine)], Y_fine[np.isnan(Z_fine)]), 
                                       method='nearest')

# Create smooth gradient surface (NO overlapping contours)
surf = ax.plot_surface(X_fine, Y_fine, Z_fine, cmap='plasma', alpha=0.96, 
                      edgecolor='none', linewidth=0, antialiased=True, shade=True,
                      rstride=2, cstride=2, vmin=0, vmax=1)

# Add value labels at data points only
for i in range(len(names)):
    for j in range(len(names)):
        ax.text(i, j, G[i, j] + 0.03, f'{G[i, j]:.3f}', 
               ha='center', va='bottom', color=PALETTE["text"], fontsize=11, fontweight='bold',
               bbox=dict(boxstyle='round,pad=0.4', facecolor=DARK_THEME["axes.facecolor"], 
                        edgecolor=PALETTE["primary"], alpha=0.9, linewidth=1.5))

# Professional camera angle
ax.view_init(elev=30, azim=45)
ax.set_xlabel('Model Index', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_ylabel('Model Index', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_zlabel('Cosine Similarity', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_title("Gradient Subspace Similarity Surface", fontsize=18, fontweight='bold', 
            color=PALETTE["text"], pad=25)
ax.set_xticks(x)
ax.set_xticklabels(names, color=PALETTE["text"], fontsize=12, rotation=12)
ax.set_yticks(y)
ax.set_yticklabels(names, color=PALETTE["text"], fontsize=12)
ax.set_zlim(0, 1.1)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('Cosine Similarity', color=PALETTE["text"], fontsize=13, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)

plt.tight_layout()
savefig("fig_gradient_similarity_3d.png", dpi=300)
plt.show()


In [None]:

# Loss surface
def get_vec(m): return torch.cat([p.detach().view(-1) for p in m.parameters()])
def set_vec(m, v):
    i=0
    for p in m.parameters():
        n=p.numel(); p.data.copy_(v[i:i+n].view_as(p)); i+=n

base = copy.deepcopy(egeat_model).eval().to(device)
w = get_vec(base); d1, d2 = torch.randn_like(w), torch.randn_like(w); d1/=d1.norm()+1e-12; d2/=d2.norm()+1e-12
grid=21; alphas=torch.linspace(-0.5,0.5,grid); betas=torch.linspace(-0.5,0.5,grid); Z=np.zeros((grid,grid),dtype=np.float32)

with torch.no_grad():
    xb_s, yb_s = xb[:min(128, xb.size(0))], yb[:min(128, yb.size(0))]
    for i,a in enumerate(alphas):
        for j,b in enumerate(betas):
            tmp = copy.deepcopy(base).to(device); set_vec(tmp, w + a*d1 + b*d2)
            Z[i,j] = nn.CrossEntropyLoss()(tmp(xb_s), yb_s).item()

# Smooth 3D Loss surface with high-quality interpolation
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

fig = plt.figure(figsize=(16, 12))
ax = fig.add_subplot(111, projection='3d')

A = alphas.cpu().numpy()
B = betas.cpu().numpy()
A_coarse, B_coarse = np.meshgrid(A, B)
Z_coarse = Z.T

# High-resolution smooth interpolation
A_fine = np.linspace(A.min(), A.max(), 200)
B_fine = np.linspace(B.min(), B.max(), 200)
A_mesh, B_mesh = np.meshgrid(A_fine, B_fine)

points = np.column_stack([A_coarse.flatten(), B_coarse.flatten()])
values = Z_coarse.flatten()
Z_fine = griddata(points, values, (A_mesh, B_mesh), method='cubic', fill_value=np.nan)
if np.isnan(Z_fine).any():
    Z_fine[np.isnan(Z_fine)] = griddata(points, values, 
                                       (A_mesh[np.isnan(Z_fine)], B_mesh[np.isnan(Z_fine)]), 
                                       method='nearest')

# Create smooth gradient surface (NO overlapping wireframes/contours)
surf = ax.plot_surface(A_mesh, B_mesh, Z_fine, cmap='plasma', alpha=0.97, 
                       edgecolor='none', linewidth=0, antialiased=True, shade=True,
                       rstride=2, cstride=2)

# Professional camera angle
ax.view_init(elev=30, azim=45)
ax.set_xlabel("α (Direction 1)", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_ylabel("β (Direction 2)", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_zlabel("Cross-Entropy Loss", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_title("Loss Landscape Around EGEAT Model", fontsize=18, fontweight='bold', 
            color=PALETTE["text"], pad=25)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('Cross-Entropy Loss', color=PALETTE["text"], fontsize=14, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)

plt.tight_layout()
savefig("fig_loss_surface_3d.png", dpi=300)
plt.show()


In [None]:

# Adversarial example grid
loss_fn = nn.CrossEntropyLoss()
idxs = torch.arange(0, min(10, xb.size(0))); xv, yv = xb[idxs].detach().cpu(), yb[idxs].detach().cpu()
x_fgsm_pgd = fgsm_attack(xv, yv, pgd_model, loss_fn, eps=EPS).cpu()
x_exact_e  = exact_perturbation(xv, yv, egeat_model, loss_fn, eps=EPS, p='linf').cpu()

def show_triplet(orig, a, b, title_a="FGSM (PGD)", title_b="Exact (EGEAT)"):
    n=orig.size(0)
    fig, axes = plt.subplots(3, n, figsize=(1.5*n, 4.5))
    fig.patch.set_facecolor(DARK_THEME["figure.facecolor"])
    
    for i in range(n):
        im0=orig[i].permute(1,2,0).squeeze().numpy()
        im1=a[i].permute(1,2,0).squeeze().numpy()
        im2=b[i].permute(1,2,0).squeeze().numpy()
        
        # Handle grayscale images
        if len(im0.shape) == 2:
            im0 = np.stack([im0]*3, axis=-1)
        if len(im1.shape) == 2:
            im1 = np.stack([im1]*3, axis=-1)
        if len(im2.shape) == 2:
            im2 = np.stack([im2]*3, axis=-1)
        
        axes[0,i].imshow(np.clip(im0, 0, 1), vmin=0, vmax=1)
        axes[0,i].axis("off")
        axes[0,i].set_facecolor(DARK_THEME["axes.facecolor"])
        
        axes[1,i].imshow(np.clip(im1, 0, 1), vmin=0, vmax=1)
        axes[1,i].axis("off")
        axes[1,i].set_facecolor(DARK_THEME["axes.facecolor"])
        
        axes[2,i].imshow(np.clip(im2, 0, 1), vmin=0, vmax=1)
        axes[2,i].axis("off")
        axes[2,i].set_facecolor(DARK_THEME["axes.facecolor"])
    
    axes[0,0].set_ylabel("Original", fontsize=12, color=PALETTE["text"], fontweight='bold', rotation=0, ha='right', va='center')
    axes[1,0].set_ylabel(title_a, fontsize=12, color=PALETTE["accent1"], fontweight='bold', rotation=0, ha='right', va='center')
    axes[2,0].set_ylabel(title_b, fontsize=12, color=PALETTE["accent2"], fontweight='bold', rotation=0, ha='right', va='center')
    
    fig.suptitle("Adversarial Examples Comparison", fontsize=16, fontweight='bold', color=PALETTE["text"], y=0.98)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    savefig("fig_adversarial_grid.png", dpi=300)
    plt.show()

show_triplet(xv, x_fgsm_pgd, x_exact_e)


In [None]:

# Transferability
def transfer_rate(src, tgt, loader, eps=EPS):
    loss_fn = nn.CrossEntropyLoss(); total=fooled=0
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, src, loss_fn, eps=eps)
        with torch.no_grad(): pred = tgt(xb_adv.to(device)).argmax(1).cpu()
        fooled += (pred != yb).sum().item(); total += yb.size(0)
    return fooled/total

models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT model", "EGEAT soup", "PGD model"]
pairs = [(0,1),(0,2),(1,2)]; rates=[]
for i,j in pairs:
    r = transfer_rate(models[i], models[j], test_loader, eps=EPS); rates.append(r); print(f"{names[i]}→{names[j]}: {r:.3f}")

# 3D Transferability visualization
from mpl_toolkits.mplot3d import Axes3D

# Build full transfer matrix
transfer_matrix = np.zeros((len(models), len(models)))
for i, src_model in enumerate(models):
    for j, tgt_model in enumerate(models):
        if i != j:
            transfer_matrix[i, j] = transfer_rate(src_model, tgt_model, test_loader, eps=EPS)
        else:
            transfer_matrix[i, j] = 1.0  # Self-attack

fig = plt.figure(figsize=(14, 10))
ax = fig.add_subplot(111, projection='3d')

x = np.arange(len(names))
y = np.arange(len(names))
X, Y = np.meshgrid(x, y)
Z = transfer_matrix

# Create smooth interpolated surface (NO overlapping bars)
from scipy.interpolate import griddata

X_coarse, Y_coarse = np.meshgrid(x, y)
Z_coarse = transfer_matrix

# High-resolution smooth interpolation
x_fine = np.linspace(0, len(names)-1, 150)
y_fine = np.linspace(0, len(names)-1, 150)
X_fine, Y_fine = np.meshgrid(x_fine, y_fine)

points = np.column_stack([X_coarse.flatten(), Y_coarse.flatten()])
values = Z_coarse.flatten()
Z_fine = griddata(points, values, (X_fine, Y_fine), method='cubic', fill_value=np.nan)
if np.isnan(Z_fine).any():
    Z_fine[np.isnan(Z_fine)] = griddata(points, values, 
                                       (X_fine[np.isnan(Z_fine)], Y_fine[np.isnan(Z_fine)]), 
                                       method='nearest')

# Smooth surface only (no bars)
surf = ax.plot_surface(X_fine, Y_fine, Z_fine, cmap='RdYlGn_r', alpha=0.96, 
                      edgecolor='none', linewidth=0, antialiased=True, shade=True,
                      rstride=2, cstride=2, vmin=0, vmax=1)

# Add value labels at data points only
for i in range(len(names)):
    for j in range(len(names)):
        if i != j:
            ax.text(i, j, transfer_matrix[i, j] + 0.03, f'{transfer_matrix[i, j]:.3f}', 
                   ha='center', va='bottom', color=PALETTE["text"], fontsize=11, fontweight='bold',
                   bbox=dict(boxstyle='round,pad=0.4', facecolor=DARK_THEME["axes.facecolor"], 
                            edgecolor=PALETTE["primary"], alpha=0.9, linewidth=1.5))

# Professional camera angle
ax.view_init(elev=30, azim=45)
ax.set_xlabel('Source Model', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_ylabel('Target Model', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_zlabel('Transfer Rate $P_T$', fontsize=15, color=PALETTE["text"], fontweight='bold", labelpad=15)
ax.set_title("Adversarial Transferability Matrix", fontsize=18, fontweight='bold', 
            color=PALETTE["text"], pad=25)
ax.set_xticks(x)
ax.set_xticklabels(names, color=PALETTE["text"], fontsize=12, rotation=12)
ax.set_yticks(y)
ax.set_yticklabels(names, color=PALETTE["text"], fontsize=12)
ax.set_zlim(0, 1.1)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('Transfer Rate', color=PALETTE["text"], fontsize=13, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)

plt.tight_layout()
savefig("fig_transfer_3d.png", dpi=300)
plt.show()


In [None]:

# Ablation
import pandas as pd
ABL_SETTINGS=[(0.00,0.00),(0.10,0.00),(0.10,0.05),(0.20,0.05)]
ABL_EPOCHS=20  # set to 100 for strict replication

def quick_egeat(lg, ls, epochs=ABL_EPOCHS):
    cfg = EGEATConfig(epochs=epochs, eps=EPS, lambda_geom=lg, lambda_soup=ls, snapshots_k=5, lr=2e-4)
    m,s = train_egeat(cfg)
    clean = accuracy(m, test_loader)
    pgd20 = eval_adv_acc(m, test_loader, 'pgd', eps=EPS, steps=20)
    m.eval(); ent=[]
    with torch.no_grad():
        for xb, yb in test_loader:
            p = F.softmax(m(xb.to(device)), dim=1).cpu().numpy()
            ent.append(-(p*np.log(p+1e-12)).sum(axis=1))
    ent = np.concatenate(ent); ece_proxy=float(ent.mean())
    return clean, pgd20, ece_proxy

rows=[]
for lg, ls in ABL_SETTINGS:
    print(f"[Ablation] λ1={lg} λ2={ls}")
    c,p20,ece = quick_egeat(lg, ls, epochs=ABL_EPOCHS)
    rows.append((lg, ls, c, p20, ece))

df = pd.DataFrame(rows, columns=["lambda1","lambda2","Acc_clean","Acc_PGD20","ECE_proxy"])
print(df)
csv_path = os.path.join(SAVE_DIR, "table_ablation_results.csv"); df.to_csv(csv_path, index=False); print("Saved:", csv_path)

# Smooth 3D Ablation study with non-overlapping surfaces
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata

fig = plt.figure(figsize=(18, 12))
ax = fig.add_subplot(111, projection='3d')

# High-resolution smooth interpolation
l1_grid = np.linspace(df["lambda1"].min(), df["lambda1"].max(), 150)
l2_grid = np.linspace(df["lambda2"].min(), df["lambda2"].max(), 150)
L1_grid, L2_grid = np.meshgrid(l1_grid, l2_grid)
Z_grid = griddata((df["lambda1"], df["lambda2"]), df["Acc_PGD20"], 
                  (L1_grid, L2_grid), method='cubic', fill_value=df["Acc_PGD20"].mean())

# Smooth surface only (NO overlapping wireframes)
surf = ax.plot_surface(L1_grid, L2_grid, Z_grid, cmap='plasma', alpha=0.95, 
                      edgecolor='none', linewidth=0, antialiased=True, shade=True,
                      rstride=2, cstride=2)

# Add scatter points at data locations only
scatter = ax.scatter(df["lambda1"], df["lambda2"], df["Acc_PGD20"], 
                    s=200, c=df["Acc_PGD20"], cmap='viridis', 
                    alpha=0.9, edgecolors=PALETTE["primary"], linewidths=2.5, zorder=5)

# Add annotations at data points
for i, row in df.iterrows():
    ax.text(row["lambda1"], row["lambda2"], row["Acc_PGD20"] + 0.02,
           f'λ₁={row["lambda1"]:.2f}\nλ₂={row["lambda2"]:.2f}\nAcc={row["Acc_PGD20"]:.3f}',
           fontsize=10, color=PALETTE["text"], fontweight='bold',
           bbox=dict(boxstyle='round,pad=0.4', facecolor=DARK_THEME["axes.facecolor"], 
                    edgecolor=PALETTE["primary"], alpha=0.9, linewidth=1.5))

# Professional camera angle
ax.view_init(elev=30, azim=45)
ax.set_xlabel("λ₁ (Geometric Regularization)", fontsize=15, color=PALETTE["text"], 
             fontweight='bold', labelpad=15)
ax.set_ylabel("λ₂ (Soup Regularization)", fontsize=15, color=PALETTE["text"], 
             fontweight='bold', labelpad=15)
ax.set_zlabel("PGD-20 Accuracy", fontsize=15, color=PALETTE["text"], 
             fontweight='bold', labelpad=15)
ax.set_title("Ablation Study: Hyperparameter Effects on Robustness", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('PGD-20 Accuracy', color=PALETTE["text"], fontsize=13, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)

plt.tight_layout()
savefig("fig_ablation_3d.png", dpi=300)
plt.show()



# === BLOG / LINKEDIN VISUALS — Modern Showcase ===

The following cells generate **publication-quality** and **social-ready** visuals (1080×1080) with a modern dark theme.
They reuse the trained models (`egeat_model`, `egeat_soup`, `pgd_model`) and the dataset/test loader from above.
All outputs are saved to: `os.path.join(SAVE_DIR, "blog_visuals")`.


In [None]:

# Showcase visualizations directory (uses existing dark theme from cell 1)
BLOG_DIR = os.path.join(SAVE_DIR, "blog_visuals")
os.makedirs(BLOG_DIR, exist_ok=True)

def save_square_png(name, fig=None, size=12):
    """Save square figure for social media (1080×1080 equivalent)"""
    if fig is None:
        fig = plt.gcf()
    fig.set_size_inches(size, size)
    path = os.path.join(BLOG_DIR, name)
    fig.savefig(path, bbox_inches="tight", dpi=300, 
                facecolor=DARK_THEME["figure.facecolor"],
                edgecolor="none")
    print(f"✓ Saved showcase: {path}")

print(f"✓ Showcase directory ready: {BLOG_DIR}")


In [None]:

# 1) Loss Landscape "Evolution": sweep along two random directions around EGEAT θ
# Produces a strip of frames that can be combined into a GIF (kept simple as a strip for portability).
import copy, torch, torch.nn as nn

def get_vec(m): 
    return torch.cat([p.detach().view(-1) for p in m.parameters()])

def set_vec(m, v):
    i=0
    for p in m.parameters():
        n = p.numel()
        p.data.copy_(v[i:i+n].view_as(p)); i+=n

base = copy.deepcopy(egeat_model).eval().to(device)
w = get_vec(base)
d1, d2 = torch.randn_like(w), torch.randn_like(w)
d1 /= (d1.norm()+1e-12); d2 /= (d2.norm()+1e-12)

grid = 9
alphas = torch.linspace(-0.6, 0.6, grid)
betas  = torch.linspace(-0.6, 0.6, grid)
Z = np.zeros((grid, grid), dtype=np.float32)

xb, yb = next(iter(test_loader))
xb, yb = xb.to(device), yb.to(device)

with torch.no_grad():
    xb_s, yb_s = xb[:min(256, xb.size(0))], yb[:min(256, yb.size(0))]
    for i,a in enumerate(alphas):
        for j,b in enumerate(betas):
            tmp = copy.deepcopy(base).to(device)
            set_vec(tmp, w + a*d1 + b*d2)
            Z[i,j] = nn.CrossEntropyLoss()(tmp(xb_s), yb_s).item()

# Render as a modern contour with bright strokes - Enhanced dark theme
fig, ax = plt.subplots(figsize=(10, 10))
cs = ax.contourf(alphas.cpu(), betas.cpu(), Z.T, levels=30, cmap='plasma', alpha=0.95)
contours = ax.contour(alphas.cpu(), betas.cpu(), Z.T, levels=15, colors='white', 
                     linewidths=1.5, alpha=0.5)
ax.clabel(contours, inline=True, fontsize=10, colors='white', fmt='%.2f', fontweight='bold')
ax.set_xlabel("α (Direction 1)", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_ylabel("β (Direction 2)", fontsize=16, color=PALETTE["text"], fontweight='bold')
ax.set_title("Loss Landscape Around EGEAT Model", fontsize=18, fontweight='bold', 
            color=PALETTE["text"], pad=25)
cbar = plt.colorbar(cs, ax=ax)
cbar.set_label('Cross-Entropy Loss', color=PALETTE["text"], fontsize=14, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=12)
cbar.outline.set_edgecolor(PALETTE["primary"])
cbar.outline.set_linewidth(2)
apply_dark_style(ax)
plt.tight_layout()
save_square_png("loss_landscape_showcase.png", fig=fig, size=10)
plt.show()


In [None]:

# 2) Gradient Flow Constellation: PCA of ∇_x ℓ over test mini-batches for multiple models
from sklearn.decomposition import PCA
import torch

loss_fn = nn.CrossEntropyLoss()
models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT", "Soup", "PGD"]

# Sample multiple mini-batches to build a gradient cloud
grads_all = []
labels_all = []
num_batches = 6
it = iter(test_loader)
for b in range(num_batches):
    try:
        xb, yb = next(it)
    except StopIteration:
        break
    xb, yb = xb.to(device), yb.to(device)
    for idx, m in enumerate(models):
        m.zero_grad(set_to_none=True)
        x_ = xb.detach().clone().requires_grad_(True)
        loss = loss_fn(m(x_), yb); loss.backward()
        g = x_.grad.detach().view(x_.size(0), -1).cpu().numpy()
        grads_all.append(g)
        labels_all += [names[idx]] * g.shape[0]

G = np.concatenate(grads_all, axis=0)
pca = PCA(n_components=2, random_state=42).fit(G)
P = pca.transform(G)

# 3D Gradient constellation with PCA
from mpl_toolkits.mplot3d import Axes3D

# Use 3D PCA instead of 2D
pca_3d = PCA(n_components=3, random_state=42).fit(G)
P_3d = pca_3d.transform(G)

fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# Plot 3D scatter with gradient colors
for name, color in zip(names, COLORS[:len(names)]):
    mask = np.array(labels_all) == name
    scatter = ax.scatter(P_3d[mask,0], P_3d[mask,1], P_3d[mask,2], 
                        s=50, alpha=0.8, label=name, c=color, 
                        edgecolors=PALETTE["primary"], linewidths=1.5)

ax.legend(frameon=True, facecolor=DARK_THEME["axes.facecolor"], 
         edgecolor=PALETTE["primary"], labelcolor=PALETTE["text"], 
         fontsize=14, loc='best', framealpha=0.9)
ax.set_title("3D Gradient Constellation (PCA) — Decorrelated Subspaces", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)
ax.set_xlabel("Principal Component 1", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_ylabel("Principal Component 2", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax.set_zlabel("Principal Component 3", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=10)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"])

plt.tight_layout()
save_square_png("gradient_constellation_3d.png", fig=fig, size=10)
plt.show()


In [None]:

# 3) Adversarial Transfer Graph: edges weighted by transfer rate P_T
import networkx as nx

def transfer_rate(source_model, target_model, loader, eps=float(EPS)):
    loss_fn = nn.CrossEntropyLoss()
    total=fooled=0
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, source_model, loss_fn, eps=eps)
        with torch.no_grad():
            pred_t = target_model(xb_adv.to(device)).argmax(1).cpu()
        fooled += (pred_t != yb).sum().item()
        total += yb.size(0)
    return fooled/total

models = [egeat_model.eval(), egeat_soup.eval(), pgd_model.eval()]
names  = ["EGEAT", "Soup", "PGD"]

Gx = nx.DiGraph()
for n in names:
    Gx.add_node(n)

edges = []
for i, si in enumerate(names):
    for j, tj in enumerate(names):
        if i==j: continue
        r = transfer_rate(models[i], models[j], test_loader, eps=float(EPS))
        Gx.add_edge(si, tj, weight=r)
        edges.append((si, tj, r))

# Layout & draw - Enhanced dark theme
fig, ax = plt.subplots(figsize=(10, 10))
pos = nx.circular_layout(Gx)
weights = [Gx[u][v]['weight'] for u,v in Gx.edges()]
w_scaled = [(1.0 + 8*w) for w in weights]  # emphasize thickness
nx.draw_networkx_nodes(Gx, pos, node_size=2800, node_color=COLORS[:3], 
                       linewidths=3, edgecolors=PALETTE["primary"], ax=ax)
nx.draw_networkx_labels(Gx, pos, font_size=18, font_color=PALETTE["text"], 
                        font_weight='bold', ax=ax)
nx.draw_networkx_edges(Gx, pos, width=w_scaled, edge_color=PALETTE["accent1"], 
                       arrows=True, arrowsize=25, connectionstyle='arc3,rad=0.15',
                       alpha=0.8, ax=ax)

# Annotate edge weights
for (u,v,r) in edges:
    x=(pos[u][0]+pos[v][0])/2
    y=(pos[u][1]+pos[v][1])/2
    ax.text(x, y, f"{r:.3f}", ha="center", va="center", 
           fontsize=13, color=PALETTE["text"], fontweight='bold',
           bbox=dict(boxstyle='round,pad=0.3', facecolor=DARK_THEME["axes.facecolor"], 
                    edgecolor=PALETTE["primary"], alpha=0.8))

ax.set_title("Adversarial Transfer Graph\n(Lower Transfer Rate = Better Robustness)", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)
ax.axis("off")
apply_dark_style(ax)
plt.tight_layout()
save_square_png("transfer_graph.png", fig=fig, size=10)
plt.show()


In [None]:

# 4) Variance vs Ensemble Size K: empirical variance of adversarial loss across K-model soups
# If snapshot list isn't available from training, we quickly generate extra snapshots by light finetuning copies.
import copy, torch

def collect_snapshots(base_model, k=5, steps=100):
    snaps = [copy.deepcopy(base_model).to(device).eval()]
    opt = torch.optim.SGD(base_model.parameters(), lr=1e-3, momentum=0.9)
    loss_fn = nn.CrossEntropyLoss()
    it = iter(train_loader)
    for i in range(1, k):
        # light finetune from previous snapshot to diversify
        for _ in range(steps):
            try:
                xb, yb = next(it)
            except StopIteration:
                it = iter(train_loader)
                xb, yb = next(it)
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = loss_fn(base_model(xb), yb); loss.backward(); opt.step()
        snaps.append(copy.deepcopy(base_model).to(device).eval())
    return snaps

def soup(models):
    base = copy.deepcopy(models[0]).to(device)
    with torch.no_grad():
        for p in base.parameters(): p.data.zero_()
        for m in models:
            for pb, pm in zip(base.parameters(), m.parameters()):
                pb.add_(pm.data)
        for p in base.parameters(): p.data.div_(len(models))
    return base

def adv_loss_on_loader(model, loader, eps=float(EPS)):
    loss_fn = nn.CrossEntropyLoss()
    total=[]; 
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, model, loss_fn, eps=eps)
        with torch.no_grad():
            l = loss_fn(model(xb_adv.to(device)), yb.to(device)).item()
        total.append(l)
    return np.array(total)

# collect snapshots from EGEAT model
snaps = collect_snapshots(copy.deepcopy(egeat_model).to(device).train(), k=5, steps=50)
Ks = [1,2,3,4,5]
variances = []
for k in Ks:
    S = soup(snaps[:k])
    losses = adv_loss_on_loader(S, test_loader, eps=float(EPS))
    variances.append(float(np.var(losses)))

# Smooth 3D Variance vs ensemble size (no overlapping elements)
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import interp1d

fig = plt.figure(figsize=(14, 12))
ax = fig.add_subplot(111, projection='3d')

# High-resolution smooth interpolation
K_grid = np.linspace(min(Ks), max(Ks), 200)
f = interp1d(Ks, variances, kind='cubic', fill_value='extrapolate')
var_interp = f(K_grid)

# Create smooth surface
K_mesh, var_mesh = np.meshgrid(K_grid, np.linspace(0, 1, 100))
Z_mesh = np.tile(var_interp, (len(var_mesh), 1))

# Smooth surface only (NO overlapping lines)
surf = ax.plot_surface(K_mesh, var_mesh, Z_mesh, cmap='plasma', alpha=0.92, 
                      edgecolor='none', linewidth=0, antialiased=True, shade=True,
                      rstride=2, cstride=2)

# Add smooth 3D line
ax.plot(K_grid, np.zeros_like(K_grid), var_interp, '-', 
       color=PALETTE["primary"], linewidth=5, alpha=0.95)

# Add scatter points at data locations
scatter = ax.scatter(Ks, np.zeros(len(Ks)), variances, s=250, c=variances, 
                    cmap='plasma', alpha=0.95, edgecolors=PALETTE["primary"], linewidths=3, zorder=5)

# Add value labels at data points
for k, var in zip(Ks, variances):
    ax.text(k, 0, var + max(variances)*0.03, f'K={k}\nVar={var:.4f}', 
           ha='center', va='bottom', color=PALETTE["text"], fontsize=11, fontweight='bold',
           bbox=dict(boxstyle='round,pad=0.4', facecolor=DARK_THEME["axes.facecolor"], 
                    edgecolor=PALETTE["primary"], alpha=0.9, linewidth=1.5))

# Professional camera angle
ax.view_init(elev=30, azim=45)
ax.set_xlabel("Ensemble Size K", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_ylabel("", fontsize=16)
ax.set_zlabel("Var[ Adversarial Loss ]", fontsize=16, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_title("Variance vs Ensemble Size (EGEAT Parameter Soups)", 
            fontsize=18, fontweight='bold', color=PALETTE["text"], pad=25)

# Style 3D axes
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('Variance', color=PALETTE["text"], fontsize=14, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)

plt.tight_layout()
save_square_png("variance_vs_k_3d.png", fig=fig, size=10)
plt.show()


In [None]:

# 5) 3D λ1–λ2 trade-off surface (clean vs robust accuracy)
from mpl_toolkits.mplot3d import Axes3D  # noqa: F401
import torch

grid_l1 = np.linspace(0.0, 0.2, 4)   # [0.00, 0.067, 0.133, 0.2]
grid_l2 = np.linspace(0.00, 0.08, 4) # [0.00, 0.027, 0.053, 0.08]

Acc_clean = np.zeros((len(grid_l1), len(grid_l2)))
Acc_pgd20 = np.zeros((len(grid_l1), len(grid_l2)))

def train_quick(l1, l2, epochs=15):
    cfg = EGEATConfig(epochs=epochs, eps=float(EPS), lambda_geom=float(l1), lambda_soup=float(l2), snapshots_k=3, lr=2e-4)
    m, _ = train_egeat(cfg)
    ac = accuracy(m, test_loader)
    ar = eval_adv_acc(m, test_loader, attack='pgd', eps=float(EPS), steps=20)
    return ac, ar

for i, l1 in enumerate(grid_l1):
    for j, l2 in enumerate(grid_l2):
        print(f"[λ-surface] λ1={l1:.3f}, λ2={l2:.3f}")
        ac, ar = train_quick(l1, l2, epochs=12)
        Acc_clean[i,j] = ac; Acc_pgd20[i,j] = ar

# Smooth Clean surface with interpolation
from scipy.interpolate import griddata

fig = plt.figure(figsize=(14, 12))
ax = fig.add_subplot(111, projection='3d')

# High-resolution smooth interpolation
l1_fine = np.linspace(grid_l1.min(), grid_l1.max(), 150)
l2_fine = np.linspace(grid_l2.min(), grid_l2.max(), 150)
L1_fine, L2_fine = np.meshgrid(l1_fine, l2_fine)

points = np.column_stack([grid_l1.repeat(len(grid_l2)), np.tile(grid_l2, len(grid_l1))])
values = Acc_clean.flatten()
Acc_clean_fine = griddata(points, values, (L1_fine, L2_fine), method='cubic', fill_value=np.nan)
if np.isnan(Acc_clean_fine).any():
    Acc_clean_fine[np.isnan(Acc_clean_fine)] = griddata(points, values, 
                                                        (L1_fine[np.isnan(Acc_clean_fine)], 
                                                         L2_fine[np.isnan(Acc_clean_fine)]), 
                                                        method='nearest')

surf = ax.plot_surface(L1_fine, L2_fine, Acc_clean_fine, linewidth=0, antialiased=True, 
                      alpha=0.97, cmap="viridis", edgecolor='none', rstride=2, cstride=2)

ax.view_init(elev=30, azim=45)
ax.set_xlabel("λ₁ (Geometric)", fontsize=15, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_ylabel("λ₂ (Soup)", fontsize=15, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_zlabel("Clean Accuracy", fontsize=15, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_title("Hyperparameter Surface: Clean Accuracy", fontsize=17, fontweight='bold', 
            color=PALETTE["text"], pad=25)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')
cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('Clean Accuracy', color=PALETTE["text"], fontsize=13, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)
save_square_png("lambda_surface_clean.png", fig=fig, size=10)
plt.show()

# Smooth Robust surface with interpolation
fig = plt.figure(figsize=(14, 12))
ax = fig.add_subplot(111, projection='3d')

values = Acc_pgd20.flatten()
Acc_pgd20_fine = griddata(points, values, (L1_fine, L2_fine), method='cubic', fill_value=np.nan)
if np.isnan(Acc_pgd20_fine).any():
    Acc_pgd20_fine[np.isnan(Acc_pgd20_fine)] = griddata(points, values, 
                                                        (L1_fine[np.isnan(Acc_pgd20_fine)], 
                                                         L2_fine[np.isnan(Acc_pgd20_fine)]), 
                                                        method='nearest')

surf = ax.plot_surface(L1_fine, L2_fine, Acc_pgd20_fine, linewidth=0, antialiased=True, 
                      alpha=0.97, cmap="plasma", edgecolor='none', rstride=2, cstride=2)

ax.view_init(elev=30, azim=45)
ax.set_xlabel("λ₁ (Geometric)", fontsize=15, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_ylabel("λ₂ (Soup)", fontsize=15, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_zlabel("PGD-20 Accuracy", fontsize=15, color=PALETTE["text"], fontweight='bold', labelpad=15)
ax.set_title("Hyperparameter Surface: Robust Accuracy (PGD-20)", fontsize=17, fontweight='bold', 
            color=PALETTE["text"], pad=25)
ax.xaxis.pane.fill = False
ax.yaxis.pane.fill = False
ax.zaxis.pane.fill = False
ax.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax.tick_params(colors=PALETTE["text_secondary"], labelsize=11)
ax.grid(True, alpha=0.2, linestyle='--')
cbar = fig.colorbar(surf, ax=ax, shrink=0.7, pad=0.12)
cbar.set_label('PGD-20 Accuracy', color=PALETTE["text"], fontsize=13, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=11)
save_square_png("lambda_surface_robust.png", fig=fig, size=10)
plt.show()


## Summary: All Results and Metrics

This section provides a comprehensive summary of all evaluation results with professional visualizations.


In [None]:
# ===== COMPREHENSIVE RESULTS SUMMARY =====
# Display all metrics in a professional dashboard format

print("\n" + "="*80)
print(" " * 20 + "EGEAT EXPERIMENT SUMMARY")
print("="*80)

# Recompute all metrics for summary
all_metrics = {
    'EGEAT Model': {
        'Clean': accuracy(egeat_model, test_loader),
        'FGSM': eval_adv_acc(egeat_model, test_loader, 'fgsm'),
        'PGD-20': eval_adv_acc(egeat_model, test_loader, 'pgd', steps=20)
    },
    'EGEAT Soup': {
        'Clean': accuracy(egeat_soup, test_loader),
        'FGSM': eval_adv_acc(egeat_soup, test_loader, 'fgsm'),
        'PGD-20': eval_adv_acc(egeat_soup, test_loader, 'pgd', steps=20)
    },
    'PGD Model': {
        'Clean': accuracy(pgd_model, test_loader),
        'FGSM': eval_adv_acc(pgd_model, test_loader, 'fgsm'),
        'PGD-20': eval_adv_acc(pgd_model, test_loader, 'pgd', steps=20)
    }
}

# Print formatted table
print(f"\n{'Model':<20} {'Clean Acc':<15} {'FGSM Acc':<15} {'PGD-20 Acc':<15}")
print("-"*80)
for model_name, metrics in all_metrics.items():
    print(f"{model_name:<20} {metrics['Clean']:<15.4f} {metrics['FGSM']:<15.4f} {metrics['PGD-20']:<15.4f}")
print("="*80)

# Create comprehensive dashboard visualization
fig = plt.figure(figsize=(16, 10))
fig.patch.set_facecolor(DARK_THEME["figure.facecolor"])

# Create a 2x3 grid of subplots
gs = fig.add_gridspec(2, 3, hspace=0.3, wspace=0.3)

# 1. 3D Accuracy comparison
from mpl_toolkits.mplot3d import Axes3D
from scipy.interpolate import griddata, interp1d

ax1 = fig.add_subplot(gs[0, 0], projection='3d')
models_list = list(all_metrics.keys())
clean_vals = [all_metrics[m]['Clean'] for m in models_list]
fgsm_vals = [all_metrics[m]['FGSM'] for m in models_list]
pgd_vals = [all_metrics[m]['PGD-20'] for m in models_list]

x = np.arange(len(models_list))
attack_types = ['Clean', 'FGSM', 'PGD-20']
y = np.arange(len(attack_types))
X, Y = np.meshgrid(x, y)
Z = np.array([clean_vals, fgsm_vals, pgd_vals])

# Smooth interpolated surface (NO overlapping bars)
from scipy.interpolate import griddata

X_coarse, Y_coarse = np.meshgrid(x, y)
Z_coarse = Z

x_fine = np.linspace(0, len(models_list)-1, 100)
y_fine = np.linspace(0, len(attack_types)-1, 100)
X_fine, Y_fine = np.meshgrid(x_fine, y_fine)

points = np.column_stack([X_coarse.flatten(), Y_coarse.flatten()])
values = Z_coarse.flatten()
Z_fine = griddata(points, values, (X_fine, Y_fine), method='cubic', fill_value=np.nan)
if np.isnan(Z_fine).any():
    Z_fine[np.isnan(Z_fine)] = griddata(points, values, 
                                       (X_fine[np.isnan(Z_fine)], Y_fine[np.isnan(Z_fine)]), 
                                       method='nearest')

surf = ax1.plot_surface(X_fine, Y_fine, Z_fine, cmap='viridis', alpha=0.96, 
                       edgecolor='none', linewidth=0, antialiased=True, shade=True,
                       rstride=2, cstride=2, vmin=0, vmax=1)

ax1.view_init(elev=30, azim=45)
ax1.set_xlabel('Model', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax1.set_ylabel('Attack Type', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax1.set_zlabel('Accuracy', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax1.set_title('Accuracy Comparison', fontsize=13, fontweight='bold', color=PALETTE["text"])
ax1.set_xticks(x)
ax1.set_xticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=9)
ax1.set_yticks(y)
ax1.set_yticklabels(attack_types, color=PALETTE["text"], fontsize=9)
ax1.set_zlim(0, 1.0)

# Style 3D axes
ax1.xaxis.pane.fill = False
ax1.yaxis.pane.fill = False
ax1.zaxis.pane.fill = False
ax1.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax1.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax1.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax1.tick_params(colors=PALETTE["text_secondary"], labelsize=9)
ax1.grid(True, alpha=0.2, linestyle='--')

# 2. 3D Robustness comparison
from mpl_toolkits.mplot3d import Axes3D

ax2 = fig.add_subplot(gs[0, 1], projection='3d')
robustness = [all_metrics[m]['PGD-20'] for m in models_list]

# Smooth surface instead of bars
x_pos = np.linspace(0, len(models_list)-1, 100)
y_pos = np.linspace(0, 0.5, 50)
X_pos, Y_pos = np.meshgrid(x_pos, y_pos)
Z_pos = np.tile(robustness, (len(y_pos), 1))
f = interp1d(np.arange(len(models_list)), robustness, kind='cubic', fill_value='extrapolate')
Z_smooth = np.tile(f(x_pos), (len(y_pos), 1))

surf2 = ax2.plot_surface(X_pos, Y_pos, Z_smooth, cmap='plasma', alpha=0.95, 
                         edgecolor='none', linewidth=0, antialiased=True, shade=True,
                         rstride=2, cstride=2)

# Add value labels
for i, (val, name) in enumerate(zip(robustness, models_list)):
    ax2.text(i, 0, val + 0.05, f'{val:.3f}', 
            ha='center', va='bottom', color=PALETTE["text"], 
            fontsize=11, fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=DARK_THEME["axes.facecolor"], 
                     edgecolor=PALETTE["primary"], alpha=0.8))

ax2.view_init(elev=30, azim=45)
ax2.set_xlabel('Model', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax2.set_ylabel('', fontsize=12)
ax2.set_zlabel('PGD-20 Accuracy', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax2.set_title('Robustness Ranking', fontsize=13, fontweight='bold', color=PALETTE["text"])
ax2.set_xticks(np.arange(len(models_list)))
ax2.set_xticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=9)
ax2.set_yticks([])
ax2.set_zlim(0, 1.0)

# Style 3D axes
ax2.xaxis.pane.fill = False
ax2.yaxis.pane.fill = False
ax2.zaxis.pane.fill = False
ax2.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax2.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax2.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax2.tick_params(colors=PALETTE["text_secondary"], labelsize=9)
ax2.grid(True, alpha=0.2, linestyle='--')

# 3. 3D Performance improvement visualization
from mpl_toolkits.mplot3d import Axes3D

ax3 = fig.add_subplot(gs[0, 2], projection='3d')
baseline_pgd = all_metrics['PGD Model']['PGD-20']
improvements = [
    (all_metrics['EGEAT Model']['PGD-20'] - baseline_pgd) * 100,
    (all_metrics['EGEAT Soup']['PGD-20'] - baseline_pgd) * 100
]
improvement_names = ['EGEAT Model', 'EGEAT Soup']

# Smooth surface instead of bars
x_pos = np.linspace(0, len(improvement_names)-1, 100)
y_pos = np.linspace(0, 0.5, 50)
X_pos, Y_pos = np.meshgrid(x_pos, y_pos)
f = interp1d(np.arange(len(improvement_names)), improvements, kind='cubic', fill_value='extrapolate')
Z_smooth = np.tile(f(x_pos), (len(y_pos), 1))

surf3 = ax3.plot_surface(X_pos, Y_pos, Z_smooth, cmap='RdYlGn', alpha=0.95, 
                         edgecolor='none', linewidth=0, antialiased=True, shade=True,
                         rstride=2, cstride=2)

# Add value labels
for i, (imp, name) in enumerate(zip(improvements, improvement_names)):
    ax3.text(i, 0, imp + (1 if imp > 0 else -1), f'{imp:+.2f}%', 
            ha='center', va='bottom' if imp > 0 else 'top', color=PALETTE["text"], 
            fontsize=11, fontweight='bold',
            bbox=dict(boxstyle='round,pad=0.3', facecolor=DARK_THEME["axes.facecolor"], 
                     edgecolor=PALETTE["primary"], alpha=0.8))

# Add zero plane
xx, yy = np.meshgrid([-0.5, len(improvement_names)-0.5], [-0.5, 0.5])
zz = np.zeros_like(xx)
ax3.plot_surface(xx, yy, zz, alpha=0.3, color=PALETTE["text_secondary"])

ax3.view_init(elev=30, azim=45)
ax3.set_xlabel('Model', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax3.set_ylabel('', fontsize=12)
ax3.set_zlabel('Improvement (%)', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax3.set_title('Improvement over PGD Baseline', fontsize=13, fontweight='bold', color=PALETTE["text"])
ax3.set_xticks(np.arange(len(improvement_names)))
ax3.set_xticklabels([n.replace(' ', '\n') for n in improvement_names], 
                   color=PALETTE["text"], fontsize=10)
ax3.set_yticks([])

# Style 3D axes
ax3.xaxis.pane.fill = False
ax3.yaxis.pane.fill = False
ax3.zaxis.pane.fill = False
ax3.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax3.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax3.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax3.tick_params(colors=PALETTE["text_secondary"], labelsize=9)
ax3.grid(True, alpha=0.2, linestyle='--')

# 4. 3D Transferability visualization
from mpl_toolkits.mplot3d import Axes3D

ax4 = fig.add_subplot(gs[1, :2], projection='3d')
transfer_matrix = np.zeros((len(models_list), len(models_list)))
for i, src in enumerate(models_list):
    for j, tgt in enumerate(models_list):
        if i != j:
            transfer_matrix[i, j] = transfer_rate(
                [egeat_model, egeat_soup, pgd_model][i],
                [egeat_model, egeat_soup, pgd_model][j],
                test_loader, eps=EPS
            )
        else:
            transfer_matrix[i, j] = 1.0  # Self-attack (always 100%)

x = np.arange(len(models_list))
y = np.arange(len(models_list))
X, Y = np.meshgrid(x, y)
Z = transfer_matrix

# Smooth interpolated surface (NO overlapping bars)
X_coarse, Y_coarse = np.meshgrid(x, y)
Z_coarse = transfer_matrix

x_fine = np.linspace(0, len(models_list)-1, 100)
y_fine = np.linspace(0, len(models_list)-1, 100)
X_fine, Y_fine = np.meshgrid(x_fine, y_fine)

points = np.column_stack([X_coarse.flatten(), Y_coarse.flatten()])
values = Z_coarse.flatten()
Z_fine = griddata(points, values, (X_fine, Y_fine), method='cubic', fill_value=np.nan)
if np.isnan(Z_fine).any():
    Z_fine[np.isnan(Z_fine)] = griddata(points, values, 
                                       (X_fine[np.isnan(Z_fine)], Y_fine[np.isnan(Z_fine)]), 
                                       method='nearest')

surf = ax4.plot_surface(X_fine, Y_fine, Z_fine, cmap='RdYlGn_r', alpha=0.96, 
                       edgecolor='none', linewidth=0, antialiased=True, shade=True,
                       rstride=2, cstride=2, vmin=0, vmax=1)

# Add value labels
for i in range(len(models_list)):
    for j in range(len(models_list)):
        if i != j:
            ax4.text(i, j, transfer_matrix[i, j] + 0.05, f'{transfer_matrix[i, j]:.2f}', 
                    ha='center', va='bottom', color=PALETTE["text"], fontsize=9, fontweight='bold')

ax4.view_init(elev=30, azim=45)
ax4.set_xlabel('Source Model', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax4.set_ylabel('Target Model', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax4.set_zlabel('Transfer Rate', fontsize=12, color=PALETTE["text"], fontweight='bold', labelpad=10)
ax4.set_title('Adversarial Transferability Matrix', fontsize=13, fontweight='bold', color=PALETTE["text"])
ax4.set_xticks(x)
ax4.set_xticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=10)
ax4.set_yticks(y)
ax4.set_yticklabels([m.replace(' ', '\n') for m in models_list], color=PALETTE["text"], fontsize=10)
ax4.set_zlim(0, 1.1)

# Style 3D axes
ax4.xaxis.pane.fill = False
ax4.yaxis.pane.fill = False
ax4.zaxis.pane.fill = False
ax4.xaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax4.yaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax4.zaxis.pane.set_edgecolor(PALETTE["text_secondary"])
ax4.tick_params(colors=PALETTE["text_secondary"], labelsize=9)
ax4.grid(True, alpha=0.2, linestyle='--')

cbar = fig.colorbar(surf, ax=ax4, shrink=0.6, pad=0.1)
cbar.set_label('Transfer Rate', color=PALETTE["text"], fontsize=11, fontweight='bold')
cbar.ax.yaxis.set_tick_params(color=PALETTE["text_secondary"], labelsize=9)

# 5. Key statistics text box
ax5 = fig.add_subplot(gs[1, 2])
ax5.axis('off')
stats_text = f"""
KEY STATISTICS
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Best Clean Accuracy:
  {max(clean_vals):.4f} ({models_list[np.argmax(clean_vals)]})

Best FGSM Accuracy:
  {max(fgsm_vals):.4f} ({models_list[np.argmax(fgsm_vals)]})

Best PGD-20 Accuracy:
  {max(pgd_vals):.4f} ({models_list[np.argmax(pgd_vals)]})

Average Robustness Gain:
  {np.mean(improvements):+.2f}%

Dataset: {DATASET}
Epsilon: {EPS:.4f}
"""
ax5.text(0.1, 0.5, stats_text, transform=ax5.transAxes, fontsize=11,
        color=PALETTE["text"], verticalalignment='center',
        family='monospace', fontweight='bold',
        bbox=dict(boxstyle='round', facecolor=DARK_THEME["axes.facecolor"],
                 edgecolor=PALETTE["primary"], linewidth=2, alpha=0.9))

fig.suptitle('EGEAT: Complete Evaluation Dashboard', fontsize=20, fontweight='bold',
            color=PALETTE["text"], y=0.98)
plt.tight_layout(rect=[0, 0, 1, 0.96])
savefig("fig_complete_dashboard.png", dpi=300)
plt.show()

print("\n✓ All visualizations generated successfully!")
print(f"✓ All figures saved to: {SAVE_DIR}")
print(f"✓ Showcase figures saved to: {BLOG_DIR}")
