In [1]:
# trades_cifar10_fixed.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from torchvision.models import resnet18
import numpy as np
import random
from tqdm import tqdm

# --------------------------
# 1. Reproducibility
# --------------------------
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# --------------------------
# 2. Hyperparameters
# --------------------------
batch_size = 128
learning_rate = 0.01  # smaller for stability
num_epochs = 80
epsilon = 8/255  # L_inf perturbation
alpha = 2/255    # PGD step size
num_steps = 10   # PGD steps
beta = 6.0       # TRADES trade-off

# --------------------------
# 3. CIFAR-10 Data
# --------------------------
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# --------------------------
# 4. Model
# --------------------------
model = resnet18(num_classes=10).to(device)

# --------------------------
# 5. TRADES Loss Function
# --------------------------
def trades_loss(model, x_natural, y):
    model.train()
    batch_size = len(x_natural)

    # Initialize adversarial example
    x_adv = x_natural.detach() + 0.001 * torch.randn_like(x_natural)
    for _ in range(num_steps):
        x_adv.requires_grad_()
        logits_adv = model(x_adv)
        logits_nat = model(x_natural)
        loss_kl = F.kl_div(F.log_softmax(logits_adv, dim=1),
                           F.softmax(logits_nat, dim=1),
                           reduction='batchmean')
        grad = torch.autograd.grad(loss_kl, [x_adv])[0]
        x_adv = x_adv + alpha * torch.sign(grad)
        x_adv = torch.min(torch.max(x_adv, x_natural - epsilon), x_natural + epsilon)
        x_adv = torch.clamp(x_adv, 0.0, 1.0).detach()

    # Compute TRADES loss
    logits = model(x_natural)
    logits_adv = model(x_adv)
    loss_ce = F.cross_entropy(logits, y)
    loss_kl = F.kl_div(F.log_softmax(logits_adv, dim=1),
                       F.softmax(logits, dim=1),
                       reduction='batchmean')
    loss = loss_ce + beta * loss_kl
    return loss

# --------------------------
# 6. Optimizer & Scheduler
# --------------------------
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 25], gamma=0.1)

# --------------------------
# 7. Evaluation Function
# --------------------------
def evaluate(model, loader, adversarial=False):
    model.eval()
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        if adversarial:
            x_adv = images.detach() + 0.001 * torch.randn_like(images)
            for _ in range(num_steps):
                x_adv.requires_grad_()
                logits = model(x_adv)
                loss = F.cross_entropy(logits, labels)
                grad = torch.autograd.grad(loss, [x_adv])[0]
                x_adv = x_adv + alpha * torch.sign(grad)
                x_adv = torch.min(torch.max(x_adv, images - epsilon), images + epsilon)
                x_adv = torch.clamp(x_adv, 0.0, 1.0).detach()
            outputs = model(x_adv)
        else:
            outputs = model(images)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    acc = 100.0 * correct / total
    return acc

# --------------------------
# 8. Training Loop
# --------------------------
print("Starting TRADES training...")
for epoch in range(1, num_epochs+1):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        loss = trades_loss(model, images, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    scheduler.step()
    avg_loss = running_loss / len(train_loader.dataset)
    clean_acc = evaluate(model, test_loader, adversarial=False)
    adv_acc = evaluate(model, test_loader, adversarial=True)
    print(f"Epoch [{epoch}/{num_epochs}] | Loss: {avg_loss:.4f} | Clean Acc: {clean_acc:.2f}% | Adv Acc: {adv_acc:.2f}%")

# --------------------------
# 9. Save Model
# --------------------------
torch.save(model.state_dict(), "trades_resnet18_cifar10_fixed.pth")
print("Training completed and model saved!")


Using device: cuda
Starting TRADES training...


Epoch 1/80: 100%|██████████| 391/391 [02:20<00:00,  2.77it/s]


Epoch [1/80] | Loss: 2.2856 | Clean Acc: 34.00% | Adv Acc: 19.72%


Epoch 2/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [2/80] | Loss: 2.0289 | Clean Acc: 36.84% | Adv Acc: 18.69%


Epoch 3/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [3/80] | Loss: 1.9746 | Clean Acc: 39.66% | Adv Acc: 20.85%


Epoch 4/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [4/80] | Loss: 1.9338 | Clean Acc: 40.86% | Adv Acc: 20.01%


Epoch 5/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [5/80] | Loss: 1.9055 | Clean Acc: 44.12% | Adv Acc: 23.62%


Epoch 6/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [6/80] | Loss: 1.8855 | Clean Acc: 45.87% | Adv Acc: 24.65%


Epoch 7/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [7/80] | Loss: 1.8555 | Clean Acc: 45.23% | Adv Acc: 24.76%


Epoch 8/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [8/80] | Loss: 1.8447 | Clean Acc: 46.61% | Adv Acc: 24.17%


Epoch 9/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [9/80] | Loss: 1.8334 | Clean Acc: 46.54% | Adv Acc: 23.72%


Epoch 10/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [10/80] | Loss: 1.8143 | Clean Acc: 46.15% | Adv Acc: 23.39%


Epoch 11/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [11/80] | Loss: 1.8093 | Clean Acc: 46.60% | Adv Acc: 24.78%


Epoch 12/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [12/80] | Loss: 1.7909 | Clean Acc: 50.05% | Adv Acc: 25.71%


Epoch 13/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [13/80] | Loss: 1.7823 | Clean Acc: 47.64% | Adv Acc: 27.46%


Epoch 14/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [14/80] | Loss: 1.7694 | Clean Acc: 50.97% | Adv Acc: 25.90%


Epoch 15/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [15/80] | Loss: 1.7609 | Clean Acc: 51.21% | Adv Acc: 27.57%


Epoch 16/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [16/80] | Loss: 1.7217 | Clean Acc: 53.09% | Adv Acc: 28.33%


Epoch 17/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [17/80] | Loss: 1.7149 | Clean Acc: 53.83% | Adv Acc: 28.07%


Epoch 18/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [18/80] | Loss: 1.7111 | Clean Acc: 52.25% | Adv Acc: 28.23%


Epoch 19/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [19/80] | Loss: 1.7095 | Clean Acc: 53.79% | Adv Acc: 27.40%


Epoch 20/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [20/80] | Loss: 1.7079 | Clean Acc: 54.52% | Adv Acc: 27.69%


Epoch 21/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [21/80] | Loss: 1.7054 | Clean Acc: 52.73% | Adv Acc: 27.20%


Epoch 22/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [22/80] | Loss: 1.7044 | Clean Acc: 54.08% | Adv Acc: 28.24%


Epoch 23/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [23/80] | Loss: 1.7016 | Clean Acc: 54.32% | Adv Acc: 28.84%


Epoch 24/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [24/80] | Loss: 1.6990 | Clean Acc: 52.65% | Adv Acc: 26.95%


Epoch 25/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [25/80] | Loss: 1.6984 | Clean Acc: 53.87% | Adv Acc: 27.99%


Epoch 26/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [26/80] | Loss: 1.6964 | Clean Acc: 53.83% | Adv Acc: 27.81%


Epoch 27/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [27/80] | Loss: 1.6919 | Clean Acc: 53.09% | Adv Acc: 28.43%


Epoch 28/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [28/80] | Loss: 1.6931 | Clean Acc: 54.13% | Adv Acc: 28.29%


Epoch 29/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [29/80] | Loss: 1.6951 | Clean Acc: 53.18% | Adv Acc: 27.92%


Epoch 30/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [30/80] | Loss: 1.6906 | Clean Acc: 54.20% | Adv Acc: 28.60%


Epoch 31/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [31/80] | Loss: 1.6916 | Clean Acc: 53.99% | Adv Acc: 28.00%


Epoch 32/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [32/80] | Loss: 1.6931 | Clean Acc: 54.23% | Adv Acc: 28.31%


Epoch 33/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [33/80] | Loss: 1.6934 | Clean Acc: 53.81% | Adv Acc: 27.27%


Epoch 34/80: 100%|██████████| 391/391 [02:19<00:00,  2.81it/s]


Epoch [34/80] | Loss: 1.6913 | Clean Acc: 54.78% | Adv Acc: 28.76%


Epoch 35/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [35/80] | Loss: 1.6920 | Clean Acc: 53.86% | Adv Acc: 28.39%


Epoch 36/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [36/80] | Loss: 1.6909 | Clean Acc: 53.51% | Adv Acc: 28.09%


Epoch 37/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [37/80] | Loss: 1.6904 | Clean Acc: 54.27% | Adv Acc: 27.75%


Epoch 38/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [38/80] | Loss: 1.6923 | Clean Acc: 53.87% | Adv Acc: 28.13%


Epoch 39/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [39/80] | Loss: 1.6921 | Clean Acc: 54.90% | Adv Acc: 28.02%


Epoch 40/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [40/80] | Loss: 1.6920 | Clean Acc: 53.93% | Adv Acc: 28.28%


Epoch 41/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [41/80] | Loss: 1.6914 | Clean Acc: 54.64% | Adv Acc: 28.78%


Epoch 42/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [42/80] | Loss: 1.6919 | Clean Acc: 53.90% | Adv Acc: 28.75%


Epoch 43/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [43/80] | Loss: 1.6901 | Clean Acc: 55.32% | Adv Acc: 28.81%


Epoch 44/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [44/80] | Loss: 1.6891 | Clean Acc: 53.62% | Adv Acc: 28.51%


Epoch 45/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [45/80] | Loss: 1.6896 | Clean Acc: 55.02% | Adv Acc: 28.56%


Epoch 46/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [46/80] | Loss: 1.6903 | Clean Acc: 54.49% | Adv Acc: 28.33%


Epoch 47/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [47/80] | Loss: 1.6897 | Clean Acc: 54.04% | Adv Acc: 28.84%


Epoch 48/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [48/80] | Loss: 1.6892 | Clean Acc: 53.83% | Adv Acc: 27.49%


Epoch 49/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [49/80] | Loss: 1.6912 | Clean Acc: 54.18% | Adv Acc: 28.30%


Epoch 50/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [50/80] | Loss: 1.6900 | Clean Acc: 53.99% | Adv Acc: 28.36%


Epoch 51/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [51/80] | Loss: 1.6890 | Clean Acc: 54.83% | Adv Acc: 28.45%


Epoch 52/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [52/80] | Loss: 1.6889 | Clean Acc: 53.96% | Adv Acc: 27.98%


Epoch 53/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [53/80] | Loss: 1.6891 | Clean Acc: 53.78% | Adv Acc: 28.54%


Epoch 54/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [54/80] | Loss: 1.6882 | Clean Acc: 55.25% | Adv Acc: 28.94%


Epoch 55/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [55/80] | Loss: 1.6895 | Clean Acc: 52.84% | Adv Acc: 28.01%


Epoch 56/80: 100%|██████████| 391/391 [02:19<00:00,  2.80it/s]


Epoch [56/80] | Loss: 1.6887 | Clean Acc: 54.35% | Adv Acc: 28.63%


Epoch 57/80:  53%|█████▎    | 207/391 [01:14<01:05,  2.79it/s]


KeyboardInterrupt: 

In [2]:
torch.save(model.state_dict(), "trades_resnet18_cifar10_fixed.pth")
print("Training completed and model saved!")

Training completed and model saved!
