In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader
from resnet_wrapper import ResNetWrapper
import os

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_classes = 10

# --- Hyperparams ---
batch_size = 128
m_replay = 4            # number of replay steps
epsilon = 8 / 255       # max perturbation
alpha = 2 / 255         # step size
epochs = 30

In [3]:
# --- Data ---
transform = T.Compose([
    T.RandomCrop(32, padding=4),
    T.RandomHorizontalFlip(),
    T.ToTensor()
])

trainset = torchvision.datasets.CIFAR10(root='/content/data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

100%|██████████| 170M/170M [00:03<00:00, 43.2MB/s]


In [4]:
model = ResNetWrapper('resnet18', num_classes).to(device)
opt = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(opt, milestones=[15, 25], gamma=0.1)



In [5]:
global_delta = torch.zeros((batch_size, 3, 32, 32), device=device)

for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch_idx, (images, labels) in enumerate(trainloader):
        images, labels = images.to(device), labels.to(device)
        delta = global_delta[:images.size(0)]

        for _ in range(m_replay):
            adv_images = images + delta
            adv_images = torch.clamp(adv_images, 0, 1)

            opt.zero_grad()
            outputs = model(adv_images)
            loss = F.cross_entropy(outputs, labels)
            loss.backward()

            # update model + delta
            grad_sign = delta.grad.detach().sign() if delta.grad is not None else torch.zeros_like(delta)
            delta.data = torch.clamp(delta + alpha * grad_sign, -epsilon, epsilon)
            delta.grad = None

            opt.step()

        total_loss += loss.item()

    scheduler.step()
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(trainloader):.4f}")



Epoch [1/30], Loss: 1.6299




Epoch [2/30], Loss: 1.2358




Epoch [3/30], Loss: 1.0555




Epoch [4/30], Loss: 0.9665




Epoch [5/30], Loss: 0.9102




Epoch [6/30], Loss: 0.8847




Epoch [7/30], Loss: 0.8550




Epoch [8/30], Loss: 0.8315




Epoch [9/30], Loss: 0.8210
Epoch [10/30], Loss: 0.8061




Epoch [11/30], Loss: 0.8048
Epoch [12/30], Loss: 0.7891




Epoch [13/30], Loss: 0.7870
Epoch [14/30], Loss: 0.7736




Epoch [15/30], Loss: 0.7789
Epoch [16/30], Loss: 0.7624
Epoch [17/30], Loss: 0.6585




Epoch [18/30], Loss: 0.6042




Epoch [19/30], Loss: 0.5694




Epoch [20/30], Loss: 0.5388
Epoch [21/30], Loss: 0.5100
Epoch [22/30], Loss: 0.4946




Epoch [23/30], Loss: 0.4739
Epoch [24/30], Loss: 0.4658




Epoch [25/30], Loss: 0.4502




Epoch [26/30], Loss: 0.4686




Epoch [27/30], Loss: 0.4284




Epoch [28/30], Loss: 0.4093




Epoch [29/30], Loss: 0.3939




Epoch [30/30], Loss: 0.3847


In [6]:
os.makedirs('/content/saved_models', exist_ok=True)
torch.save(model.state_dict(), '/content/saved_models/resnet18_free.pth')