In [1]:
# Colab cell 1: install dependencies
# - faiss-cpu for fast nearest-neighbor search (IndexFlatL2)
# - torchattacks optional (we also include minimal FGSM/PGD implementations)
# Note: use faiss-gpu if available, but faiss-cpu works fine for CIFAR-10.
!pip install faiss-cpu
!pip install torchattacks


Collecting faiss-cpu
  Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.1 kB)
Downloading faiss_cpu-1.12.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (31.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m57.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: faiss-cpu
Successfully installed faiss-cpu-1.12.0
Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1

In [2]:
# Colab cell 2
# Standard imports and hyperparameters. Change these to suit your GPU/time budget.
import os
import time
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T

import faiss  # for retrieval index

# For plotting
import matplotlib.pyplot as plt

# Device (Colab T4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Reproducibility seeds
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)

# Hyperparameters (start values tuned for Colab)
IMG_SIZE = 224            # using ResNet-18 pretrained requires 224x224
BATCH_SIZE = 64
NUM_WORKERS = 2
NUM_CLASSES = 10          # CIFAR-10
LR = 1e-3
EPOCHS = 20               # lower for quick runs; increase if you have time
K = 10                     # default retrieved neighbors
K_OPTIONS = [5, 10]       # experiment with 5 and 10 as in paper
C = 10                    # K' = C*K for randomization if used
MIXUP_WEIGHT = 0.5        # weight for local mixup loss in total loss


Device: cuda


In [3]:
# Colab cell 3
# Dataset pipeline: Resize to 224 for ResNet-18. (Paper used custom small network for CIFAR; ResNet-18 is fine.)
transform_train = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])
transform_test = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = torch.utils.data.DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print("Train size:", len(trainset), "Test size:", len(testset))


100%|██████████| 170M/170M [00:04<00:00, 41.0MB/s]


Train size: 50000 Test size: 10000


In [4]:
# Colab cell 4
# Implements Section 2 (phi' and phi). phi_prime can be pretrained ResNet-18 (frozen) used for retrieval keys.
# phi is the trainable feature extractor used inside RaCNN. For simplicity we use same architecture for both.

def get_resnet18_backbone(pretrained=True, out_dim=None):
    """
    Returns ResNet18 truncated before final fc.
    If out_dim given, add a linear layer to project to out_dim.
    """
    model = torchvision.models.resnet18(pretrained=pretrained)
    modules = list(model.children())[:-1]  # remove fc
    encoder = nn.Sequential(*modules)      # outputs [B, 512, 1, 1]
    class Wrapper(nn.Module):
        def __init__(self, encoder, out_dim=None):
            super().__init__()
            self.encoder = encoder
            self.out_dim = out_dim
            self.pool = nn.AdaptiveAvgPool2d((1,1))
            if out_dim is not None:
                self.proj = nn.Linear(512, out_dim)
            else:
                self.proj = None
        def forward(self, x):
            h = self.encoder(x)
            h = h.view(h.size(0), -1)
            if self.proj is not None:
                return self.proj(h)
            return h
    return Wrapper(encoder, out_dim=out_dim)

# instantiate
# phi_prime: the retrieval key extractor (Section 2.3). We will freeze it by default (simulate hidden retrieval).
phi_prime = get_resnet18_backbone(pretrained=True, out_dim=None).to(device)
for p in phi_prime.parameters():
    p.requires_grad = False  # freeze phi' by default

# phi: trainable feature extractor used by RaCNN (can start from pretrained backbone or random init)
phi = get_resnet18_backbone(pretrained=True, out_dim=None).to(device)
# Option: fine-tune or freeze early layers as you prefer




Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth


100%|██████████| 44.7M/44.7M [00:00<00:00, 194MB/s]


In [6]:
# ==============================================================
# 📘 Dataset Preparation — CIFAR-10 (Section 6.1)
# ==============================================================
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Data augmentation (same spirit as Sec. 6: "data augmentation on D but not D′")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)

# Dataloaders (as D and D′)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

print(f"[RaCNN] CIFAR-10 ready — {len(trainset)} train, {len(testset)} test samples.")


[RaCNN] CIFAR-10 ready — 50000 train, 10000 test samples.


In [7]:
# ==============================================================
# 📘 Section 2.3 — Retrieval Engine F
# ==============================================================
# This builds the retrieval database D′ = {φ′(x′i), y′i}.
# In the paper: they use LSH + random projection.
# Here: we use FAISS IndexFlatL2 (dense lookup, faster & GPU compatible).
# Optional random projection (dim reduction) can be added with out_dim in phi′.

import os, faiss, numpy as np
from tqdm import tqdm

SAVE_DIR = '/content/drive/MyDrive/racnn'
os.makedirs(SAVE_DIR, exist_ok=True)

DB_FEATS_PATH = os.path.join(SAVE_DIR, 'db_feats.npy')
DB_LABELS_PATH = os.path.join(SAVE_DIR, 'db_labels.npy')
DB_INDEX_PATH = os.path.join(SAVE_DIR, 'retrieval_index.faiss')

# --------------------------------------------------------------
# 🔹 Helper: build retrieval DB if not cached
# --------------------------------------------------------------
def build_feature_db(dataloader, feature_extractor):
    feature_extractor.eval()
    feats, labels = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc="[RaCNN] Building retrieval DB"):
            imgs = imgs.to(device)
            f = feature_extractor(imgs).cpu().numpy()
            feats.append(f)
            labels.append(lbls.numpy())
    feats = np.concatenate(feats, axis=0)
    labels = np.concatenate(labels, axis=0)
    return feats, labels


# --------------------------------------------------------------
# 🔹 Step 1: Try to load saved retrieval DB
# --------------------------------------------------------------
if all(os.path.exists(p) for p in [DB_FEATS_PATH, DB_LABELS_PATH, DB_INDEX_PATH]):
    print("[RaCNN] Loading cached retrieval database from Drive...")
    db_feats = np.load(DB_FEATS_PATH)
    db_labels = np.load(DB_LABELS_PATH)
    index = faiss.read_index(DB_INDEX_PATH)

else:
    print("[RaCNN] Building retrieval database from scratch (first run)...")
    db_feats, db_labels = build_feature_db(trainloader, phi_prime)

    # FAISS index construction (equiv. to 'retrieval engine F' in Sec. 2.3)
    d = db_feats.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(db_feats)

    # Save to Drive (so we can skip this next time)
    np.save(DB_FEATS_PATH, db_feats)
    np.save(DB_LABELS_PATH, db_labels)
    faiss.write_index(index, DB_INDEX_PATH)
    print(f"[RaCNN] Saved retrieval DB to {SAVE_DIR}")

print(f"[RaCNN] Retrieval DB ready — {len(db_labels)} examples indexed.")


[RaCNN] Building retrieval database from scratch (first run)...


[RaCNN] Building retrieval DB: 100%|██████████| 391/391 [00:20<00:00, 19.33it/s]


[RaCNN] Saved retrieval DB to /content/drive/MyDrive/racnn
[RaCNN] Retrieval DB ready — 50000 examples indexed.


In [8]:
# ==============================================================
# 📘 Section 2.1 — Trainable Projection via Attention (βₖ, αₖ, P(x))
# ==============================================================

import torch.nn.functional as F

class TrainableProjection(nn.Module):
    """
    Implements Eqns. (βₖ, αₖ, P(x)) from Section 2.1:
      βₖ = φ(x′ₖ)^T U φ(x)
      αₖ = softmax(βₖ)
      P(x) = Σ αₖ φ(x′ₖ)
    """
    def __init__(self, feature_dim):
        super().__init__()
        self.U = nn.Parameter(torch.eye(feature_dim))  # trainable weight matrix U

    def forward(self, phi_x, phi_neighbors):
        """
        phi_x:         [B, D] — feature of input image φ(x)
        phi_neighbors: [B, K, D] — features of K retrieved neighbors φ(x′ₖ)
        returns:
            P(x): projected feature [B, D]
            α: attention weights [B, K]
        """
        # Compute βₖ = φ(x′ₖ)^T U φ(x)
        B, K, D = phi_neighbors.shape
        Uphi = F.linear(phi_x, self.U)  # [B, D]
        beta = torch.bmm(phi_neighbors, Uphi.unsqueeze(2)).squeeze(2)  # [B, K]

        # αₖ = softmax(βₖ)
        alpha = F.softmax(beta, dim=1)  # [B, K]

        # Weighted sum: P(x) = Σ αₖ φ(x′ₖ)
        P = torch.bmm(alpha.unsqueeze(1), phi_neighbors).squeeze(1)  # [B, D]
        return P, alpha


In [9]:
# ==============================================================
# 📘 Section 2.2 — Training with Local Mixup + Classifier g
# ==============================================================

class ClassifierHead(nn.Module):
    """
    Implements g from Section 2.
    In the paper, g is the final classifier after the projection P(x).
    """
    def __init__(self, feature_dim=512, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# --------------------------------------------------------------
# 🔹 Local Mixup implementation
# --------------------------------------------------------------
def local_mixup(phi_neighbors, y_neighbors):
    """
    Implements Eqn. for Local Mixup (Section 2.2):
      (Σ αₖ φ(x′ₖ), Σ αₖ y′ₖ)
    α sampled uniformly on the simplex (Kraemer Algorithm).
    """
    B, K, D = phi_neighbors.shape
    device = phi_neighbors.device
    # sample convex coefficients αₖ ≥ 0, Σ αₖ = 1
    rand = torch.rand(B, K, device=device)
    alpha = rand / rand.sum(dim=1, keepdim=True)
    # new mixed feature and label
    mixed_feat = torch.bmm(alpha.unsqueeze(1), phi_neighbors).squeeze(1)   # [B, D]
    mixed_label = torch.bmm(alpha.unsqueeze(1), y_neighbors).squeeze(1)   # [B, num_classes]
    return mixed_feat, mixed_label

# --------------------------------------------------------------
# 🔹 Full RaCNN Model: φ, F(x), Projection (U), Classifier (g)
# --------------------------------------------------------------
class RaCNN(nn.Module):
    """
    Combines all modules from Section 2:
      - Feature extractor φ
      - Trainable projection (U, αₖ, convex hull)
      - Classifier g
    Retrieval F(x) and φ′ handled externally (FAISS + precomputed db_feats)
    """
    def __init__(self, phi, projection, classifier, index, db_feats, db_labels, K=10):
        super().__init__()
        self.phi = phi
        self.projection = projection
        self.classifier = classifier
        self.index = index
        self.db_feats = db_feats
        self.db_labels = db_labels
        self.K = K

    def forward(self, x):
        """
        Forward pass:
          1. extract φ(x)
          2. retrieve K nearest neighbors via FAISS (F(x))
          3. project x onto convex hull (Eqn. 2.1)
          4. classify g(P(x))
        """
        # Step 1: feature for input x
        phi_x = self.phi(x)                       # [B, D]
        phi_x_np = phi_x.detach().cpu().numpy()

        # Step 2: retrieve neighbors in φ′ space (F(x))
        D, I = self.index.search(phi_x_np, self.K)
        phi_neighbors = torch.tensor(self.db_feats[I], dtype=torch.float32, device=x.device)
        y_neighbors = torch.nn.functional.one_hot(
            torch.tensor(self.db_labels[I], device=x.device), num_classes=10
        ).float()

        # Step 3: project x to convex hull (trainable projection)
        P, alpha = self.projection(phi_x, phi_neighbors)

        # Step 4: classify
        logits = self.classifier(P)
        return logits, P, alpha


In [12]:
# ==============================================================
# 📘 Section 2.2 + Section 6 — Training Loop for RaCNN
# ==============================================================

import torch.optim as optim
from torch.nn import functional as F
from tqdm import tqdm

# instantiate all components
feature_dim = 512
projection = TrainableProjection(feature_dim).to(device)
classifier = ClassifierHead(feature_dim, num_classes=10).to(device)
model = RaCNN(phi, projection, classifier, index, db_feats, db_labels, K=5).to(device)

# optimizer (paper: "we use stochastic gradient descent")
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)

# loss criterion
criterion = nn.CrossEntropyLoss()

# training hyperparams
num_epochs = 20          # keep small in Colab
NCE_steps = 1           # classification steps
NMU_steps = 5           # local mixup steps

for epoch in range(num_epochs):
    model.train()
    running_loss, running_mixup_loss = 0.0, 0.0
    pbar = tqdm(trainloader, desc=f"[Epoch {epoch+1}/{num_epochs}]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        onehot = F.one_hot(labels, num_classes=10).float()

        # ----------------------------------------------------------
        # 1️⃣ Classification loss (NCE step)
        # ----------------------------------------------------------
        optimizer.zero_grad()
        logits, P, _ = model(imgs)
        loss_cls = criterion(logits, labels)
        loss_cls.backward()
        optimizer.step()
        running_loss += loss_cls.item()

        # ----------------------------------------------------------
        # 2️⃣ Local Mixup regularization (NMU step)
        # ----------------------------------------------------------
        with torch.no_grad():
            # retrieve neighbors again for mixup
            phi_x = model.phi(imgs)
            phi_x_np = phi_x.detach().cpu().numpy()
            _, I = model.index.search(phi_x_np, model.K)
            phi_neighbors = torch.tensor(model.db_feats[I], dtype=torch.float32, device=device)
            y_neighbors = F.one_hot(
                torch.tensor(model.db_labels[I], device=device), num_classes=10
            ).float()

        mixed_feat, mixed_label = local_mixup(phi_neighbors, y_neighbors)

        optimizer.zero_grad()
        logits_mix = model.classifier(mixed_feat)
        loss_mix = F.cross_entropy(logits_mix, mixed_label.argmax(dim=1))
        loss_mix.backward()
        optimizer.step()
        running_mixup_loss += loss_mix.item()

        pbar.set_postfix({
            "cls_loss": f"{running_loss/len(trainloader):.3f}",
            "mix_loss": f"{running_mixup_loss/len(trainloader):.3f}"
        })

    print(f"Epoch {epoch+1}: Cls Loss {running_loss/len(trainloader):.4f} | Mix Loss {running_mixup_loss/len(trainloader):.4f}")

print("✅ Training complete.")


[Epoch 1/20]: 100%|██████████| 391/391 [02:11<00:00,  2.98it/s, cls_loss=1.765, mix_loss=1.331]


Epoch 1: Cls Loss 1.7652 | Mix Loss 1.3310


[Epoch 2/20]: 100%|██████████| 391/391 [02:11<00:00,  2.98it/s, cls_loss=1.657, mix_loss=1.123]


Epoch 2: Cls Loss 1.6569 | Mix Loss 1.1229


[Epoch 3/20]: 100%|██████████| 391/391 [02:12<00:00,  2.96it/s, cls_loss=1.573, mix_loss=1.042]


Epoch 3: Cls Loss 1.5733 | Mix Loss 1.0415


[Epoch 4/20]: 100%|██████████| 391/391 [02:10<00:00,  3.00it/s, cls_loss=1.527, mix_loss=1.025]


Epoch 4: Cls Loss 1.5268 | Mix Loss 1.0248


[Epoch 5/20]: 100%|██████████| 391/391 [02:11<00:00,  2.97it/s, cls_loss=1.479, mix_loss=0.975]


Epoch 5: Cls Loss 1.4794 | Mix Loss 0.9754


[Epoch 6/20]: 100%|██████████| 391/391 [02:12<00:00,  2.96it/s, cls_loss=1.450, mix_loss=1.036]


Epoch 6: Cls Loss 1.4503 | Mix Loss 1.0362


[Epoch 7/20]: 100%|██████████| 391/391 [02:10<00:00,  2.99it/s, cls_loss=1.415, mix_loss=1.012]


Epoch 7: Cls Loss 1.4154 | Mix Loss 1.0117


[Epoch 8/20]: 100%|██████████| 391/391 [02:11<00:00,  2.97it/s, cls_loss=1.393, mix_loss=0.982]


Epoch 8: Cls Loss 1.3928 | Mix Loss 0.9815


[Epoch 9/20]: 100%|██████████| 391/391 [02:11<00:00,  2.97it/s, cls_loss=1.376, mix_loss=0.964]


Epoch 9: Cls Loss 1.3759 | Mix Loss 0.9642


[Epoch 10/20]: 100%|██████████| 391/391 [02:13<00:00,  2.93it/s, cls_loss=1.314, mix_loss=0.978]


Epoch 10: Cls Loss 1.3139 | Mix Loss 0.9782


[Epoch 11/20]: 100%|██████████| 391/391 [02:11<00:00,  2.97it/s, cls_loss=1.274, mix_loss=0.936]


Epoch 11: Cls Loss 1.2736 | Mix Loss 0.9361


[Epoch 12/20]: 100%|██████████| 391/391 [02:11<00:00,  2.98it/s, cls_loss=1.232, mix_loss=0.905]


Epoch 12: Cls Loss 1.2323 | Mix Loss 0.9050


[Epoch 13/20]: 100%|██████████| 391/391 [02:07<00:00,  3.06it/s, cls_loss=1.204, mix_loss=0.848]


Epoch 13: Cls Loss 1.2041 | Mix Loss 0.8485


[Epoch 14/20]: 100%|██████████| 391/391 [02:10<00:00,  3.01it/s, cls_loss=1.224, mix_loss=0.848]


Epoch 14: Cls Loss 1.2235 | Mix Loss 0.8475


[Epoch 15/20]: 100%|██████████| 391/391 [02:09<00:00,  3.03it/s, cls_loss=1.164, mix_loss=0.813]


Epoch 15: Cls Loss 1.1636 | Mix Loss 0.8129


[Epoch 16/20]: 100%|██████████| 391/391 [02:10<00:00,  3.00it/s, cls_loss=1.132, mix_loss=0.787]


Epoch 16: Cls Loss 1.1322 | Mix Loss 0.7870


[Epoch 17/20]: 100%|██████████| 391/391 [02:09<00:00,  3.03it/s, cls_loss=1.239, mix_loss=0.816]


Epoch 17: Cls Loss 1.2388 | Mix Loss 0.8159


[Epoch 18/20]: 100%|██████████| 391/391 [02:10<00:00,  3.00it/s, cls_loss=1.127, mix_loss=0.761]


Epoch 18: Cls Loss 1.1273 | Mix Loss 0.7613


[Epoch 19/20]: 100%|██████████| 391/391 [02:11<00:00,  2.98it/s, cls_loss=1.063, mix_loss=0.739]


Epoch 19: Cls Loss 1.0632 | Mix Loss 0.7388


[Epoch 20/20]: 100%|██████████| 391/391 [02:10<00:00,  3.00it/s, cls_loss=1.037, mix_loss=0.744]

Epoch 20: Cls Loss 1.0368 | Mix Loss 0.7444
✅ Training complete.





In [13]:
# ==============================================================
# 📘 Section 6 — Evaluation under Adversarial Attacks
# ==============================================================

def fgsm_attack(model, images, labels, eps=0.03):
    """Implements FGSM: x_adv = x + eps * sign(∇x L(x, y))."""
    images = images.clone().detach().to(device)
    images.requires_grad = True

    outputs, _, _ = model(images)
    loss = F.cross_entropy(outputs, labels)
    loss.backward()
    adv_images = images + eps * images.grad.sign()
    adv_images = torch.clamp(adv_images, -1, 1)   # keep normalized range
    return adv_images.detach()

def ifgsm_attack(model, images, labels, eps=0.03, alpha=0.005, iters=10):
    """Iterative FGSM (a.k.a. PGD-lite)."""
    images = images.clone().detach().to(device)
    adv_images = images.clone().detach()

    for _ in range(iters):
        adv_images.requires_grad = True
        outputs, _, _ = model(adv_images)
        loss = F.cross_entropy(outputs, labels)
        model.zero_grad()
        loss.backward()
        adv_images = adv_images + alpha * adv_images.grad.sign()
        eta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + eta, -1, 1).detach()
    return adv_images


# --------------------------------------------------------------
# 🔹 Evaluation function
# --------------------------------------------------------------
def evaluate(model, dataloader, attack=None):
    model.eval()
    correct, total = 0, 0
    for imgs, labels in tqdm(dataloader, desc=f"[Eval {attack or 'Clean'}]"):
        imgs, labels = imgs.to(device), labels.to(device)
        if attack == 'FGSM':
            imgs = fgsm_attack(model, imgs, labels)
        elif attack == 'iFGSM':
            imgs = ifgsm_attack(model, imgs, labels)
        with torch.no_grad():
            outputs, _, _ = model(imgs)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total


# --------------------------------------------------------------
# 🔹 Evaluate RaCNN vs. Vanilla CNN
# --------------------------------------------------------------
print("\n🧪 Evaluating robustness (CIFAR-10 test set)...")

# Baseline CNN (vanilla)
vanilla = get_resnet18_backbone(pretrained=False)
vanilla.fc = nn.Linear(512, 10)
vanilla = vanilla.to(device)
# assume it's randomly initialized — used to show difference in robustness

acc_clean_racnn = evaluate(model, testloader, None)
acc_fgsm_racnn  = evaluate(model, testloader, 'FGSM')
acc_ifgsm_racnn = evaluate(model, testloader, 'iFGSM')

print(f"✅ RaCNN accuracy — Clean: {acc_clean_racnn:.2f}% | FGSM: {acc_fgsm_racnn:.2f}% | iFGSM: {acc_ifgsm_racnn:.2f}%")



🧪 Evaluating robustness (CIFAR-10 test set)...


[Eval Clean]: 100%|██████████| 79/79 [00:11<00:00,  7.15it/s]
[Eval FGSM]: 100%|██████████| 79/79 [00:25<00:00,  3.13it/s]
[Eval iFGSM]: 100%|██████████| 79/79 [02:09<00:00,  1.63s/it]

✅ RaCNN accuracy — Clean: 69.02% | FGSM: 29.82% | iFGSM: 20.63%





In [18]:
# ==============================================================
# 📘 Extended Adversarial Evaluation — Full Suite (Section 6)
# ==============================================================
import torchattacks
from tqdm import tqdm
import torch.nn as nn

# --------------------------------------------------------------
# 🔹 Fix for torchattacks — wrapper that only returns logits
# --------------------------------------------------------------
class LogitsWrapper(nn.Module):
    def __init__(self, racnn):
        super().__init__()
        self.racnn = racnn
    def forward(self, x):
        logits, _, _ = self.racnn(x)
        return logits

# use wrapper for attacks (torchattacks expects logits only)
wrapped_model = LogitsWrapper(model)

# --------------------------------------------------------------
# 🔹 Define attacks (matches Section 6 of the paper)
# --------------------------------------------------------------
atk_fgsm     = torchattacks.FGSM(wrapped_model, eps=0.03)
atk_ifgsm    = torchattacks.PGD(wrapped_model, eps=0.03, alpha=0.005, steps=10)   # ≈ iFGSM
atk_pgd      = torchattacks.PGD(wrapped_model, eps=0.03, alpha=0.005, steps=20)   # stronger PGD
atk_deepfool = torchattacks.DeepFool(wrapped_model, steps=50)
atk_cw       = torchattacks.CW(wrapped_model, c=1, kappa=0, steps=50, lr=0.01)
atk_noise    = lambda x, y: torch.clamp(x + 0.03 * torch.randn_like(x), -1, 1)  # random Gaussian noise

# Wrap all attacks in a dict
attacks = {
    "Clean":    None,
    "FGSM":     atk_fgsm,
    "iFGSM":    atk_ifgsm,
    "PGD":      atk_pgd,
    "DeepFool": atk_deepfool,
    "CW":       atk_cw,
    "Noise":    atk_noise,
}

# --------------------------------------------------------------
# 🔹 Evaluation loop
# --------------------------------------------------------------
def evaluate_adv(model, dataloader, attacks):
    """Evaluate model under multiple attacks."""
    model.eval()
    results = {}
    for name, atk in attacks.items():
        correct, total = 0, 0
        for imgs, labels in tqdm(dataloader, desc=f"[Eval {name}]"):
            imgs, labels = imgs.to(device), labels.to(device)
            if atk is not None:
                if callable(atk):  # noise
                    imgs = atk(imgs, labels)
                else:
                    imgs = atk(imgs, labels)
            with torch.no_grad():
                outputs, _, _ = model(imgs)
                preds = outputs.argmax(1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        acc = 100 * correct / total
        results[name] = acc
        print(f"{name:10s}: {acc:.2f}%")
    return results

# --------------------------------------------------------------
# 🚀 Run full evaluation
# --------------------------------------------------------------
print("🧪 Evaluating RaCNN robustness under full attack suite (CIFAR-10)...")
results = evaluate_adv(model, testloader, attacks)

print("\n✅ RaCNN Robustness Summary:")
for k, v in results.items():
    print(f"{k:10s}: {v:.2f}%")


🧪 Evaluating RaCNN robustness under full attack suite (CIFAR-10)...


[Eval Clean]: 100%|██████████| 79/79 [00:10<00:00,  7.24it/s]


Clean     : 69.02%


[Eval FGSM]: 100%|██████████| 79/79 [00:24<00:00,  3.22it/s]


FGSM      : 38.89%


[Eval iFGSM]: 100%|██████████| 79/79 [02:11<00:00,  1.67s/it]


iFGSM     : 15.88%


[Eval PGD]: 100%|██████████| 79/79 [04:07<00:00,  3.14s/it]


PGD       : 13.37%


[Eval DeepFool]: 100%|██████████| 79/79 [50:38<00:00, 38.46s/it]


DeepFool  : 10.00%


[Eval CW]: 100%|██████████| 79/79 [09:48<00:00,  7.45s/it]


CW        : 10.00%


[Eval Noise]: 100%|██████████| 79/79 [00:10<00:00,  7.32it/s]

Noise     : 65.77%

✅ RaCNN Robustness Summary:
Clean     : 69.02%
FGSM      : 38.89%
iFGSM     : 15.88%
PGD       : 13.37%
DeepFool  : 10.00%
CW        : 10.00%
Noise     : 65.77%



