
# EGEAT Reproduction Notebook (Lightweight)

**Exact Geometric Ensemble Adversarial Training (EGEAT)** — lightweight, self-contained reproduction on `sklearn.digits` (8×8 grayscale) so it runs offline.

> This notebook implements the core ideas in the uploaded paper: Exact inner maximization (dual-norm), gradient-space geometric regularization, and ensemble/weight-space smoothing, with figures for gradient similarity, loss landscape slices, adversarial examples, and an ablation.

**Caveat:** The original paper reports MNIST/CIFAR-10/DREBIN. Since internet is disabled here, we use `sklearn.digits` to make the code fully runnable. You can switch to MNIST/CIFAR-10 by replacing the dataset loader if you have those locally.


In [None]:

# Environment & config
import math, random, time, copy, itertools
from dataclasses import dataclass
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset, random_split
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

# Reproducibility
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:

# Data: sklearn digits (1797 samples, 10 classes, 8x8 images)
digits = load_digits()
X = digits.images.astype(np.float32)   # shape (N, 8, 8)
y = digits.target.astype(np.int64)     # shape (N,)

# Normalize to [0,1]
X = (X - X.min()) / (X.max() - X.min())

# Train/val/test split
from sklearn.model_selection import train_test_split
X_trainval, X_test, y_trainval, y_test = train_test_split(X, y, test_size=0.15, random_state=seed, stratify=y)
X_train, X_val, y_train, y_val = train_test_split(X_trainval, y_trainval, test_size=0.1765, random_state=seed, stratify=y_trainval)  # ~0.15 total val

# Torch tensors
X_train_t = torch.from_numpy(X_train).unsqueeze(1)  # (N,1,8,8)
X_val_t   = torch.from_numpy(X_val).unsqueeze(1)
X_test_t  = torch.from_numpy(X_test).unsqueeze(1)
y_train_t = torch.from_numpy(y_train)
y_val_t   = torch.from_numpy(y_val)
y_test_t  = torch.from_numpy(y_test)

train_ds = TensorDataset(X_train_t, y_train_t)
val_ds   = TensorDataset(X_val_t, y_val_t)
test_ds  = TensorDataset(X_test_t, y_test_t)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=256, shuffle=False)
test_loader  = DataLoader(test_ds,  batch_size=256, shuffle=False)

len_train, len_val, len_test = len(train_ds), len(val_ds), len(test_ds)
print(len_train, len_val, len_test)


In [None]:

# Small MLP for 8x8 images
class SmallMLP(nn.Module):
    def __init__(self, hidden=128):
        super().__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(64, hidden)
        self.fc2 = nn.Linear(hidden, hidden)
        self.fc3 = nn.Linear(hidden, 10)

    def forward(self, x):
        x = self.flatten(x)
        x = F.leaky_relu(self.fc1(x), negative_slope=0.1)
        x = F.leaky_relu(self.fc2(x), negative_slope=0.1)
        x = self.fc3(x)
        return x

def accuracy(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            pred = logits.argmax(1)
            correct += (pred == yb).sum().item()
            total += yb.size(0)
    return correct/total


In [None]:

# Dual-norm exact adversary (first-order exact solution) and attacks
def exact_perturbation(x, y, model, loss_fn, eps=0.3, p='linf'):
    x = x.detach().clone().to(device)
    x.requires_grad_(True)
    model.zero_grad(set_to_none=True)
    logits = model(x)
    loss = loss_fn(logits, y.to(device))
    loss.backward()
    g = x.grad.detach()

    if p == 'linf':
        delta = eps * g.sign()
    elif p == 'l2':
        g_flat = g.view(g.size(0), -1)
        norm = g_flat.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)
        unit = (g_flat / norm).view_as(g)
        delta = eps * unit
    elif p == 'l1':
        g_flat = g.view(g.size(0), -1)
        denom = g_flat.abs().sum(dim=1, keepdim=True).clamp_min(1e-12)
        unit = (g_flat / denom).view_as(g)
        delta = eps * unit.sign()
    else:
        raise ValueError("p must be 'linf', 'l2', or 'l1'")

    x_adv = (x + delta).clamp(0.0, 1.0).detach()
    return x_adv

def fgsm_attack(x, y, model, loss_fn, eps=0.3):
    return exact_perturbation(x, y, model, loss_fn, eps=eps, p='linf')

def pgd_attack(x, y, model, loss_fn, eps=0.3, alpha=0.1, steps=20, p='linf'):
    x0 = x.detach().clone().to(device)
    x_adv = x0 + torch.empty_like(x0).uniform_(-eps, eps)
    x_adv = x_adv.clamp(0,1).detach()

    for _ in range(steps):
        x_adv.requires_grad_(True)
        model.zero_grad(set_to_none=True)
        loss = loss_fn(model(x_adv), y.to(device))
        loss.backward()
        g = x_adv.grad.detach()

        if p == 'linf':
            x_adv = x_adv + alpha * g.sign()
            eta = torch.clamp(x_adv - x0, min=-eps, max=eps)
            x_adv = (x0 + eta).clamp(0,1).detach()
        elif p == 'l2':
            g_flat = g.view(g.size(0), -1)
            norm = g_flat.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)
            step = (alpha * g_flat / norm).view_as(g)
            x_adv = x_adv + step
            d = (x_adv - x0).view(x.size(0), -1)
            d_norm = d.norm(p=2, dim=1, keepdim=True).clamp_min(1e-12)
            d = d * (eps / d_norm).clamp(max=1.0)
            x_adv = (x0 + d.view_as(x)).clamp(0,1).detach()
        else:
            raise ValueError("p must be 'linf' or 'l2' in this PGD")
    return x_adv


In [None]:

# Geometric regularization & ensemble smoothing
def batch_input_gradients(models, x, y, loss_fn):
    grads = []
    for m in models:
        m.zero_grad(set_to_none=True)
        m.eval()
        x_ = x.detach().clone().to(device).requires_grad_(True)
        loss = loss_fn(m(x_), y.to(device))
        loss.backward()
        grads.append(x_.grad.detach())
    return grads

def cosine_similarity(a, b, eps=1e-12):
    a = a.view(a.size(0), -1)
    b = b.view(a.size(0), -1)
    num = (a*b).sum(dim=1)
    den = a.norm(p=2, dim=1)*b.norm(p=2, dim=1) + eps
    return (num/den).mean()

def geometric_regularizer(models, x, y, loss_fn):
    if len(models) < 2:
        return torch.tensor(0.0, device=device)
    grads = batch_input_gradients(models, x, y, loss_fn)
    sims = []
    for i in range(len(models)):
        for j in range(i+1, len(models)):
            sims.append(cosine_similarity(grads[i], grads[j]))
    return torch.stack(sims).mean()

def update_soup(snapshots):
    if not snapshots:
        return None
    base = copy.deepcopy(snapshots[0])
    with torch.no_grad():
        for p in base.parameters():
            p.data.zero_()
        for s in snapshots:
            for pb, ps in zip(base.parameters(), s.parameters()):
                pb.add_(ps.data)
        for p in base.parameters():
            p.data.div_(len(snapshots))
    return base


In [None]:

# Training loop: EGEAT
from dataclasses import dataclass

@dataclass
class EGEATConfig:
    epochs: int = 8
    eps: float = 0.3
    norm: str = "linf"
    lambda_geom: float = 0.1
    lambda_soup: float = 0.001
    snapshots_k: int = 4
    lr: float = 2e-3

def train_egeat(config: EGEATConfig):
    model = SmallMLP().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=config.lr)
    loss_fn = nn.CrossEntropyLoss()

    snapshots = []
    soup_model = None
    snapshot_every = max(1, config.epochs // config.snapshots_k)

    for epoch in range(1, config.epochs+1):
        model.train()
        total_loss = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)

            xb_adv = exact_perturbation(xb, yb, model, loss_fn, eps=config.eps, p=config.norm)

            logits = model(xb_adv)
            L_adv = loss_fn(logits, yb)

            models_for_geom = [model]
            if soup_model is not None:
                models_for_geom.append(soup_model)
            L_geom = geometric_regularizer(models_for_geom, xb, yb, loss_fn)

            if soup_model is not None:
                L_soup = torch.tensor(0.0, device=device)
                for p, ps in zip(model.parameters(), soup_model.parameters()):
                    L_soup = L_soup + (p - ps).pow(2).sum()
            else:
                L_soup = torch.tensor(0.0, device=device)

            loss = L_adv + config.lambda_geom * L_geom + config.lambda_soup * L_soup

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            total_loss += loss.item() * xb.size(0)

        if epoch % snapshot_every == 0 or epoch == config.epochs:
            snapshots.append(copy.deepcopy(model).to(device))
            soup_model = update_soup(snapshots)

        val_acc = accuracy(model, val_loader)
        print(f"Epoch {epoch}/{config.epochs}  loss={total_loss/len(train_ds):.4f}  val_acc={val_acc:.3f}  snaps={len(snapshots)}")

    final_soup = update_soup(snapshots) if snapshots else copy.deepcopy(model)
    return model, final_soup

config = EGEATConfig()
egeat_model, egeat_soup = train_egeat(config)
print("Validation acc (model):", accuracy(egeat_model, val_loader))
print("Validation acc (soup): ", accuracy(egeat_soup,  val_loader))


In [None]:

# Baseline: PGD adversarial training
from dataclasses import dataclass

@dataclass
class PGDConfig:
    epochs: int = 8
    eps: float = 0.3
    alpha: float = 0.1
    steps: int = 5
    lr: float = 2e-3

def train_pgd_adv(config: PGDConfig):
    model = SmallMLP().to(device)
    opt = torch.optim.Adam(model.parameters(), lr=config.lr)
    loss_fn = nn.CrossEntropyLoss()
    for epoch in range(1, config.epochs+1):
        model.train()
        total = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            xb_adv = pgd_attack(xb, yb, model, loss_fn, eps=config.eps, alpha=config.alpha, steps=config.steps, p='linf')
            logits = model(xb_adv)
            loss = loss_fn(logits, yb)
            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()
            total += loss.item() * xb.size(0)
        val_acc = accuracy(model, val_loader)
        print(f"[PGD] Epoch {epoch}/{config.epochs} loss={total/len(train_ds):.4f} val_acc={val_acc:.3f}")
    return model

pgd_model = train_pgd_adv(PGDConfig())
print("Validation acc (PGD):", accuracy(pgd_model, val_loader))


In [None]:

# Evaluation: clean and adversarial accs
def eval_adv_acc(model, loader, attack='fgsm', eps=0.3, alpha=0.1, steps=20):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    correct, total = 0, 0
    for xb, yb in loader:
        if attack == 'fgsm':
            xb_adv = fgsm_attack(xb, yb, model, loss_fn, eps=eps)
        elif attack == 'pgd':
            xb_adv = pgd_attack(xb, yb, model, loss_fn, eps=eps, alpha=alpha, steps=steps, p='linf')
        else:
            raise ValueError("attack must be 'fgsm' or 'pgd'")
        with torch.no_grad():
            pred = model(xb_adv.to(device)).argmax(1).cpu()
        correct += (pred == yb).sum().item()
        total += yb.size(0)
    return correct/total

clean_e = accuracy(egeat_model, test_loader)
fgsm_e  = eval_adv_acc(egeat_model, test_loader, attack='fgsm', eps=0.3)
pgd_e   = eval_adv_acc(egeat_model, test_loader, attack='pgd',  eps=0.3, alpha=0.1, steps=20)

clean_p = accuracy(pgd_model, test_loader)
fgsm_p  = eval_adv_acc(pgd_model, test_loader, attack='fgsm', eps=0.3)
pgd_p   = eval_adv_acc(pgd_model, test_loader, attack='pgd',  eps=0.3, alpha=0.1, steps=20)

print("EGEAT — clean/FGSM/PGD:", round(clean_e,3), round(fgsm_e,3), round(pgd_e,3))
print("PGD   — clean/FGSM/PGD:", round(clean_p,3), round(fgsm_p,3), round(pgd_p,3))


In [None]:

# Gradient similarity heatmap
loss_fn = nn.CrossEntropyLoss()
models = [egeat_model.to(device).eval(), egeat_soup.to(device).eval(), pgd_model.to(device).eval()]
names  = ["EGEAT model", "EGEAT soup", "PGD model"]

xb, yb = next(iter(test_loader))
xb, yb = xb.to(device), yb.to(device)
grads = []
for m in models:
    m.zero_grad(set_to_none=True)
    x_ = xb.detach().clone().requires_grad_(True)
    loss = loss_fn(m(x_), yb)
    loss.backward()
    grads.append(x_.grad.detach())

def pairwise_grad_cosine(grads):
    M = len(grads)
    mat = torch.zeros(M, M)
    for i in range(M):
        for j in range(M):
            mat[i,j] = cosine_similarity(grads[i], grads[j]).detach().cpu()
    return mat

G = pairwise_grad_cosine(grads).numpy()

plt.figure()
plt.imshow(G, interpolation='nearest')
plt.xticks(range(len(names)), names, rotation=45, ha='right')
plt.yticks(range(len(names)), names)
plt.title("Gradient Subspace Cosine Similarity")
plt.colorbar()
plt.tight_layout()
plt.savefig("/mnt/data/fig_gradient_similarity.png")
plt.show()


In [None]:

# Loss surface slice near EGEAT parameters
def get_params_vector(model):
    return torch.cat([p.detach().view(-1) for p in model.parameters()])

def set_params_vector(model, vec):
    idx = 0
    for p in model.parameters():
        n = p.numel()
        p.data.copy_(vec[idx:idx+n].view_as(p))
        idx += n

base = copy.deepcopy(egeat_model).to(device).eval()
w = get_params_vector(base)
d1 = torch.randn_like(w); d2 = torch.randn_like(w)
d1 = d1 / d1.norm(); d2 = d2 / d2.norm()

grid = 15
alphas = torch.linspace(-0.5, 0.5, grid)
betas  = torch.linspace(-0.5, 0.5, grid)
Z = np.zeros((grid, grid), dtype=np.float32)

with torch.no_grad():
    xb_small, yb_small = xb[:128], yb[:128]
    for i, a in enumerate(alphas):
        for j, b in enumerate(betas):
            tmp = copy.deepcopy(base).to(device)
            set_params_vector(tmp, w + a*d1 + b*d2)
            z = nn.CrossEntropyLoss()(tmp(xb_small), yb_small).item()
            Z[i,j] = z

plt.figure()
plt.contourf(alphas.cpu(), betas.cpu(), Z.T, levels=20)
plt.title("Loss Surface Slice near EGEAT parameters")
plt.xlabel("alpha (dir 1)")
plt.ylabel("beta (dir 2)")
plt.tight_layout()
plt.savefig("/mnt/data/fig_loss_surface.png")
plt.show()


In [None]:

# Visualize adversarial examples
idxs = torch.arange(0, 10)
xv = X_test_t[idxs]
yv = y_test_t[idxs]

loss_fn = nn.CrossEntropyLoss()
x_fgsm_pgd = fgsm_attack(xv, yv, pgd_model, loss_fn, eps=0.3).cpu()
x_exact_e  = exact_perturbation(xv, yv, egeat_model, loss_fn, eps=0.3, p='linf').cpu()

def show_triplet(orig, a, b, title_a="FGSM(PGD)", title_b="Exact(EGEAT)"):
    n = orig.size(0)
    fig, axes = plt.subplots(3, n, figsize=(1.2*n, 3.6))
    for i in range(n):
        axes[0,i].imshow(orig[i,0].numpy(), vmin=0, vmax=1)
        axes[0,i].axis("off")
        axes[1,i].imshow(a[i,0].numpy(), vmin=0, vmax=1)
        axes[1,i].axis("off")
        axes[2,i].imshow(b[i,0].numpy(), vmin=0, vmax=1)
        axes[2,i].axis("off")
    axes[0,0].set_ylabel("Original")
    axes[1,0].set_ylabel(title_a)
    axes[2,0].set_ylabel(title_b)
    plt.tight_layout()
    plt.savefig("/mnt/data/fig_adversarial_grid.png")
    plt.show()

show_triplet(xv, x_fgsm_pgd, x_exact_e)


In [None]:

# Transferability experiment: craft on source, test on target
def transfer_rate(source_model, target_model, loader, eps=0.3):
    loss_fn = nn.CrossEntropyLoss()
    total, fooled = 0, 0
    for xb, yb in loader:
        xb_adv = fgsm_attack(xb, yb, source_model, loss_fn, eps=eps)
        with torch.no_grad():
            pred_t = target_model(xb_adv.to(device)).argmax(1).cpu()
        fooled += (pred_t != yb).sum().item()
        total += yb.size(0)
    return fooled/total

models = [egeat_model.to(device).eval(), egeat_soup.to(device).eval(), pgd_model.to(device).eval()]
names  = ["EGEAT model", "EGEAT soup", "PGD model"]
pairs = [(0,1), (0,2), (1,2)]
rates = []
for i,j in pairs:
    r = transfer_rate(models[i], models[j], test_loader, eps=0.3)
    rates.append(r)
    print(f"Transfer {names[i]} -> {names[j]}: {r:.3f}")

plt.figure()
xs = np.arange(len(pairs))
plt.bar(xs, rates)
plt.xticks(xs, [f"{names[i]}→{names[j]}" for i,j in pairs], rotation=30, ha='right')
plt.ylabel("Transfer Rate")
plt.title("Adversarial Transferability (FGSM)")
plt.tight_layout()
plt.savefig("/mnt/data/fig_transfer.png")
plt.show()


In [None]:

# Ablation (tiny epochs for speed): effect of lambda_geom and lambda_soup
import pandas as pd

def quick_egeat(epochs, lam_geom, lam_soup):
    cfg = EGEATConfig(epochs=4, eps=0.3, norm="linf",
                      lambda_geom=lam_geom, lambda_soup=lam_soup,
                      snapshots_k=2, lr=2e-3)
    m, s = train_egeat(cfg)
    clean = accuracy(m, test_loader)
    pgd20 = eval_adv_acc(m, test_loader, attack='pgd', eps=0.3, alpha=0.1, steps=20)
    # simple entropy proxy (not true ECE)
    m.eval(); ent = []
    with torch.no_grad():
        for xb, yb in test_loader:
            p = F.softmax(m(xb.to(device)), dim=1).cpu().numpy()
            ent.append(-(p*np.log(p+1e-12)).sum(axis=1))
    ece_proxy = float(np.mean(ent))
    return clean, pgd20, ece_proxy

settings = [(0.0,0.0), (0.1,0.0), (0.1,0.001)]
results = []
for lam_g, lam_s in settings:
    print(f"ABLT: lambda_geom={lam_g}, lambda_soup={lam_s}")
    c, p20, ece = quick_egeat(epochs=4, lam_geom=lam_g, lam_soup=lam_s)
    results.append((lam_g, lam_s, c, p20, ece))

df = pd.DataFrame(results, columns=["lambda_geom","lambda_soup","clean_acc","pgd20_acc","ece_proxy"])
print(df)
df.to_csv("/mnt/data/ablation_results.csv", index=False)

plt.figure()
plt.plot(df["lambda_geom"], df["pgd20_acc"], marker="o")
plt.xlabel("lambda_geom")
plt.ylabel("PGD-20 Accuracy")
plt.title("Ablation: Effect of geometric regularization")
plt.tight_layout()
plt.savefig("/mnt/data/fig_ablation.png")
plt.show()


In [None]:

print("=== SUMMARY (Digits dataset) ===")
print("Saved figures to /mnt/data/: fig_gradient_similarity.png, fig_loss_surface.png, fig_adversarial_grid.png, fig_transfer.png, fig_ablation.png")
print("Ablation CSV: /mnt/data/ablation_results.csv")
