In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import numpy as np
import matplotlib.pyplot as plt

In [2]:
# Data augmentation and normalization
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))
])

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=512, shuffle=False, num_workers=4)

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
model = models.resnet18(weights=None)
model.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
model.fc = nn.Linear(model.fc.in_features, 100)  # 100 classes for CIFAR-100

checkpoint = torch.load('best_resnet18_cifar100.pth')
model.load_state_dict(checkpoint['state_dict'])
model = model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [None]:
def pgd_attack(model, images, labels, target_labels, epsilon=16/255, alpha=4/255, iters=40):
    """
    PGD attack with proper gradient tracking.
    
    Args:
        model: The model to attack
        images: Clean images
        labels: True labels
        target_labels: Target labels for the attack
        epsilon: Maximum perturbation
        alpha: Step size
        iters: Number of iterations
        
    Returns:
        Adversarial images
    """
    # Ensure we're working with a device-consistent copy
    device = next(model.parameters()).device
    images = images.clone().detach().to(device)
    labels = labels.to(device)
    target_labels = target_labels.to(device)
    criterion = nn.CrossEntropyLoss()

    # Initialize adversarial images
    adv_images = images.clone().detach()

    for i in range(iters):
        # Important: Create a fresh copy that requires gradients
        adv_images = adv_images.detach().requires_grad_(True)
        
        # Forward pass
        outputs = model(adv_images)
        loss = criterion(outputs, target_labels)
        
        # Backward pass
        model.zero_grad()
        loss.backward()
        
        # Get gradients
        grad = adv_images.grad.detach()
        
        # Update adversarial images
        with torch.no_grad():
            adv_images = adv_images - alpha * grad.sign()  # Perturb toward target
            delta = torch.clamp(adv_images - images, min=-epsilon, max=epsilon)
            adv_images = torch.clamp(images + delta, min=0, max=1)

        # Optional debugging
        if i % 10 == 0:
            print(f"Iteration {i}, Loss: {loss.item():.4f}")

    return adv_images.detach()

In [None]:
def evaluate_full_test(model, testloader, epsilon=16/255, alpha=4/255, iters=20, subset_size=None):
    model.eval()
    clean_correct = 0
    adv_correct = 0
    total = 0

    # If subset_size is specified, limit the evaluation
    if subset_size:
        testloader = torch.utils.data.DataLoader(testset, batch_size=256, shuffle=True, num_workers=4)
        max_batches = subset_size // 256 + 1
    else:
        max_batches = len(testloader)

    with torch.no_grad():
        for i, (images, labels) in enumerate(testloader):
            if i >= max_batches:
                break
            images, labels = images.to(device), labels.to(device)
            total += labels.size(0)

            # Clean accuracy
            outputs = model(images)
            _, predicted = outputs.max(1)
            clean_correct += predicted.eq(labels).sum().item()

            # Generate adversarial examples
            target_labels = (labels + 1) % 100  # Example: shift to next class
            adv_images = pgd_attack(model, images, labels, target_labels, epsilon, alpha, iters)

            # Adversarial accuracy (success if predicts target label)
            adv_outputs = model(adv_images)
            _, adv_predicted = adv_outputs.max(1)
            adv_correct += adv_predicted.eq(target_labels).sum().item()

    clean_acc = 100. * clean_correct / total
    attack_success = 100. * adv_correct / total
    print(f"Clean Accuracy (full test set): {clean_acc:.2f}%")
    print(f"Attack Success Rate: {attack_success:.2f}%")
    return clean_acc, attack_success

In [None]:
clean_acc, attack_success = evaluate_full_test(model, test_loader, epsilon=16/255, alpha=4/255, iters=20, subset_size=1000)

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [None]:
dataiter = iter(test_loader)
images, labels = next(dataiter)
images, labels = images.to(device), labels.to(device)
target_labels = (labels + 1) % 100
adv_images = pgd_attack(model, images, labels, target_labels)

def imshow(img, title):
    img = img.cpu().numpy().transpose((1, 2, 0))
    img = img * np.array([0.2675, 0.2565, 0.2761]) + np.array([0.5071, 0.4867, 0.4408])  # Denormalize
    img = np.clip(img, 0, 1)
    plt.imshow(img)
    plt.title(title)
    plt.show()

# Show one example
imshow(images[0], f"Original: {testset.classes[labels[0]]}")
imshow(adv_images[0], f"Adversarial: Predicted {testset.classes[target_labels[0]]}")