In [1]:
# Colab cell 1: install dependencies
# - faiss-cpu for fast nearest-neighbor search (IndexFlatL2)
# - torchattacks optional (we also include minimal FGSM/PGD implementations)
# Note: use faiss-gpu if available, but faiss-cpu works fine for CIFAR-10.
!pip install faiss-cpu
!pip install torchattacks


Collecting faiss-cpu
  Using cached faiss_cpu-1.12.0-cp310-cp310-win_amd64.whl.metadata (5.2 kB)
Using cached faiss_cpu-1.12.0-cp310-cp310-win_amd64.whl (18.2 MB)
Installing collected packages: faiss-cpu
Successfully installed faiss-cpu-1.12.0
Collecting torchattacks
  Downloading torchattacks-3.5.1-py3-none-any.whl.metadata (927 bytes)
Collecting requests~=2.25.1 (from torchattacks)
  Downloading requests-2.25.1-py2.py3-none-any.whl.metadata (4.2 kB)
Collecting chardet<5,>=3.0.2 (from requests~=2.25.1->torchattacks)
  Downloading chardet-4.0.0-py2.py3-none-any.whl.metadata (3.5 kB)
Collecting idna<3,>=2.5 (from requests~=2.25.1->torchattacks)
  Downloading idna-2.10-py2.py3-none-any.whl.metadata (9.1 kB)
Collecting urllib3<1.27,>=1.21.1 (from requests~=2.25.1->torchattacks)
  Downloading urllib3-1.26.20-py2.py3-none-any.whl.metadata (50 kB)
Downloading torchattacks-3.5.1-py3-none-any.whl (142 kB)
Downloading requests-2.25.1-py2.py3-none-any.whl (61 kB)
Downloading chardet-4.0.0-py2.py

In [3]:
# Colab cell 2
# Standard imports and hyperparameters. Change these to suit your GPU/time budget.
import os
import time
import random
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as T

import faiss  # for retrieval index

# For plotting
import matplotlib.pyplot as plt

# Device (Colab T4)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Reproducibility seeds
seed = 42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == "cuda":
    torch.cuda.manual_seed_all(seed)

# Hyperparameters (start values tuned for Colab)
IMG_SIZE = 224            # using ResNet-18 pretrained requires 224x224
BATCH_SIZE = 64
NUM_WORKERS = 2
NUM_CLASSES = 10          # CIFAR-10
LR = 1e-3
EPOCHS = 100               # lower for quick runs; increase if you have time
K = 10                     # default retrieved neighbors
K_OPTIONS = [5, 10]       # experiment with 5 and 10 as in paper
C = 10                    # K' = C*K for randomization if used
MIXUP_WEIGHT = 0.5        # weight for local mixup loss in total loss


Device: cuda


In [4]:
# Colab cell 3
# Dataset pipeline: Resize to 224 for ResNet-18. (Paper used custom small network for CIFAR; ResNet-18 is fine.)
transform_train = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])
transform_test = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize((0.485,0.456,0.406), (0.229,0.224,0.225)),
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
test_loader  = torch.utils.data.DataLoader(testset,  batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print("Train size:", len(trainset), "Test size:", len(testset))


100%|██████████| 170M/170M [01:32<00:00, 1.84MB/s] 


Train size: 50000 Test size: 10000


In [5]:
# Colab cell 4
# Implements Section 2 (phi' and phi). phi_prime can be pretrained ResNet-18 (frozen) used for retrieval keys.
# phi is the trainable feature extractor used inside RaCNN. For simplicity we use same architecture for both.

def get_resnet18_backbone(pretrained=True, out_dim=None):
    """
    Returns ResNet18 truncated before final fc.
    If out_dim given, add a linear layer to project to out_dim.
    """
    model = torchvision.models.resnet18(pretrained=pretrained)
    modules = list(model.children())[:-1]  # remove fc
    encoder = nn.Sequential(*modules)      # outputs [B, 512, 1, 1]
    class Wrapper(nn.Module):
        def __init__(self, encoder, out_dim=None):
            super().__init__()
            self.encoder = encoder
            self.out_dim = out_dim
            self.pool = nn.AdaptiveAvgPool2d((1,1))
            if out_dim is not None:
                self.proj = nn.Linear(512, out_dim)
            else:
                self.proj = None
        def forward(self, x):
            h = self.encoder(x)
            h = h.view(h.size(0), -1)
            if self.proj is not None:
                return self.proj(h)
            return h
    return Wrapper(encoder, out_dim=out_dim)

# instantiate
# phi_prime: the retrieval key extractor (Section 2.3). We will freeze it by default (simulate hidden retrieval).
phi_prime = get_resnet18_backbone(pretrained=True, out_dim=None).to(device)
for p in phi_prime.parameters():
    p.requires_grad = False  # freeze phi' by default

# phi: trainable feature extractor used by RaCNN (can start from pretrained backbone or random init)
phi = get_resnet18_backbone(pretrained=True, out_dim=None).to(device)
# Option: fine-tune or freeze early layers as you prefer




In [6]:
# ==============================================================
# 📘 Dataset Preparation — CIFAR-10 (Section 6.1)
# ==============================================================
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Data augmentation (same spirit as Sec. 6: "data augmentation on D but not D′")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)

# Dataloaders (as D and D′)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=2)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=2)

print(f"[RaCNN] CIFAR-10 ready — {len(trainset)} train, {len(testset)} test samples.")


[RaCNN] CIFAR-10 ready — 50000 train, 10000 test samples.


In [7]:
# ==============================================================
# 📘 Section 2.3 — Retrieval Engine F
# ==============================================================
# This builds the retrieval database D′ = {φ′(x′i), y′i}.
# In the paper: they use LSH + random projection.
# Here: we use FAISS IndexFlatL2 (dense lookup, faster & GPU compatible).
# Optional random projection (dim reduction) can be added with out_dim in phi′.

import os, faiss, numpy as np
from tqdm import tqdm

SAVE_DIR = '/content/drive/MyDrive/racnn'
os.makedirs(SAVE_DIR, exist_ok=True)

DB_FEATS_PATH = os.path.join(SAVE_DIR, 'db_feats.npy')
DB_LABELS_PATH = os.path.join(SAVE_DIR, 'db_labels.npy')
DB_INDEX_PATH = os.path.join(SAVE_DIR, 'retrieval_index.faiss')

# --------------------------------------------------------------
# 🔹 Helper: build retrieval DB if not cached
# --------------------------------------------------------------
def build_feature_db(dataloader, feature_extractor):
    feature_extractor.eval()
    feats, labels = [], []
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc="[RaCNN] Building retrieval DB"):
            imgs = imgs.to(device)
            f = feature_extractor(imgs).cpu().numpy()
            feats.append(f)
            labels.append(lbls.numpy())
    feats = np.concatenate(feats, axis=0)
    labels = np.concatenate(labels, axis=0)
    return feats, labels


# --------------------------------------------------------------
# 🔹 Step 1: Try to load saved retrieval DB
# --------------------------------------------------------------
if all(os.path.exists(p) for p in [DB_FEATS_PATH, DB_LABELS_PATH, DB_INDEX_PATH]):
    print("[RaCNN] Loading cached retrieval database from Drive...")
    db_feats = np.load(DB_FEATS_PATH)
    db_labels = np.load(DB_LABELS_PATH)
    index = faiss.read_index(DB_INDEX_PATH)

else:
    print("[RaCNN] Building retrieval database from scratch (first run)...")
    db_feats, db_labels = build_feature_db(trainloader, phi_prime)

    # FAISS index construction (equiv. to 'retrieval engine F' in Sec. 2.3)
    d = db_feats.shape[1]
    index = faiss.IndexFlatL2(d)
    index.add(db_feats)

    # Save to Drive (so we can skip this next time)
    np.save(DB_FEATS_PATH, db_feats)
    np.save(DB_LABELS_PATH, db_labels)
    faiss.write_index(index, DB_INDEX_PATH)
    print(f"[RaCNN] Saved retrieval DB to {SAVE_DIR}")

print(f"[RaCNN] Retrieval DB ready — {len(db_labels)} examples indexed.")


[RaCNN] Building retrieval database from scratch (first run)...


[RaCNN] Building retrieval DB: 100%|██████████| 391/391 [00:09<00:00, 39.52it/s] 

[RaCNN] Saved retrieval DB to /content/drive/MyDrive/racnn
[RaCNN] Retrieval DB ready — 50000 examples indexed.





In [8]:
# ==============================================================
# 📘 Section 2.1 — Trainable Projection via Attention (βₖ, αₖ, P(x))
# ==============================================================

import torch.nn.functional as F

class TrainableProjection(nn.Module):
    """
    Implements Eqns. (βₖ, αₖ, P(x)) from Section 2.1:
      βₖ = φ(x′ₖ)^T U φ(x)
      αₖ = softmax(βₖ)
      P(x) = Σ αₖ φ(x′ₖ)
    """
    def __init__(self, feature_dim):
        super().__init__()
        self.U = nn.Parameter(torch.eye(feature_dim))  # trainable weight matrix U

    def forward(self, phi_x, phi_neighbors):
        """
        phi_x:         [B, D] — feature of input image φ(x)
        phi_neighbors: [B, K, D] — features of K retrieved neighbors φ(x′ₖ)
        returns:
            P(x): projected feature [B, D]
            α: attention weights [B, K]
        """
        # Compute βₖ = φ(x′ₖ)^T U φ(x)
        B, K, D = phi_neighbors.shape
        Uphi = F.linear(phi_x, self.U)  # [B, D]
        beta = torch.bmm(phi_neighbors, Uphi.unsqueeze(2)).squeeze(2)  # [B, K]

        # αₖ = softmax(βₖ)
        alpha = F.softmax(beta, dim=1)  # [B, K]

        # Weighted sum: P(x) = Σ αₖ φ(x′ₖ)
        P = torch.bmm(alpha.unsqueeze(1), phi_neighbors).squeeze(1)  # [B, D]
        return P, alpha


In [9]:
# ==============================================================
# 📘 Section 2.2 — Training with Local Mixup + Classifier g
# ==============================================================

class ClassifierHead(nn.Module):
    """
    Implements g from Section 2.
    In the paper, g is the final classifier after the projection P(x).
    """
    def __init__(self, feature_dim=512, num_classes=10):
        super().__init__()
        self.fc = nn.Linear(feature_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# --------------------------------------------------------------
# 🔹 Local Mixup implementation
# --------------------------------------------------------------
def local_mixup(phi_neighbors, y_neighbors):
    """
    Implements Eqn. for Local Mixup (Section 2.2):
      (Σ αₖ φ(x′ₖ), Σ αₖ y′ₖ)
    α sampled uniformly on the simplex (Kraemer Algorithm).
    """
    B, K, D = phi_neighbors.shape
    device = phi_neighbors.device
    # sample convex coefficients αₖ ≥ 0, Σ αₖ = 1
    rand = torch.rand(B, K, device=device)
    alpha = rand / rand.sum(dim=1, keepdim=True)
    # new mixed feature and label
    mixed_feat = torch.bmm(alpha.unsqueeze(1), phi_neighbors).squeeze(1)   # [B, D]
    mixed_label = torch.bmm(alpha.unsqueeze(1), y_neighbors).squeeze(1)   # [B, num_classes]
    return mixed_feat, mixed_label

# --------------------------------------------------------------
# 🔹 Full RaCNN Model: φ, F(x), Projection (U), Classifier (g)
# --------------------------------------------------------------
class RaCNN(nn.Module):
    """
    Combines all modules from Section 2:
      - Feature extractor φ
      - Trainable projection (U, αₖ, convex hull)
      - Classifier g
    Retrieval F(x) and φ′ handled externally (FAISS + precomputed db_feats)
    """
    def __init__(self, phi, projection, classifier, index, db_feats, db_labels, K=10):
        super().__init__()
        self.phi = phi
        self.projection = projection
        self.classifier = classifier
        self.index = index
        self.db_feats = db_feats
        self.db_labels = db_labels
        self.K = K

    def forward(self, x):
        """
        Forward pass:
          1. extract φ(x)
          2. retrieve K nearest neighbors via FAISS (F(x))
          3. project x onto convex hull (Eqn. 2.1)
          4. classify g(P(x))
        """
        # Step 1: feature for input x
        phi_x = self.phi(x)                       # [B, D]
        phi_x_np = phi_x.detach().cpu().numpy()

        # Step 2: retrieve neighbors in φ′ space (F(x))
        D, I = self.index.search(phi_x_np, self.K)
        phi_neighbors = torch.tensor(self.db_feats[I], dtype=torch.float32, device=x.device)
        y_neighbors = torch.nn.functional.one_hot(
            torch.tensor(self.db_labels[I], device=x.device), num_classes=10
        ).float()

        # Step 3: project x to convex hull (trainable projection)
        P, alpha = self.projection(phi_x, phi_neighbors)

        # Step 4: classify
        logits = self.classifier(P)
        return logits, P, alpha


In [11]:
# ==============================================================
# 📘 Section 2.2 + Section 6 — Training Loop for RaCNN
# ==============================================================

import torch.optim as optim
from torch.nn import functional as F
from tqdm import tqdm

# instantiate all components
feature_dim = 512
projection = TrainableProjection(feature_dim).to(device)
classifier = ClassifierHead(feature_dim, num_classes=10).to(device)
model = RaCNN(phi, projection, classifier, index, db_feats, db_labels, K=5).to(device)

# optimizer (paper: "we use stochastic gradient descent")
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=5e-4)

# loss criterion
criterion = nn.CrossEntropyLoss()

# training hyperparams
num_epochs = 200          # keep small in Colab
NCE_steps = 1           # classification steps
NMU_steps = 5           # local mixup steps

for epoch in range(num_epochs):
    model.train()
    running_loss, running_mixup_loss = 0.0, 0.0
    pbar = tqdm(trainloader, desc=f"[Epoch {epoch+1}/{num_epochs}]")
    for imgs, labels in pbar:
        imgs, labels = imgs.to(device), labels.to(device)
        onehot = F.one_hot(labels, num_classes=10).float()

        # ----------------------------------------------------------
        # 1️⃣ Classification loss (NCE step)
        # ----------------------------------------------------------
        optimizer.zero_grad()
        logits, P, _ = model(imgs)
        loss_cls = criterion(logits, labels)
        loss_cls.backward()
        optimizer.step()
        running_loss += loss_cls.item()

        # ----------------------------------------------------------
        # 2️⃣ Local Mixup regularization (NMU step)
        # ----------------------------------------------------------
        with torch.no_grad():
            # retrieve neighbors again for mixup
            phi_x = model.phi(imgs)
            phi_x_np = phi_x.detach().cpu().numpy()
            _, I = model.index.search(phi_x_np, model.K)
            phi_neighbors = torch.tensor(model.db_feats[I], dtype=torch.float32, device=device)
            y_neighbors = F.one_hot(
                torch.tensor(model.db_labels[I], device=device), num_classes=10
            ).float()

        mixed_feat, mixed_label = local_mixup(phi_neighbors, y_neighbors)

        optimizer.zero_grad()
        logits_mix = model.classifier(mixed_feat)
        loss_mix = F.cross_entropy(logits_mix, mixed_label.argmax(dim=1))
        loss_mix.backward()
        optimizer.step()
        running_mixup_loss += loss_mix.item()

        pbar.set_postfix({
            "cls_loss": f"{running_loss/len(trainloader):.3f}",
            "mix_loss": f"{running_mixup_loss/len(trainloader):.3f}"
        })

    print(f"Epoch {epoch+1}: Cls Loss {running_loss/len(trainloader):.4f} | Mix Loss {running_mixup_loss/len(trainloader):.4f}")

print("✅ Training complete.")


[Epoch 1/200]: 100%|██████████| 391/391 [00:28<00:00, 13.74it/s, cls_loss=1.803, mix_loss=1.312]


Epoch 1: Cls Loss 1.8029 | Mix Loss 1.3120


[Epoch 2/200]: 100%|██████████| 391/391 [00:28<00:00, 13.95it/s, cls_loss=1.779, mix_loss=1.154]


Epoch 2: Cls Loss 1.7791 | Mix Loss 1.1543


[Epoch 3/200]: 100%|██████████| 391/391 [00:28<00:00, 13.95it/s, cls_loss=1.746, mix_loss=1.153]


Epoch 3: Cls Loss 1.7457 | Mix Loss 1.1531


[Epoch 4/200]: 100%|██████████| 391/391 [00:28<00:00, 13.84it/s, cls_loss=1.738, mix_loss=1.109]


Epoch 4: Cls Loss 1.7381 | Mix Loss 1.1085


[Epoch 5/200]: 100%|██████████| 391/391 [00:28<00:00, 13.76it/s, cls_loss=1.681, mix_loss=1.093]


Epoch 5: Cls Loss 1.6807 | Mix Loss 1.0932


[Epoch 6/200]: 100%|██████████| 391/391 [00:28<00:00, 13.69it/s, cls_loss=1.722, mix_loss=1.039]


Epoch 6: Cls Loss 1.7221 | Mix Loss 1.0388


[Epoch 7/200]: 100%|██████████| 391/391 [00:28<00:00, 13.52it/s, cls_loss=1.762, mix_loss=1.118]


Epoch 7: Cls Loss 1.7618 | Mix Loss 1.1185


[Epoch 8/200]: 100%|██████████| 391/391 [00:28<00:00, 13.74it/s, cls_loss=1.706, mix_loss=1.124]


Epoch 8: Cls Loss 1.7060 | Mix Loss 1.1241


[Epoch 9/200]: 100%|██████████| 391/391 [00:28<00:00, 13.83it/s, cls_loss=1.641, mix_loss=1.069]


Epoch 9: Cls Loss 1.6410 | Mix Loss 1.0686


[Epoch 10/200]: 100%|██████████| 391/391 [00:28<00:00, 13.79it/s, cls_loss=1.564, mix_loss=1.017]


Epoch 10: Cls Loss 1.5639 | Mix Loss 1.0172


[Epoch 11/200]: 100%|██████████| 391/391 [00:28<00:00, 13.73it/s, cls_loss=1.560, mix_loss=1.008]


Epoch 11: Cls Loss 1.5601 | Mix Loss 1.0076


[Epoch 12/200]: 100%|██████████| 391/391 [00:28<00:00, 13.76it/s, cls_loss=1.527, mix_loss=0.968]


Epoch 12: Cls Loss 1.5273 | Mix Loss 0.9679


[Epoch 13/200]: 100%|██████████| 391/391 [00:28<00:00, 13.79it/s, cls_loss=1.475, mix_loss=0.887]


Epoch 13: Cls Loss 1.4747 | Mix Loss 0.8868


[Epoch 14/200]: 100%|██████████| 391/391 [00:28<00:00, 13.75it/s, cls_loss=1.508, mix_loss=0.815]


Epoch 14: Cls Loss 1.5083 | Mix Loss 0.8152


[Epoch 15/200]: 100%|██████████| 391/391 [00:28<00:00, 13.70it/s, cls_loss=1.425, mix_loss=0.796]


Epoch 15: Cls Loss 1.4253 | Mix Loss 0.7960


[Epoch 16/200]: 100%|██████████| 391/391 [00:28<00:00, 13.73it/s, cls_loss=1.363, mix_loss=0.768]


Epoch 16: Cls Loss 1.3631 | Mix Loss 0.7680


[Epoch 17/200]: 100%|██████████| 391/391 [00:29<00:00, 13.36it/s, cls_loss=1.326, mix_loss=0.731]


Epoch 17: Cls Loss 1.3262 | Mix Loss 0.7314


[Epoch 18/200]: 100%|██████████| 391/391 [00:29<00:00, 13.40it/s, cls_loss=1.272, mix_loss=0.712]


Epoch 18: Cls Loss 1.2716 | Mix Loss 0.7122


[Epoch 19/200]: 100%|██████████| 391/391 [00:29<00:00, 13.32it/s, cls_loss=1.232, mix_loss=0.692]


Epoch 19: Cls Loss 1.2318 | Mix Loss 0.6918


[Epoch 20/200]: 100%|██████████| 391/391 [00:29<00:00, 13.34it/s, cls_loss=1.172, mix_loss=0.667]


Epoch 20: Cls Loss 1.1720 | Mix Loss 0.6666


[Epoch 21/200]: 100%|██████████| 391/391 [00:29<00:00, 13.32it/s, cls_loss=1.115, mix_loss=0.652]


Epoch 21: Cls Loss 1.1149 | Mix Loss 0.6518


[Epoch 22/200]: 100%|██████████| 391/391 [00:29<00:00, 13.25it/s, cls_loss=1.082, mix_loss=0.636]


Epoch 22: Cls Loss 1.0820 | Mix Loss 0.6358


[Epoch 23/200]: 100%|██████████| 391/391 [00:28<00:00, 13.64it/s, cls_loss=1.035, mix_loss=0.616]


Epoch 23: Cls Loss 1.0350 | Mix Loss 0.6156


[Epoch 24/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=1.024, mix_loss=0.591]


Epoch 24: Cls Loss 1.0243 | Mix Loss 0.5912


[Epoch 25/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.984, mix_loss=0.575]


Epoch 25: Cls Loss 0.9841 | Mix Loss 0.5748


[Epoch 26/200]: 100%|██████████| 391/391 [00:27<00:00, 14.24it/s, cls_loss=0.933, mix_loss=0.569]


Epoch 26: Cls Loss 0.9325 | Mix Loss 0.5686


[Epoch 27/200]: 100%|██████████| 391/391 [00:27<00:00, 14.22it/s, cls_loss=0.915, mix_loss=0.565]


Epoch 27: Cls Loss 0.9145 | Mix Loss 0.5651


[Epoch 28/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.900, mix_loss=0.564]


Epoch 28: Cls Loss 0.9003 | Mix Loss 0.5643


[Epoch 29/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.894, mix_loss=0.573]


Epoch 29: Cls Loss 0.8938 | Mix Loss 0.5732


[Epoch 30/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.870, mix_loss=0.560]


Epoch 30: Cls Loss 0.8699 | Mix Loss 0.5604


[Epoch 31/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.847, mix_loss=0.556]


Epoch 31: Cls Loss 0.8475 | Mix Loss 0.5562


[Epoch 32/200]: 100%|██████████| 391/391 [00:27<00:00, 14.36it/s, cls_loss=0.836, mix_loss=0.551]


Epoch 32: Cls Loss 0.8363 | Mix Loss 0.5511


[Epoch 33/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.818, mix_loss=0.549]


Epoch 33: Cls Loss 0.8184 | Mix Loss 0.5485


[Epoch 34/200]: 100%|██████████| 391/391 [00:27<00:00, 14.34it/s, cls_loss=0.806, mix_loss=0.541]


Epoch 34: Cls Loss 0.8061 | Mix Loss 0.5414


[Epoch 35/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.797, mix_loss=0.543]


Epoch 35: Cls Loss 0.7969 | Mix Loss 0.5435


[Epoch 36/200]: 100%|██████████| 391/391 [00:27<00:00, 14.37it/s, cls_loss=0.800, mix_loss=0.537]


Epoch 36: Cls Loss 0.7999 | Mix Loss 0.5369


[Epoch 37/200]: 100%|██████████| 391/391 [00:27<00:00, 14.37it/s, cls_loss=0.803, mix_loss=0.537]


Epoch 37: Cls Loss 0.8031 | Mix Loss 0.5368


[Epoch 38/200]: 100%|██████████| 391/391 [00:27<00:00, 14.06it/s, cls_loss=0.817, mix_loss=0.546]


Epoch 38: Cls Loss 0.8166 | Mix Loss 0.5464


[Epoch 39/200]: 100%|██████████| 391/391 [00:27<00:00, 14.38it/s, cls_loss=0.805, mix_loss=0.544]


Epoch 39: Cls Loss 0.8049 | Mix Loss 0.5435


[Epoch 40/200]: 100%|██████████| 391/391 [00:27<00:00, 14.41it/s, cls_loss=0.784, mix_loss=0.537]


Epoch 40: Cls Loss 0.7839 | Mix Loss 0.5370


[Epoch 41/200]: 100%|██████████| 391/391 [00:27<00:00, 14.41it/s, cls_loss=0.764, mix_loss=0.522]


Epoch 41: Cls Loss 0.7641 | Mix Loss 0.5225


[Epoch 42/200]: 100%|██████████| 391/391 [00:27<00:00, 14.36it/s, cls_loss=0.753, mix_loss=0.518]


Epoch 42: Cls Loss 0.7532 | Mix Loss 0.5178


[Epoch 43/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.740, mix_loss=0.521]


Epoch 43: Cls Loss 0.7396 | Mix Loss 0.5207


[Epoch 44/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.721, mix_loss=0.522]


Epoch 44: Cls Loss 0.7214 | Mix Loss 0.5224


[Epoch 45/200]: 100%|██████████| 391/391 [00:27<00:00, 14.36it/s, cls_loss=0.730, mix_loss=0.504]


Epoch 45: Cls Loss 0.7296 | Mix Loss 0.5043


[Epoch 46/200]: 100%|██████████| 391/391 [00:27<00:00, 14.37it/s, cls_loss=0.705, mix_loss=0.519]


Epoch 46: Cls Loss 0.7054 | Mix Loss 0.5192


[Epoch 47/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.687, mix_loss=0.514]


Epoch 47: Cls Loss 0.6871 | Mix Loss 0.5142


[Epoch 48/200]: 100%|██████████| 391/391 [00:27<00:00, 14.46it/s, cls_loss=0.676, mix_loss=0.504]


Epoch 48: Cls Loss 0.6757 | Mix Loss 0.5038


[Epoch 49/200]: 100%|██████████| 391/391 [00:27<00:00, 14.38it/s, cls_loss=0.684, mix_loss=0.490]


Epoch 49: Cls Loss 0.6835 | Mix Loss 0.4901


[Epoch 50/200]: 100%|██████████| 391/391 [00:27<00:00, 14.35it/s, cls_loss=0.662, mix_loss=0.485]


Epoch 50: Cls Loss 0.6615 | Mix Loss 0.4847


[Epoch 51/200]: 100%|██████████| 391/391 [00:27<00:00, 14.34it/s, cls_loss=0.645, mix_loss=0.476]


Epoch 51: Cls Loss 0.6446 | Mix Loss 0.4761


[Epoch 52/200]: 100%|██████████| 391/391 [00:27<00:00, 14.35it/s, cls_loss=0.630, mix_loss=0.475]


Epoch 52: Cls Loss 0.6298 | Mix Loss 0.4747


[Epoch 53/200]: 100%|██████████| 391/391 [00:27<00:00, 14.36it/s, cls_loss=0.628, mix_loss=0.470]


Epoch 53: Cls Loss 0.6275 | Mix Loss 0.4699


[Epoch 54/200]: 100%|██████████| 391/391 [00:27<00:00, 14.35it/s, cls_loss=0.624, mix_loss=0.469]


Epoch 54: Cls Loss 0.6242 | Mix Loss 0.4691


[Epoch 55/200]: 100%|██████████| 391/391 [00:27<00:00, 14.39it/s, cls_loss=0.620, mix_loss=0.473]


Epoch 55: Cls Loss 0.6200 | Mix Loss 0.4730


[Epoch 56/200]: 100%|██████████| 391/391 [00:27<00:00, 14.15it/s, cls_loss=0.606, mix_loss=0.481]


Epoch 56: Cls Loss 0.6062 | Mix Loss 0.4809


[Epoch 57/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.602, mix_loss=0.474]


Epoch 57: Cls Loss 0.6023 | Mix Loss 0.4740


[Epoch 58/200]: 100%|██████████| 391/391 [00:28<00:00, 13.84it/s, cls_loss=0.903, mix_loss=0.465]


Epoch 58: Cls Loss 0.9035 | Mix Loss 0.4646


[Epoch 59/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.712, mix_loss=0.465]


Epoch 59: Cls Loss 0.7119 | Mix Loss 0.4652


[Epoch 60/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.653, mix_loss=0.457]


Epoch 60: Cls Loss 0.6532 | Mix Loss 0.4573


[Epoch 61/200]: 100%|██████████| 391/391 [00:27<00:00, 14.46it/s, cls_loss=0.629, mix_loss=0.460]


Epoch 61: Cls Loss 0.6292 | Mix Loss 0.4597


[Epoch 62/200]: 100%|██████████| 391/391 [00:27<00:00, 14.45it/s, cls_loss=0.615, mix_loss=0.458]


Epoch 62: Cls Loss 0.6154 | Mix Loss 0.4576


[Epoch 63/200]: 100%|██████████| 391/391 [00:27<00:00, 14.42it/s, cls_loss=0.590, mix_loss=0.461]


Epoch 63: Cls Loss 0.5900 | Mix Loss 0.4607


[Epoch 64/200]: 100%|██████████| 391/391 [00:27<00:00, 14.34it/s, cls_loss=0.604, mix_loss=0.457]


Epoch 64: Cls Loss 0.6042 | Mix Loss 0.4570


[Epoch 65/200]: 100%|██████████| 391/391 [00:27<00:00, 14.38it/s, cls_loss=0.651, mix_loss=0.449]


Epoch 65: Cls Loss 0.6509 | Mix Loss 0.4488


[Epoch 66/200]: 100%|██████████| 391/391 [00:27<00:00, 14.37it/s, cls_loss=0.593, mix_loss=0.462]


Epoch 66: Cls Loss 0.5931 | Mix Loss 0.4615


[Epoch 67/200]: 100%|██████████| 391/391 [00:27<00:00, 14.39it/s, cls_loss=0.567, mix_loss=0.454]


Epoch 67: Cls Loss 0.5666 | Mix Loss 0.4538


[Epoch 68/200]: 100%|██████████| 391/391 [00:27<00:00, 14.44it/s, cls_loss=0.611, mix_loss=0.460]


Epoch 68: Cls Loss 0.6111 | Mix Loss 0.4598


[Epoch 69/200]: 100%|██████████| 391/391 [00:27<00:00, 14.41it/s, cls_loss=0.554, mix_loss=0.452]


Epoch 69: Cls Loss 0.5538 | Mix Loss 0.4523


[Epoch 70/200]: 100%|██████████| 391/391 [00:27<00:00, 14.39it/s, cls_loss=0.547, mix_loss=0.447]


Epoch 70: Cls Loss 0.5473 | Mix Loss 0.4468


[Epoch 71/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.539, mix_loss=0.436]


Epoch 71: Cls Loss 0.5389 | Mix Loss 0.4364


[Epoch 72/200]: 100%|██████████| 391/391 [00:27<00:00, 14.42it/s, cls_loss=0.535, mix_loss=0.444]


Epoch 72: Cls Loss 0.5353 | Mix Loss 0.4443


[Epoch 73/200]: 100%|██████████| 391/391 [00:27<00:00, 14.40it/s, cls_loss=0.541, mix_loss=0.445]


Epoch 73: Cls Loss 0.5406 | Mix Loss 0.4452


[Epoch 74/200]: 100%|██████████| 391/391 [00:27<00:00, 14.35it/s, cls_loss=0.511, mix_loss=0.440]


Epoch 74: Cls Loss 0.5115 | Mix Loss 0.4402


[Epoch 75/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.522, mix_loss=0.431]


Epoch 75: Cls Loss 0.5222 | Mix Loss 0.4309


[Epoch 76/200]: 100%|██████████| 391/391 [00:27<00:00, 14.45it/s, cls_loss=0.506, mix_loss=0.430]


Epoch 76: Cls Loss 0.5059 | Mix Loss 0.4298


[Epoch 77/200]: 100%|██████████| 391/391 [00:27<00:00, 14.47it/s, cls_loss=0.503, mix_loss=0.435]


Epoch 77: Cls Loss 0.5034 | Mix Loss 0.4347


[Epoch 78/200]: 100%|██████████| 391/391 [00:27<00:00, 14.46it/s, cls_loss=0.530, mix_loss=0.419]


Epoch 78: Cls Loss 0.5298 | Mix Loss 0.4195


[Epoch 79/200]: 100%|██████████| 391/391 [00:26<00:00, 14.49it/s, cls_loss=0.517, mix_loss=0.425]


Epoch 79: Cls Loss 0.5174 | Mix Loss 0.4248


[Epoch 80/200]: 100%|██████████| 391/391 [00:27<00:00, 14.13it/s, cls_loss=0.502, mix_loss=0.414]


Epoch 80: Cls Loss 0.5021 | Mix Loss 0.4144


[Epoch 81/200]: 100%|██████████| 391/391 [00:27<00:00, 14.24it/s, cls_loss=0.500, mix_loss=0.417]


Epoch 81: Cls Loss 0.5003 | Mix Loss 0.4173


[Epoch 82/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.485, mix_loss=0.424]


Epoch 82: Cls Loss 0.4845 | Mix Loss 0.4236


[Epoch 83/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.480, mix_loss=0.411]


Epoch 83: Cls Loss 0.4800 | Mix Loss 0.4112


[Epoch 84/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.469, mix_loss=0.421]


Epoch 84: Cls Loss 0.4689 | Mix Loss 0.4207


[Epoch 85/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.465, mix_loss=0.413]


Epoch 85: Cls Loss 0.4651 | Mix Loss 0.4129


[Epoch 86/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.469, mix_loss=0.427]


Epoch 86: Cls Loss 0.4690 | Mix Loss 0.4269


[Epoch 87/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.459, mix_loss=0.423]


Epoch 87: Cls Loss 0.4594 | Mix Loss 0.4229


[Epoch 88/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.451, mix_loss=0.425]


Epoch 88: Cls Loss 0.4507 | Mix Loss 0.4255


[Epoch 89/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.436, mix_loss=0.430]


Epoch 89: Cls Loss 0.4362 | Mix Loss 0.4301


[Epoch 90/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.441, mix_loss=0.435]


Epoch 90: Cls Loss 0.4407 | Mix Loss 0.4345


[Epoch 91/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.436, mix_loss=0.425]


Epoch 91: Cls Loss 0.4357 | Mix Loss 0.4254


[Epoch 92/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.429, mix_loss=0.424]


Epoch 92: Cls Loss 0.4285 | Mix Loss 0.4244


[Epoch 93/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.431, mix_loss=0.417]


Epoch 93: Cls Loss 0.4313 | Mix Loss 0.4171


[Epoch 94/200]: 100%|██████████| 391/391 [00:27<00:00, 14.15it/s, cls_loss=0.427, mix_loss=0.412]


Epoch 94: Cls Loss 0.4274 | Mix Loss 0.4123


[Epoch 95/200]: 100%|██████████| 391/391 [00:27<00:00, 14.21it/s, cls_loss=0.421, mix_loss=0.408]


Epoch 95: Cls Loss 0.4213 | Mix Loss 0.4085


[Epoch 96/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.426, mix_loss=0.408]


Epoch 96: Cls Loss 0.4256 | Mix Loss 0.4079


[Epoch 97/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.413, mix_loss=0.418]


Epoch 97: Cls Loss 0.4131 | Mix Loss 0.4178


[Epoch 98/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.409, mix_loss=0.427]


Epoch 98: Cls Loss 0.4086 | Mix Loss 0.4273


[Epoch 99/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.404, mix_loss=0.405]


Epoch 99: Cls Loss 0.4039 | Mix Loss 0.4049


[Epoch 100/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.407, mix_loss=0.400]


Epoch 100: Cls Loss 0.4071 | Mix Loss 0.3996


[Epoch 101/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.407, mix_loss=0.402]


Epoch 101: Cls Loss 0.4068 | Mix Loss 0.4016


[Epoch 102/200]: 100%|██████████| 391/391 [00:27<00:00, 14.24it/s, cls_loss=0.398, mix_loss=0.390]


Epoch 102: Cls Loss 0.3976 | Mix Loss 0.3903


[Epoch 103/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.397, mix_loss=0.383]


Epoch 103: Cls Loss 0.3972 | Mix Loss 0.3825


[Epoch 104/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.403, mix_loss=0.385]


Epoch 104: Cls Loss 0.4030 | Mix Loss 0.3850


[Epoch 105/200]: 100%|██████████| 391/391 [00:27<00:00, 14.22it/s, cls_loss=0.404, mix_loss=0.392]


Epoch 105: Cls Loss 0.4036 | Mix Loss 0.3918


[Epoch 106/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.401, mix_loss=0.386]


Epoch 106: Cls Loss 0.4013 | Mix Loss 0.3856


[Epoch 107/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.392, mix_loss=0.384]


Epoch 107: Cls Loss 0.3920 | Mix Loss 0.3843


[Epoch 108/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.382, mix_loss=0.383]


Epoch 108: Cls Loss 0.3820 | Mix Loss 0.3831


[Epoch 109/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.392, mix_loss=0.380]


Epoch 109: Cls Loss 0.3925 | Mix Loss 0.3795


[Epoch 110/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.406, mix_loss=0.391]


Epoch 110: Cls Loss 0.4058 | Mix Loss 0.3906


[Epoch 111/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.399, mix_loss=0.389]


Epoch 111: Cls Loss 0.3991 | Mix Loss 0.3889


[Epoch 112/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.401, mix_loss=0.385]


Epoch 112: Cls Loss 0.4010 | Mix Loss 0.3851


[Epoch 113/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.381, mix_loss=0.381]


Epoch 113: Cls Loss 0.3815 | Mix Loss 0.3810


[Epoch 114/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.410, mix_loss=0.378]


Epoch 114: Cls Loss 0.4104 | Mix Loss 0.3780


[Epoch 115/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.390, mix_loss=0.380]


Epoch 115: Cls Loss 0.3898 | Mix Loss 0.3800


[Epoch 116/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.386, mix_loss=0.382]


Epoch 116: Cls Loss 0.3859 | Mix Loss 0.3820


[Epoch 117/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.385, mix_loss=0.376]


Epoch 117: Cls Loss 0.3849 | Mix Loss 0.3761


[Epoch 118/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.370, mix_loss=0.378]


Epoch 118: Cls Loss 0.3697 | Mix Loss 0.3780


[Epoch 119/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.401, mix_loss=0.379]


Epoch 119: Cls Loss 0.4013 | Mix Loss 0.3787


[Epoch 120/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.382, mix_loss=0.372]


Epoch 120: Cls Loss 0.3815 | Mix Loss 0.3721


[Epoch 121/200]: 100%|██████████| 391/391 [00:27<00:00, 14.24it/s, cls_loss=0.381, mix_loss=0.380]


Epoch 121: Cls Loss 0.3811 | Mix Loss 0.3803


[Epoch 122/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.357, mix_loss=0.381]


Epoch 122: Cls Loss 0.3567 | Mix Loss 0.3810


[Epoch 123/200]: 100%|██████████| 391/391 [00:27<00:00, 14.20it/s, cls_loss=0.363, mix_loss=0.377]


Epoch 123: Cls Loss 0.3629 | Mix Loss 0.3771


[Epoch 124/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.370, mix_loss=0.369]


Epoch 124: Cls Loss 0.3705 | Mix Loss 0.3686


[Epoch 125/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.376, mix_loss=0.379]


Epoch 125: Cls Loss 0.3762 | Mix Loss 0.3790


[Epoch 126/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.363, mix_loss=0.377]


Epoch 126: Cls Loss 0.3628 | Mix Loss 0.3773


[Epoch 127/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.352, mix_loss=0.380]


Epoch 127: Cls Loss 0.3517 | Mix Loss 0.3804


[Epoch 128/200]: 100%|██████████| 391/391 [00:27<00:00, 14.21it/s, cls_loss=0.341, mix_loss=0.377]


Epoch 128: Cls Loss 0.3410 | Mix Loss 0.3771


[Epoch 129/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.352, mix_loss=0.374]


Epoch 129: Cls Loss 0.3522 | Mix Loss 0.3745


[Epoch 130/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.359, mix_loss=0.378]


Epoch 130: Cls Loss 0.3586 | Mix Loss 0.3783


[Epoch 131/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.388, mix_loss=0.375]


Epoch 131: Cls Loss 0.3882 | Mix Loss 0.3755


[Epoch 132/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.419, mix_loss=0.376]


Epoch 132: Cls Loss 0.4191 | Mix Loss 0.3758


[Epoch 133/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.451, mix_loss=0.392]


Epoch 133: Cls Loss 0.4509 | Mix Loss 0.3915


[Epoch 134/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.428, mix_loss=0.406]


Epoch 134: Cls Loss 0.4285 | Mix Loss 0.4062


[Epoch 135/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.423, mix_loss=0.398]


Epoch 135: Cls Loss 0.4230 | Mix Loss 0.3984


[Epoch 136/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.441, mix_loss=0.389]


Epoch 136: Cls Loss 0.4412 | Mix Loss 0.3889


[Epoch 137/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.416, mix_loss=0.383]


Epoch 137: Cls Loss 0.4160 | Mix Loss 0.3834


[Epoch 138/200]: 100%|██████████| 391/391 [00:27<00:00, 14.22it/s, cls_loss=0.436, mix_loss=0.385]


Epoch 138: Cls Loss 0.4359 | Mix Loss 0.3850


[Epoch 139/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.428, mix_loss=0.382]


Epoch 139: Cls Loss 0.4279 | Mix Loss 0.3815


[Epoch 140/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.418, mix_loss=0.384]


Epoch 140: Cls Loss 0.4183 | Mix Loss 0.3841


[Epoch 141/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.436, mix_loss=0.373]


Epoch 141: Cls Loss 0.4360 | Mix Loss 0.3726


[Epoch 142/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.451, mix_loss=0.376]


Epoch 142: Cls Loss 0.4507 | Mix Loss 0.3764


[Epoch 143/200]: 100%|██████████| 391/391 [00:28<00:00, 13.78it/s, cls_loss=0.428, mix_loss=0.379]


Epoch 143: Cls Loss 0.4279 | Mix Loss 0.3790


[Epoch 144/200]: 100%|██████████| 391/391 [00:27<00:00, 14.17it/s, cls_loss=0.396, mix_loss=0.391]


Epoch 144: Cls Loss 0.3963 | Mix Loss 0.3912


[Epoch 145/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.362, mix_loss=0.412]


Epoch 145: Cls Loss 0.3622 | Mix Loss 0.4116


[Epoch 146/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.352, mix_loss=0.404]


Epoch 146: Cls Loss 0.3524 | Mix Loss 0.4038


[Epoch 147/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.346, mix_loss=0.411]


Epoch 147: Cls Loss 0.3463 | Mix Loss 0.4108


[Epoch 148/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.354, mix_loss=0.404]


Epoch 148: Cls Loss 0.3541 | Mix Loss 0.4036


[Epoch 149/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.352, mix_loss=0.403]


Epoch 149: Cls Loss 0.3515 | Mix Loss 0.4031


[Epoch 150/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.364, mix_loss=0.391]


Epoch 150: Cls Loss 0.3639 | Mix Loss 0.3914


[Epoch 151/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.360, mix_loss=0.388]


Epoch 151: Cls Loss 0.3601 | Mix Loss 0.3879


[Epoch 152/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.381, mix_loss=0.379]


Epoch 152: Cls Loss 0.3812 | Mix Loss 0.3791


[Epoch 153/200]: 100%|██████████| 391/391 [00:27<00:00, 14.21it/s, cls_loss=0.373, mix_loss=0.392]


Epoch 153: Cls Loss 0.3726 | Mix Loss 0.3915


[Epoch 154/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.359, mix_loss=0.387]


Epoch 154: Cls Loss 0.3592 | Mix Loss 0.3872


[Epoch 155/200]: 100%|██████████| 391/391 [00:27<00:00, 14.20it/s, cls_loss=0.340, mix_loss=0.383]


Epoch 155: Cls Loss 0.3397 | Mix Loss 0.3831


[Epoch 156/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.339, mix_loss=0.382]


Epoch 156: Cls Loss 0.3390 | Mix Loss 0.3821


[Epoch 157/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.327, mix_loss=0.399]


Epoch 157: Cls Loss 0.3274 | Mix Loss 0.3994


[Epoch 158/200]: 100%|██████████| 391/391 [00:27<00:00, 14.24it/s, cls_loss=0.337, mix_loss=0.394]


Epoch 158: Cls Loss 0.3365 | Mix Loss 0.3943


[Epoch 159/200]: 100%|██████████| 391/391 [00:27<00:00, 14.22it/s, cls_loss=0.329, mix_loss=0.403]


Epoch 159: Cls Loss 0.3294 | Mix Loss 0.4028


[Epoch 160/200]: 100%|██████████| 391/391 [00:27<00:00, 14.16it/s, cls_loss=0.314, mix_loss=0.400]


Epoch 160: Cls Loss 0.3138 | Mix Loss 0.3996


[Epoch 161/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.316, mix_loss=0.397]


Epoch 161: Cls Loss 0.3160 | Mix Loss 0.3967


[Epoch 162/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.310, mix_loss=0.401]


Epoch 162: Cls Loss 0.3103 | Mix Loss 0.4006


[Epoch 163/200]: 100%|██████████| 391/391 [00:27<00:00, 14.20it/s, cls_loss=0.321, mix_loss=0.395]


Epoch 163: Cls Loss 0.3210 | Mix Loss 0.3952


[Epoch 164/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.340, mix_loss=0.384]


Epoch 164: Cls Loss 0.3403 | Mix Loss 0.3842


[Epoch 165/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.345, mix_loss=0.387]


Epoch 165: Cls Loss 0.3453 | Mix Loss 0.3870


[Epoch 166/200]: 100%|██████████| 391/391 [00:27<00:00, 14.15it/s, cls_loss=0.330, mix_loss=0.386]


Epoch 166: Cls Loss 0.3303 | Mix Loss 0.3858


[Epoch 167/200]: 100%|██████████| 391/391 [00:27<00:00, 14.21it/s, cls_loss=0.325, mix_loss=0.387]


Epoch 167: Cls Loss 0.3248 | Mix Loss 0.3869


[Epoch 168/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.330, mix_loss=0.392]


Epoch 168: Cls Loss 0.3299 | Mix Loss 0.3919


[Epoch 169/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.382, mix_loss=0.395]


Epoch 169: Cls Loss 0.3817 | Mix Loss 0.3951


[Epoch 170/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.372, mix_loss=0.408]


Epoch 170: Cls Loss 0.3717 | Mix Loss 0.4083


[Epoch 171/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.408, mix_loss=0.400]


Epoch 171: Cls Loss 0.4075 | Mix Loss 0.4004


[Epoch 172/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.365, mix_loss=0.419]


Epoch 172: Cls Loss 0.3655 | Mix Loss 0.4191


[Epoch 173/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.379, mix_loss=0.419]


Epoch 173: Cls Loss 0.3786 | Mix Loss 0.4185


[Epoch 174/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.362, mix_loss=0.429]


Epoch 174: Cls Loss 0.3618 | Mix Loss 0.4291


[Epoch 175/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.331, mix_loss=0.405]


Epoch 175: Cls Loss 0.3313 | Mix Loss 0.4047


[Epoch 176/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.336, mix_loss=0.398]


Epoch 176: Cls Loss 0.3364 | Mix Loss 0.3981


[Epoch 177/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.341, mix_loss=0.398]


Epoch 177: Cls Loss 0.3411 | Mix Loss 0.3975


[Epoch 178/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.319, mix_loss=0.404]


Epoch 178: Cls Loss 0.3188 | Mix Loss 0.4037


[Epoch 179/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.315, mix_loss=0.408]


Epoch 179: Cls Loss 0.3145 | Mix Loss 0.4084


[Epoch 180/200]: 100%|██████████| 391/391 [00:27<00:00, 14.34it/s, cls_loss=0.296, mix_loss=0.397]


Epoch 180: Cls Loss 0.2964 | Mix Loss 0.3975


[Epoch 181/200]: 100%|██████████| 391/391 [00:27<00:00, 14.31it/s, cls_loss=0.293, mix_loss=0.406]


Epoch 181: Cls Loss 0.2935 | Mix Loss 0.4055


[Epoch 182/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.287, mix_loss=0.404]


Epoch 182: Cls Loss 0.2867 | Mix Loss 0.4039


[Epoch 183/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.288, mix_loss=0.403]


Epoch 183: Cls Loss 0.2881 | Mix Loss 0.4027


[Epoch 184/200]: 100%|██████████| 391/391 [00:27<00:00, 14.24it/s, cls_loss=0.295, mix_loss=0.386]


Epoch 184: Cls Loss 0.2951 | Mix Loss 0.3857


[Epoch 185/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.278, mix_loss=0.382]


Epoch 185: Cls Loss 0.2781 | Mix Loss 0.3819


[Epoch 186/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.274, mix_loss=0.385]


Epoch 186: Cls Loss 0.2741 | Mix Loss 0.3850


[Epoch 187/200]: 100%|██████████| 391/391 [00:27<00:00, 14.23it/s, cls_loss=0.269, mix_loss=0.380]


Epoch 187: Cls Loss 0.2689 | Mix Loss 0.3795


[Epoch 188/200]: 100%|██████████| 391/391 [00:27<00:00, 14.33it/s, cls_loss=0.261, mix_loss=0.388]


Epoch 188: Cls Loss 0.2608 | Mix Loss 0.3880


[Epoch 189/200]: 100%|██████████| 391/391 [00:27<00:00, 14.25it/s, cls_loss=0.260, mix_loss=0.387]


Epoch 189: Cls Loss 0.2596 | Mix Loss 0.3868


[Epoch 190/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.257, mix_loss=0.395]


Epoch 190: Cls Loss 0.2570 | Mix Loss 0.3947


[Epoch 191/200]: 100%|██████████| 391/391 [00:27<00:00, 14.29it/s, cls_loss=0.258, mix_loss=0.384]


Epoch 191: Cls Loss 0.2578 | Mix Loss 0.3845


[Epoch 192/200]: 100%|██████████| 391/391 [00:27<00:00, 14.32it/s, cls_loss=0.259, mix_loss=0.375]


Epoch 192: Cls Loss 0.2593 | Mix Loss 0.3750


[Epoch 193/200]: 100%|██████████| 391/391 [00:27<00:00, 14.21it/s, cls_loss=0.246, mix_loss=0.377]


Epoch 193: Cls Loss 0.2455 | Mix Loss 0.3767


[Epoch 194/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.247, mix_loss=0.368]


Epoch 194: Cls Loss 0.2475 | Mix Loss 0.3677


[Epoch 195/200]: 100%|██████████| 391/391 [00:27<00:00, 14.30it/s, cls_loss=0.235, mix_loss=0.373]


Epoch 195: Cls Loss 0.2351 | Mix Loss 0.3732


[Epoch 196/200]: 100%|██████████| 391/391 [00:27<00:00, 14.26it/s, cls_loss=0.241, mix_loss=0.373]


Epoch 196: Cls Loss 0.2407 | Mix Loss 0.3728


[Epoch 197/200]: 100%|██████████| 391/391 [00:27<00:00, 14.28it/s, cls_loss=0.235, mix_loss=0.369]


Epoch 197: Cls Loss 0.2351 | Mix Loss 0.3688


[Epoch 198/200]: 100%|██████████| 391/391 [00:27<00:00, 14.21it/s, cls_loss=0.234, mix_loss=0.372]


Epoch 198: Cls Loss 0.2340 | Mix Loss 0.3718


[Epoch 199/200]: 100%|██████████| 391/391 [00:27<00:00, 14.22it/s, cls_loss=0.229, mix_loss=0.375]


Epoch 199: Cls Loss 0.2289 | Mix Loss 0.3746


[Epoch 200/200]: 100%|██████████| 391/391 [00:27<00:00, 14.27it/s, cls_loss=0.237, mix_loss=0.381]

Epoch 200: Cls Loss 0.2374 | Mix Loss 0.3810
✅ Training complete.





In [13]:
# ==============================================================
# 📘 Section 6 — Evaluation under Adversarial Attacks
# ==============================================================

def fgsm_attack(model, images, labels, eps=0.03):
    """Implements FGSM: x_adv = x + eps * sign(∇x L(x, y))."""
    images = images.clone().detach().to(device)
    images.requires_grad = True

    outputs, _, _ = model(images)
    loss = F.cross_entropy(outputs, labels)
    loss.backward()
    adv_images = images + eps * images.grad.sign()
    adv_images = torch.clamp(adv_images, -1, 1)   # keep normalized range
    return adv_images.detach()

def ifgsm_attack(model, images, labels, eps=0.03, alpha=0.005, iters=10):
    """Iterative FGSM (a.k.a. PGD-lite)."""
    images = images.clone().detach().to(device)
    adv_images = images.clone().detach()

    for _ in range(iters):
        adv_images.requires_grad = True
        outputs, _, _ = model(adv_images)
        loss = F.cross_entropy(outputs, labels)
        model.zero_grad()
        loss.backward()
        adv_images = adv_images + alpha * adv_images.grad.sign()
        eta = torch.clamp(adv_images - images, min=-eps, max=eps)
        adv_images = torch.clamp(images + eta, -1, 1).detach()
    return adv_images


# --------------------------------------------------------------
# 🔹 Evaluation function
# --------------------------------------------------------------
def evaluate(model, dataloader, attack=None):
    model.eval()
    correct, total = 0, 0
    for imgs, labels in tqdm(dataloader, desc=f"[Eval {attack or 'Clean'}]"):
        imgs, labels = imgs.to(device), labels.to(device)
        if attack == 'FGSM':
            imgs = fgsm_attack(model, imgs, labels)
        elif attack == 'iFGSM':
            imgs = ifgsm_attack(model, imgs, labels)
        with torch.no_grad():
            outputs, _, _ = model(imgs)
            preds = outputs.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return 100 * correct / total


# --------------------------------------------------------------
# 🔹 Evaluate RaCNN vs. Vanilla CNN
# --------------------------------------------------------------
print("\n🧪 Evaluating robustness (CIFAR-10 test set)...")

# Baseline CNN (vanilla)
vanilla = get_resnet18_backbone(pretrained=False)
vanilla.fc = nn.Linear(512, 10)
vanilla = vanilla.to(device)
# assume it's randomly initialized — used to show difference in robustness

acc_clean_racnn = evaluate(model, testloader, None)
acc_fgsm_racnn  = evaluate(model, testloader, 'FGSM')
acc_ifgsm_racnn = evaluate(model, testloader, 'iFGSM')

print(f"✅ RaCNN accuracy — Clean: {acc_clean_racnn:.2f}% | FGSM: {acc_fgsm_racnn:.2f}% | iFGSM: {acc_ifgsm_racnn:.2f}%")



🧪 Evaluating robustness (CIFAR-10 test set)...


[Eval Clean]: 100%|██████████| 79/79 [00:11<00:00,  7.15it/s]
[Eval FGSM]: 100%|██████████| 79/79 [00:25<00:00,  3.13it/s]
[Eval iFGSM]: 100%|██████████| 79/79 [02:09<00:00,  1.63s/it]

✅ RaCNN accuracy — Clean: 69.02% | FGSM: 29.82% | iFGSM: 20.63%





In [12]:
# ==============================================================
# 📘 Extended Adversarial Evaluation — Full Suite (Section 6)
# ==============================================================
import torchattacks
from tqdm import tqdm
import torch.nn as nn

# --------------------------------------------------------------
# 🔹 Fix for torchattacks — wrapper that only returns logits
# --------------------------------------------------------------
class LogitsWrapper(nn.Module):
    def __init__(self, racnn):
        super().__init__()
        self.racnn = racnn
    def forward(self, x):
        logits, _, _ = self.racnn(x)
        return logits

# use wrapper for attacks (torchattacks expects logits only)
wrapped_model = LogitsWrapper(model)

# --------------------------------------------------------------
# 🔹 Define attacks (matches Section 6 of the paper)
# --------------------------------------------------------------
atk_fgsm     = torchattacks.FGSM(wrapped_model, eps=0.03)
atk_ifgsm    = torchattacks.PGD(wrapped_model, eps=0.03, alpha=0.005, steps=10)   # ≈ iFGSM
atk_pgd      = torchattacks.PGD(wrapped_model, eps=0.03, alpha=0.005, steps=20)   # stronger PGD
atk_deepfool = torchattacks.DeepFool(wrapped_model, steps=50)
atk_cw       = torchattacks.CW(wrapped_model, c=1, kappa=0, steps=50, lr=0.01)
atk_noise    = lambda x, y: torch.clamp(x + 0.03 * torch.randn_like(x), -1, 1)  # random Gaussian noise

# Wrap all attacks in a dict
attacks = {
    "Clean":    None,
    "FGSM":     atk_fgsm,
    "iFGSM":    atk_ifgsm,
    "PGD":      atk_pgd,
    "DeepFool": atk_deepfool,
    "CW":       atk_cw,
    "Noise":    atk_noise,
}

# --------------------------------------------------------------
# 🔹 Evaluation loop
# --------------------------------------------------------------
def evaluate_adv(model, dataloader, attacks):
    """Evaluate model under multiple attacks."""
    model.eval()
    results = {}
    for name, atk in attacks.items():
        correct, total = 0, 0
        for imgs, labels in tqdm(dataloader, desc=f"[Eval {name}]"):
            imgs, labels = imgs.to(device), labels.to(device)
            if atk is not None:
                if callable(atk):  # noise
                    imgs = atk(imgs, labels)
                else:
                    imgs = atk(imgs, labels)
            with torch.no_grad():
                outputs, _, _ = model(imgs)
                preds = outputs.argmax(1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)
        acc = 100 * correct / total
        results[name] = acc
        print(f"{name:10s}: {acc:.2f}%")
    return results

# --------------------------------------------------------------
# 🚀 Run full evaluation
# --------------------------------------------------------------
print("🧪 Evaluating RaCNN robustness under full attack suite (CIFAR-10)...")
results = evaluate_adv(model, testloader, attacks)

print("\n✅ RaCNN Robustness Summary:")
for k, v in results.items():
    print(f"{k:10s}: {v:.2f}%")


🧪 Evaluating RaCNN robustness under full attack suite (CIFAR-10)...


[Eval Clean]: 100%|██████████| 79/79 [00:06<00:00, 11.98it/s]


Clean     : 84.55%


[Eval FGSM]: 100%|██████████| 79/79 [00:08<00:00,  9.31it/s]


FGSM      : 51.84%


[Eval iFGSM]: 100%|██████████| 79/79 [00:25<00:00,  3.13it/s]


iFGSM     : 11.78%


[Eval PGD]: 100%|██████████| 79/79 [00:43<00:00,  1.81it/s]


PGD       : 9.71%


[Eval DeepFool]: 100%|██████████| 79/79 [33:43<00:00, 25.61s/it]


DeepFool  : 10.84%


[Eval CW]: 100%|██████████| 79/79 [01:58<00:00,  1.50s/it]


CW        : 10.00%


[Eval Noise]: 100%|██████████| 79/79 [00:06<00:00, 12.17it/s]

Noise     : 82.48%

✅ RaCNN Robustness Summary:
Clean     : 84.55%
FGSM      : 51.84%
iFGSM     : 11.78%
PGD       : 9.71%
DeepFool  : 10.84%
CW        : 10.00%
Noise     : 82.48%



