In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import numpy as np
import random
from tqdm import tqdm
import os

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


batch_size = 64
learning_rate = 0.03
num_epochs = 80

epsilon = 8 / 255
alpha = 2 / 255

beta = 1.0

# gradually increasing attack strength of PGD attacks
pgd_schedule = {
    1: 1,
    5: 2,
    10: 3,
    20: 5
}

# AWP parameters to ensure robustness
awp_gamma = 0.002
awp_eps   = 0.005
awp_start_epoch = 10

transform_train = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
])

Using device: cuda


In [None]:
CLEAN_ROOT = '/kaggle/input/imagenet100/train.X1'
VAL_ROOT   = '/kaggle/input/imagenet100/val.X'

train_dataset = torchvision.datasets.ImageFolder(
    root=CLEAN_ROOT,
    transform=transform_train
)

class_mapping = train_dataset.class_to_idx

# custom dataset to ensure consistent class mapping
class AlignedImageFolder(torchvision.datasets.ImageFolder):
    def find_classes(self, directory):
        return list(class_mapping.keys()), class_mapping

test_dataset = AlignedImageFolder(
    root=VAL_ROOT,
    transform=transform_test
)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)


In [None]:
# loading folder of images and giving binary labels
class Normalize(nn.Module):
    def __init__(self, mean, std):
        super().__init__()
        self.register_buffer("mean", torch.tensor(mean).view(1,3,1,1))
        self.register_buffer("std", torch.tensor(std).view(1,3,1,1))
    def forward(self, x):
        return (x - self.mean) / self.std

base_model = resnet18(weights="IMAGENET1K_V1")
base_model.fc = nn.Linear(base_model.fc.in_features, 100)

model = nn.Sequential(
    Normalize([0.485, 0.456, 0.406],
              [0.229, 0.224, 0.225]),
    base_model
).to(device)


In [None]:

optimizer = optim.SGD(
    model.parameters(),
    lr=learning_rate,
    momentum=0.9,
    weight_decay=5e-4
)

scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer,
    milestones=[20, 35],
    gamma=0.1
)

# MART loss function
def mart_loss(model, x, y, num_steps):
    x_adv = x.detach() + 0.001 * torch.randn_like(x)

    for _ in range(num_steps):
        x_adv.requires_grad_()
        logits_adv = model(x_adv)
        loss = F.cross_entropy(logits_adv, y)
        grad = torch.autograd.grad(loss, x_adv)[0]
        x_adv = x_adv + alpha * torch.sign(grad)
        x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
        x_adv = torch.clamp(x_adv, 0, 1).detach()

    logits_clean = model(x)
    logits_adv = model(x_adv)

    prob_clean = F.softmax(logits_clean, dim=1)
    true_prob = prob_clean.gather(1, y.unsqueeze(1)).squeeze()

    ce_adv = F.cross_entropy(logits_adv, y)

    kl = F.kl_div(
        F.log_softmax(logits_adv, dim=1),
        prob_clean,
        reduction="none"
    ).sum(dim=1)

    mart = torch.mean(kl * (1 - true_prob))
    return ce_adv + beta * mart

# AWP implementation
class AWP:
    def __init__(self, model, gamma, eps):
        self.model = model
        self.gamma = gamma
        self.eps = eps
        self.backup = {}

    @torch.no_grad()
    def perturb(self):
        for n, p in self.model.named_parameters():
            if p.requires_grad and p.grad is not None and "weight" in n:
                norm = torch.norm(p.grad)
                if norm > 0:
                    self.backup[n] = p.data.clone()
                    p.data.add_(self.gamma * p.grad / (norm + 1e-8))
                    p.data.clamp_(
                        self.backup[n] - self.eps,
                        self.backup[n] + self.eps
                    )

    @torch.no_grad()
    def restore(self):
        for n, p in self.model.named_parameters():
            if n in self.backup:
                p.data.copy_(self.backup[n])
        self.backup = {}

awp = AWP(model, awp_gamma, awp_eps)


In [6]:
def evaluate(model, loader, adversarial=False, steps=10):
    model.eval()
    correct, total = 0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        if adversarial:
            x_adv = x.clone()
            for _ in range(steps):
                x_adv.requires_grad_()
                loss = F.cross_entropy(model(x_adv), y)
                grad = torch.autograd.grad(loss, x_adv)[0]
                x_adv = x_adv + alpha * torch.sign(grad)
                x_adv = torch.min(torch.max(x_adv, x - epsilon), x + epsilon)
                x_adv = torch.clamp(x_adv, 0, 1).detach()
            out = model(x_adv)
        else:
            out = model(x)

        pred = out.argmax(1)
        correct += pred.eq(y).sum().item()
        total += y.size(0)

    return 100 * correct / total


In [7]:
print("Starting MART training (corrected)...")

for epoch in range(1, num_epochs + 1):
    model.train()

    num_steps = max(v for k, v in pgd_schedule.items() if epoch >= k)
    total_loss = 0

    loop = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}")

    for x, y in loop:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()
        loss = mart_loss(model, x, y, num_steps)
        loss.backward()

        if epoch >= awp_start_epoch:
            awp.perturb()
            loss_awp = mart_loss(model, x, y, num_steps)
            loss_awp.backward()
            awp.restore()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        loop.set_postfix(loss=loss.item())

    scheduler.step()

    clean_acc = evaluate(model, test_loader, False)
    adv_acc = evaluate(model, test_loader, True, steps=num_steps)

    print(
        f"Epoch {epoch:03d} | "
        f"Loss {total_loss / len(train_loader.dataset):.4f} | "
        f"Clean {clean_acc:.2f}% | "
        f"Adv {adv_acc:.2f}%"
    )

Starting MART training (corrected)...


Epoch 1/80: 100%|██████████| 508/508 [02:33<00:00,  3.30it/s, loss=0.814]


Epoch 001 | Loss 1.6140 | Clean 79.84% | Adv 63.20%


Epoch 2/80: 100%|██████████| 508/508 [02:33<00:00,  3.31it/s, loss=0.864]


Epoch 002 | Loss 1.0641 | Clean 83.12% | Adv 68.40%


Epoch 3/80: 100%|██████████| 508/508 [02:35<00:00,  3.26it/s, loss=0.897]


Epoch 003 | Loss 0.9292 | Clean 82.72% | Adv 68.48%


Epoch 4/80: 100%|██████████| 508/508 [02:33<00:00,  3.31it/s, loss=0.788]


Epoch 004 | Loss 0.8444 | Clean 84.56% | Adv 70.00%


Epoch 5/80: 100%|██████████| 508/508 [03:18<00:00,  2.56it/s, loss=1.42] 


Epoch 005 | Loss 1.3060 | Clean 79.20% | Adv 59.52%


Epoch 6/80: 100%|██████████| 508/508 [03:15<00:00,  2.59it/s, loss=1.17] 


Epoch 006 | Loss 1.2346 | Clean 79.84% | Adv 58.32%


Epoch 7/80: 100%|██████████| 508/508 [03:15<00:00,  2.59it/s, loss=1.08] 


Epoch 007 | Loss 1.1860 | Clean 80.16% | Adv 61.28%


Epoch 8/80: 100%|██████████| 508/508 [03:15<00:00,  2.59it/s, loss=1.44] 


Epoch 008 | Loss 1.1441 | Clean 82.08% | Adv 60.24%


Epoch 9/80: 100%|██████████| 508/508 [03:15<00:00,  2.59it/s, loss=1.52] 


Epoch 009 | Loss 1.1091 | Clean 81.52% | Adv 60.40%


Epoch 10/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.5]  


Epoch 010 | Loss 1.5037 | Clean 75.92% | Adv 50.96%


Epoch 11/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.42]


Epoch 011 | Loss 1.4585 | Clean 77.60% | Adv 53.92%


Epoch 12/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.26]


Epoch 012 | Loss 1.4391 | Clean 78.24% | Adv 53.04%


Epoch 13/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.83] 


Epoch 013 | Loss 1.4266 | Clean 76.24% | Adv 52.88%


Epoch 14/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.63] 


Epoch 014 | Loss 1.4032 | Clean 76.24% | Adv 52.16%


Epoch 15/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.39]


Epoch 015 | Loss 1.3961 | Clean 75.76% | Adv 51.52%


Epoch 16/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.55] 


Epoch 016 | Loss 1.3873 | Clean 78.16% | Adv 52.80%


Epoch 17/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.39]


Epoch 017 | Loss 1.3778 | Clean 75.84% | Adv 53.52%


Epoch 18/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.61] 


Epoch 018 | Loss 1.3620 | Clean 77.44% | Adv 54.56%


Epoch 19/80: 100%|██████████| 508/508 [07:57<00:00,  1.06it/s, loss=1.56] 


Epoch 019 | Loss 1.3647 | Clean 74.88% | Adv 53.20%


Epoch 20/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.63]


Epoch 020 | Loss 1.7356 | Clean 71.52% | Adv 44.00%


Epoch 21/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.43]


Epoch 021 | Loss 1.4765 | Clean 78.32% | Adv 51.68%


Epoch 22/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.41] 


Epoch 022 | Loss 1.3769 | Clean 77.84% | Adv 48.96%


Epoch 23/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.38]


Epoch 023 | Loss 1.3401 | Clean 78.32% | Adv 50.72%


Epoch 24/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.32] 


Epoch 024 | Loss 1.3179 | Clean 79.12% | Adv 50.88%


Epoch 25/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.43] 


Epoch 025 | Loss 1.2915 | Clean 78.16% | Adv 51.36%


Epoch 26/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.41] 


Epoch 026 | Loss 1.2670 | Clean 80.32% | Adv 50.00%


Epoch 27/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.13] 


Epoch 027 | Loss 1.2531 | Clean 80.56% | Adv 50.40%


Epoch 28/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.22] 


Epoch 028 | Loss 1.2401 | Clean 78.88% | Adv 49.84%


Epoch 29/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.36] 


Epoch 029 | Loss 1.2252 | Clean 78.88% | Adv 50.64%


Epoch 30/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.11] 


Epoch 030 | Loss 1.2138 | Clean 79.84% | Adv 49.52%


Epoch 31/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.46] 


Epoch 031 | Loss 1.1984 | Clean 80.48% | Adv 50.80%


Epoch 32/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.05] 


Epoch 032 | Loss 1.1887 | Clean 80.00% | Adv 49.76%


Epoch 33/80: 100%|██████████| 508/508 [10:48<00:00,  1.28s/it, loss=1.24] 


Epoch 033 | Loss 1.1792 | Clean 79.52% | Adv 50.64%


Epoch 34/80:   6%|▌         | 31/508 [00:40<10:19,  1.30s/it, loss=1.23] 


KeyboardInterrupt: 

In [8]:
torch.save({
    "epoch": epoch,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "scheduler_state": scheduler.state_dict()
}, "mart_awp_resnet18_imagenet100.pth")

print("Training completed and model saved!")

Training completed and model saved!
