In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.models import resnet18
import numpy as np
import random
from tqdm import tqdm

In [None]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

batch_size = 128
learning_rate = 0.01
num_epochs = 80
epsilon = 8 / 255
alpha = 2 / 255
num_steps = 10
beta = 6.0
awp_gamma = 0.005
awp_eps = 0.01

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

model = resnet18(weights="IMAGENET1K_V1")
model.fc = nn.Linear(model.fc.in_features, 10)
model = model.to(device)
print("Loaded pretrained ResNet-18 with ImageNet weights")

def mart_loss(model, x_natural, y):
    model.train()
    batch_size = len(x_natural)
    x_adv = x_natural.detach() + 0.001 * torch.randn_like(x_natural)
    for _ in range(num_steps):
        x_adv.requires_grad_()
        logits_adv = model(x_adv)
        loss_adv = F.cross_entropy(logits_adv, y)
        grad = torch.autograd.grad(loss_adv, [x_adv])[0]
        x_adv = x_adv + alpha * torch.sign(grad)
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0).detach()
    logits_nat = model(x_natural)
    logits_adv = model(x_adv)
    nat_probs = F.softmax(logits_nat, dim=1)
    true_probs = nat_probs.gather(1, y.unsqueeze(1)).squeeze()
    loss_adv = F.cross_entropy(logits_adv, y)
    loss_robust = F.kl_div(F.log_softmax(logits_adv, dim=1), nat_probs, reduction='none').sum(dim=1)
    loss_robust = torch.mean(loss_robust * (1.0 - true_probs))
    return loss_adv + beta * loss_robust

optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

class AWP:
    def __init__(self, model, optimizer, gamma=0.005, eps=0.01):
        self.model = model
        self.optimizer = optimizer
        self.gamma = gamma
        self.eps = eps
        self.backup = {}

    @torch.no_grad()
    def perturb(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad and 'weight' in name:
                grad_norm = torch.norm(param.grad)
                if grad_norm != 0:
                    r_at = self.gamma * param.grad / (grad_norm + 1e-8)
                    self.backup[name] = param.data.clone()
                    param.data.add_(r_at)
                    param.data.clamp_(self.backup[name] - self.eps, self.backup[name] + self.eps)

    @torch.no_grad()
    def restore(self):
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}

awp = AWP(model, optimizer, gamma=awp_gamma, eps=awp_eps)

def evaluate(model, loader, adversarial=False):
    model.eval()
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        if adversarial:
            x_adv = images.detach() + 0.001 * torch.randn_like(images)
            for _ in range(num_steps):
                x_adv.requires_grad_()
                logits = model(x_adv)
                loss = F.cross_entropy(logits, labels)
                grad = torch.autograd.grad(loss, [x_adv])[0]
                x_adv = x_adv + alpha * torch.sign(grad)
                x_adv = torch.min(torch.max(x_adv, images - epsilon), images + epsilon)
                x_adv = torch.clamp(x_adv, 0.0, 1.0).detach()
            outputs = model(x_adv)
        else:
            outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    return 100.0 * correct / total

print("Starting MART + AWP training...")
for epoch in range(1, num_epochs + 1):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = mart_loss(model, images, labels)
        loss.backward()
        awp.perturb()
        optimizer.step()
        awp.restore()
        running_loss += loss.item() * images.size(0)
    scheduler.step()
    avg_loss = running_loss / len(train_loader.dataset)
    clean_acc = evaluate(model, test_loader, adversarial=False)
    adv_acc = evaluate(model, test_loader, adversarial=True)
    print(f"Epoch [{epoch}/{num_epochs}] | Loss: {avg_loss:.4f} | Clean Acc: {clean_acc:.2f}% | Adv Acc: {adv_acc:.2f}%")

torch.save(model.state_dict(), "mart_awp_resnet18_cifar10.pth")
print("Training completed and model saved (MART + AWP)!")

Using device: cuda


100%|██████████| 170M/170M [00:02<00:00, 57.6MB/s] 
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 120MB/s] 


Loaded pretrained ResNet-18 with ImageNet weights
Starting MART training...


Epoch 1/80: 100%|██████████| 391/391 [01:53<00:00,  3.45it/s]


Epoch [1/80] | Loss: 2.4393 | Clean Acc: 10.76% | Adv Acc: 9.83%


Epoch 2/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [2/80] | Loss: 2.3463 | Clean Acc: 13.03% | Adv Acc: 11.85%


Epoch 3/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [3/80] | Loss: 2.2807 | Clean Acc: 19.97% | Adv Acc: 17.15%


Epoch 4/80: 100%|██████████| 391/391 [01:52<00:00,  3.49it/s]


Epoch [4/80] | Loss: 2.2271 | Clean Acc: 23.82% | Adv Acc: 18.35%


Epoch 5/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [5/80] | Loss: 2.1772 | Clean Acc: 31.14% | Adv Acc: 23.69%


Epoch 7/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [7/80] | Loss: 2.1100 | Clean Acc: 32.35% | Adv Acc: 24.54%


Epoch 8/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [8/80] | Loss: 2.0854 | Clean Acc: 39.63% | Adv Acc: 27.99%


Epoch 9/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [9/80] | Loss: 2.0902 | Clean Acc: 36.05% | Adv Acc: 26.07%


Epoch 10/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [10/80] | Loss: 2.0657 | Clean Acc: 39.38% | Adv Acc: 28.68%


Epoch 11/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [11/80] | Loss: 2.0435 | Clean Acc: 39.41% | Adv Acc: 29.13%


Epoch 12/80: 100%|██████████| 391/391 [01:52<00:00,  3.49it/s]


Epoch [12/80] | Loss: 2.0307 | Clean Acc: 41.41% | Adv Acc: 29.63%


Epoch 13/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [13/80] | Loss: 2.0233 | Clean Acc: 42.93% | Adv Acc: 30.42%


Epoch 14/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [14/80] | Loss: 2.0093 | Clean Acc: 43.27% | Adv Acc: 30.10%


Epoch 15/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [15/80] | Loss: 1.9978 | Clean Acc: 44.07% | Adv Acc: 30.94%


Epoch 16/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [16/80] | Loss: 1.9886 | Clean Acc: 44.34% | Adv Acc: 31.11%


Epoch 17/80: 100%|██████████| 391/391 [01:52<00:00,  3.49it/s]


Epoch [17/80] | Loss: 1.9810 | Clean Acc: 45.55% | Adv Acc: 32.40%


Epoch 18/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [18/80] | Loss: 1.9748 | Clean Acc: 43.70% | Adv Acc: 30.90%


Epoch 19/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [19/80] | Loss: 1.9659 | Clean Acc: 46.29% | Adv Acc: 33.46%


Epoch 20/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [20/80] | Loss: 1.9590 | Clean Acc: 46.00% | Adv Acc: 31.98%


Epoch 21/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [21/80] | Loss: 1.9501 | Clean Acc: 47.52% | Adv Acc: 33.59%


Epoch 22/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [22/80] | Loss: 1.9407 | Clean Acc: 46.76% | Adv Acc: 33.24%


Epoch 23/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [23/80] | Loss: 1.9369 | Clean Acc: 48.12% | Adv Acc: 32.93%


Epoch 24/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [24/80] | Loss: 1.9299 | Clean Acc: 49.06% | Adv Acc: 34.29%


Epoch 25/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [25/80] | Loss: 1.9243 | Clean Acc: 46.94% | Adv Acc: 33.04%


Epoch 26/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [26/80] | Loss: 1.9187 | Clean Acc: 48.90% | Adv Acc: 33.57%


Epoch 27/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [27/80] | Loss: 1.9108 | Clean Acc: 46.75% | Adv Acc: 33.91%


Epoch 28/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [28/80] | Loss: 1.9071 | Clean Acc: 49.43% | Adv Acc: 33.82%


Epoch 29/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [29/80] | Loss: 1.9007 | Clean Acc: 48.83% | Adv Acc: 33.78%


Epoch 30/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [30/80] | Loss: 1.8970 | Clean Acc: 50.45% | Adv Acc: 35.33%


Epoch 31/80: 100%|██████████| 391/391 [01:52<00:00,  3.49it/s]


Epoch [31/80] | Loss: 1.9187 | Clean Acc: 48.46% | Adv Acc: 33.99%


Epoch 32/80: 100%|██████████| 391/391 [01:52<00:00,  3.49it/s]


Epoch [32/80] | Loss: 1.9097 | Clean Acc: 49.27% | Adv Acc: 35.31%


Epoch 33/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [33/80] | Loss: 1.8931 | Clean Acc: 50.78% | Adv Acc: 36.08%


Epoch 34/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [34/80] | Loss: 1.8830 | Clean Acc: 51.22% | Adv Acc: 35.58%


Epoch 35/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [35/80] | Loss: 1.8768 | Clean Acc: 52.18% | Adv Acc: 35.58%


Epoch 36/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [36/80] | Loss: 1.8707 | Clean Acc: 52.29% | Adv Acc: 36.96%


Epoch 37/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [37/80] | Loss: 1.8686 | Clean Acc: 49.36% | Adv Acc: 35.70%


Epoch 38/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [38/80] | Loss: 1.8602 | Clean Acc: 53.28% | Adv Acc: 36.94%


Epoch 39/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [39/80] | Loss: 1.8535 | Clean Acc: 52.47% | Adv Acc: 36.79%


Epoch 40/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [40/80] | Loss: 1.8475 | Clean Acc: 53.47% | Adv Acc: 37.19%


Epoch 41/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [41/80] | Loss: 1.8434 | Clean Acc: 53.38% | Adv Acc: 36.12%


Epoch 42/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [42/80] | Loss: 1.8400 | Clean Acc: 54.27% | Adv Acc: 37.05%


Epoch 43/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [43/80] | Loss: 1.8333 | Clean Acc: 54.66% | Adv Acc: 37.59%


Epoch 44/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [44/80] | Loss: 1.8306 | Clean Acc: 52.51% | Adv Acc: 37.88%


Epoch 45/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [45/80] | Loss: 1.8242 | Clean Acc: 52.18% | Adv Acc: 36.34%


Epoch 46/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [46/80] | Loss: 1.8218 | Clean Acc: 54.19% | Adv Acc: 36.24%


Epoch 47/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [47/80] | Loss: 1.8175 | Clean Acc: 55.08% | Adv Acc: 38.22%


Epoch 48/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [48/80] | Loss: 1.8107 | Clean Acc: 55.29% | Adv Acc: 38.45%


Epoch 49/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [49/80] | Loss: 1.8069 | Clean Acc: 57.04% | Adv Acc: 38.57%


Epoch 50/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [50/80] | Loss: 1.8131 | Clean Acc: 52.51% | Adv Acc: 36.82%


Epoch 51/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [51/80] | Loss: 1.8217 | Clean Acc: 55.50% | Adv Acc: 38.49%


Epoch 52/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [52/80] | Loss: 1.8008 | Clean Acc: 55.90% | Adv Acc: 36.88%


Epoch 53/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [53/80] | Loss: 1.7942 | Clean Acc: 56.59% | Adv Acc: 39.32%


Epoch 54/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [54/80] | Loss: 1.7881 | Clean Acc: 56.08% | Adv Acc: 38.61%


Epoch 55/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [55/80] | Loss: 1.7827 | Clean Acc: 55.50% | Adv Acc: 38.53%


Epoch 56/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [56/80] | Loss: 1.7788 | Clean Acc: 56.49% | Adv Acc: 38.74%


Epoch 57/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [57/80] | Loss: 1.7718 | Clean Acc: 55.58% | Adv Acc: 37.53%


Epoch 58/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [58/80] | Loss: 1.7690 | Clean Acc: 56.76% | Adv Acc: 39.70%


Epoch 59/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [59/80] | Loss: 1.7647 | Clean Acc: 58.18% | Adv Acc: 39.32%


Epoch 60/80: 100%|██████████| 391/391 [01:51<00:00,  3.49it/s]


Epoch [60/80] | Loss: 1.7607 | Clean Acc: 57.73% | Adv Acc: 39.50%


Epoch 61/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [61/80] | Loss: 1.7542 | Clean Acc: 57.74% | Adv Acc: 39.88%


Epoch 62/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [62/80] | Loss: 1.7510 | Clean Acc: 56.05% | Adv Acc: 39.78%


Epoch 63/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [63/80] | Loss: 1.7490 | Clean Acc: 57.34% | Adv Acc: 39.77%


Epoch 64/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [64/80] | Loss: 1.7441 | Clean Acc: 57.31% | Adv Acc: 39.87%


Epoch 65/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [65/80] | Loss: 1.7395 | Clean Acc: 57.04% | Adv Acc: 39.33%


Epoch 66/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [66/80] | Loss: 1.7379 | Clean Acc: 58.19% | Adv Acc: 39.47%


Epoch 67/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [67/80] | Loss: 1.7342 | Clean Acc: 57.12% | Adv Acc: 39.79%


Epoch 68/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [68/80] | Loss: 1.7318 | Clean Acc: 58.11% | Adv Acc: 39.28%


Epoch 69/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [69/80] | Loss: 1.7255 | Clean Acc: 56.59% | Adv Acc: 39.64%


Epoch 70/80: 100%|██████████| 391/391 [01:51<00:00,  3.50it/s]


Epoch [70/80] | Loss: 1.7235 | Clean Acc: 57.52% | Adv Acc: 38.91%


Epoch 71/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [71/80] | Loss: 1.7215 | Clean Acc: 56.85% | Adv Acc: 39.51%


Epoch 72/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [72/80] | Loss: 1.7181 | Clean Acc: 58.29% | Adv Acc: 39.98%


Epoch 73/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [73/80] | Loss: 1.7166 | Clean Acc: 58.08% | Adv Acc: 39.57%


Epoch 74/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [74/80] | Loss: 1.7150 | Clean Acc: 58.09% | Adv Acc: 40.14%


Epoch 75/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [75/80] | Loss: 1.7159 | Clean Acc: 57.81% | Adv Acc: 39.27%


Epoch 76/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [76/80] | Loss: 1.7133 | Clean Acc: 58.10% | Adv Acc: 40.12%


Epoch 77/80:  16%|█▌        | 62/391 [00:17<01:33,  3.51it/s]

Epoch [77/80] | Loss: 1.7100 | Clean Acc: 57.94% | Adv Acc: 39.53%


Epoch 78/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [78/80] | Loss: 1.7093 | Clean Acc: 57.85% | Adv Acc: 40.14%


Epoch 79/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [79/80] | Loss: 1.7122 | Clean Acc: 58.29% | Adv Acc: 39.93%


Epoch 80/80: 100%|██████████| 391/391 [01:51<00:00,  3.51it/s]


Epoch [80/80] | Loss: 1.7068 | Clean Acc: 58.40% | Adv Acc: 39.81%
Training completed and model saved!


In [2]:
torch.save(model.state_dict(), "mart_resnet18_cifar10.pth")
print("Training completed and model saved!")

Training completed and model saved!
