In [85]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

import torchvision
import torchvision.transforms as transforms

import time
import datetime
import os
import torchvision.models as models


# Load the ResNet-18 model pretrained on ImageNet
resnet18 = models.resnet18(pretrained=True)

In [86]:
!pip install wandb
import wandb
wandb.login()





True

In [87]:
# Load in dataset: CIFAR-10

%matplotlib inline

import torchvision
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


train_transforms = [
    transforms.RandomCrop(32, padding=4),  # Data Augmentation: Random crop with padding
    transforms.RandomHorizontalFlip(),    # Data Augmentation: Random horizontal flip
    transforms.ToTensor(),                # Convert PIL.Image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
]

val_transforms = [
    transforms.ToTensor(),                # Convert PIL.Image to PyTorch tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize
]

def get_datasets(train_transforms=(), val_transforms=()):
    r"""
    Returns the CIFAR-10 training and validation datasets with corresponding
    transforms.

    `*_transforms` represent optional transformations, e.g., conversion to
    PyTorch tensors, preprocessing, etc.
    """
    train_set = torchvision.datasets.CIFAR10(
        './data', train=True, download=True,
        transform=torchvision.transforms.Compose(train_transforms))
    val_set = torchvision.datasets.CIFAR10(
        './data', train=False, download=True,
        transform=torchvision.transforms.Compose(val_transforms))
    return train_set, val_set

train_set, val_set = get_datasets(train_transforms, val_transforms)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=128, shuffle=True, num_workers=4)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=100, shuffle=False, num_workers=4)

print(f"Training set size: {len(train_set)}")
print(f"Validation set size: {len(val_set)}")

class_names = train_set.classes

print(f'CIFAR-10 classes: {class_names}')

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Files already downloaded and verified
Files already downloaded and verified
Training set size: 50000
Validation set size: 10000
CIFAR-10 classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [88]:
from google.colab import drive
drive.mount('/content/drive')

model_path = '/content/drive/MyDrive/MIT/6.7960 Deep Learning/models/resnet18_basic_trained.pt'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [92]:
# Modify the Fully Connected Layer for CIFAR-10 (10 classes)
num_classes = 10
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)

# Move Model to Device
resnet18 = resnet18.to(device)

# Define Loss Function and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Set up a Wandb instance
wandb.init(project="layer-freezing-adversarial-training", name="resnet18-basic-training")
table = wandb.Table(columns=["epoch", "train_loss", "val_loss", "train_accuracy", "val_accuracy"])

# Training Function
def train(model, loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, targets in loader:
        inputs, targets = inputs.to(device), targets.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    loss = running_loss / len(loader)
    accuracy = 100. * correct / total
    print(f"Epoch {epoch}: Loss = {loss:.4f}, "
          f"Accuracy = {accuracy:.2f}%")
    return loss, accuracy

# Validation Function
def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, targets in loader:
            inputs, targets = inputs.to(device), targets.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    loss = running_loss / len(loader)
    accuracy = 100. * correct / total
    print(f"Validation Loss = {loss:.4f}, "
          f"Accuracy = {accuracy:.2f}%")
    return loss, accuracy

# Training Loop: takes about 9 minutes with 20 epochs
num_epochs = 20
best_accuracy = 0

for epoch in range(1, num_epochs + 1):
    train_loss, train_accuracy = train(resnet18, train_loader, criterion, optimizer, epoch)
    val_loss, val_accuracy = evaluate(resnet18, val_loader, criterion)

    # Save the best model
    if val_accuracy > best_accuracy:
        best_accuracy = val_accuracy
        torch.save(resnet18, model_path)
        print("Model Saved!")

    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "val_loss": val_loss,
        "train_accuracy": train_accuracy,
        "val_accuracy": val_accuracy
    })
    table.add_data(epoch, train_loss, val_loss, train_accuracy, val_accuracy)

print(f"Best Validation Accuracy: {best_accuracy:.2f}%")
wandb.log({"best_accuracy": best_accuracy})
wandb.log({"metrics_table": table})
print(table)

wandb.finish()

Epoch 1: Loss = 1.0572, Accuracy = 63.45%
Validation Loss = 0.7645, Accuracy = 73.48%
Model Saved!
Epoch 2: Loss = 0.7068, Accuracy = 75.76%
Validation Loss = 0.6469, Accuracy = 77.83%
Model Saved!
Epoch 3: Loss = 0.6228, Accuracy = 78.61%
Validation Loss = 0.6290, Accuracy = 78.82%
Model Saved!
Epoch 4: Loss = 0.5659, Accuracy = 80.45%
Validation Loss = 0.5738, Accuracy = 80.39%
Model Saved!
Epoch 5: Loss = 0.5314, Accuracy = 81.54%
Validation Loss = 0.7124, Accuracy = 77.50%
Epoch 6: Loss = 0.4954, Accuracy = 82.75%
Validation Loss = 0.5308, Accuracy = 82.51%
Model Saved!
Epoch 7: Loss = 0.4792, Accuracy = 83.24%
Validation Loss = 0.6052, Accuracy = 79.86%
Epoch 8: Loss = 0.4493, Accuracy = 84.50%
Validation Loss = 0.5158, Accuracy = 83.26%
Model Saved!
Epoch 9: Loss = 0.4354, Accuracy = 84.86%
Validation Loss = 0.4930, Accuracy = 83.28%
Model Saved!
Epoch 10: Loss = 0.4218, Accuracy = 85.33%
Validation Loss = 0.5217, Accuracy = 82.61%
Epoch 11: Loss = 0.4017, Accuracy = 86.09%
Valid

0,1
best_accuracy,▁
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
train_accuracy,▁▄▅▆▆▆▆▇▇▇▇▇▇▇██████
train_loss,█▅▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁
val_accuracy,▁▄▄▅▄▇▅▇▇▇▇█▇██▇██▆▇
val_loss,█▅▅▃▇▂▄▂▁▂▂▁▃▁▁▂▁▁▃▂

0,1
best_accuracy,84.45
epoch,20.0
train_accuracy,88.67
train_loss,0.32405
val_accuracy,82.69
val_loss,0.5289


In [None]:
!pip install torchattacks

import torchattacks
from torchattacks import PGD



In [89]:
# Methods for Training and Validating with an Adversarial Training scheme; Consistent across all experiments

# Number of Batches: 391

# Training function with adversarial examples
def adversarial_train(epoch, model):
    print(f'\n[ Train epoch: {epoch} ]')
    model.train()
    train_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        # Generate adversarial examples
        adv_inputs = adversary(inputs, targets)

        # Forward pass on adversarial examples
        optimizer.zero_grad()
        outputs = model(adv_inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        # Statistics
        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batch_idx % 10 == 0:
            print(f'Batch {batch_idx}: Loss = {loss.item():.4f}, Accuracy = {100. * correct / total:.2f}%')

    total_loss = train_loss / len(train_loader)
    total_accuracy = 100. * correct / total
    print(f'Epoch {epoch}: Total Loss = {total_loss:.4f}, Total Accuracy = {total_accuracy:.2f}%')
    return total_loss, total_accuracy

# Testing function for adversarial and clean examples
def adversarial_test(epoch, model):
    print(f'\n[ Test epoch: {epoch} ]')
    model.eval()
    benign_loss = 0.0
    adv_loss = 0.0
    benign_correct = 0
    adv_correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(val_loader):
        with torch.no_grad():
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)

            # Test on benign examples
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            benign_loss += loss.item()
            _, predicted = outputs.max(1)
            benign_correct += predicted.eq(targets).sum().item()

        adv_inputs = inputs.clone().detach().requires_grad_(True)
        adv_inputs = adversary(adv_inputs, targets)
        with torch.no_grad():
            # Test on adversarial examples
            adv_outputs = model(adv_inputs)
            loss = criterion(adv_outputs, targets)
            adv_loss += loss.item()
            _, predicted = adv_outputs.max(1)
            adv_correct += predicted.eq(targets).sum().item()

            if batch_idx % 10 == 0:
                print(f'Batch {batch_idx}: Benign Loss = {loss.item():.4f}, Adversarial Loss = {loss.item():.4f}')

    benign_accuracy = 100. * benign_correct / total
    adv_accuracy = 100. * adv_correct / total
    benign_loss = benign_loss / len(val_loader)
    adv_loss = adv_loss / len(val_loader)
    print(f'Epoch {epoch}: Benign Accuracy = {benign_accuracy:.2f}%, Adversarial Accuracy = {adv_accuracy:.2f}%')
    print(f'Benign Loss = {benign_loss:.4f}, Adversarial Loss = {adv_loss:.4f}')
    return benign_accuracy, adv_accuracy, benign_loss, adv_loss


In [95]:
# Methods for Evaluating Benign and Adversarial Accuracy after Training

def evaluate(model):
    # Initialize PGD attack
    attack = torchattacks.PGD(model, eps=0.03, alpha=0.01, steps=40)

    # Generate adversarial examples
    # Get one batch of test data
    inputs, labels = next(iter(val_loader))

    # Move to the same device as the model (e.g., GPU if available)
    inputs, labels = inputs.to(device), labels.to(device)

    with torch.no_grad():
        benign_outputs = model(inputs)  # Forward pass on clean inputs
        _, benign_predicted = benign_outputs.max(1)  # Get predictions
        benign_accuracy = (benign_predicted == labels).float().mean().item() * 100
        print(f"Benign Accuracy: {benign_accuracy:.2f}%")

    adv_inputs = attack(inputs, labels)

    # Evaluate the model on adversarial examples
    with torch.no_grad():
      outputs = model(adv_inputs)
      _, predicted = outputs.max(1)
      adv_accuracy = (predicted == labels).float().mean().item() * 100
      print(f"Adversarial Accuracy: {adv_accuracy:.2f}%")

    return benign_accuracy, adv_accuracy

In [None]:
# EXPERIMENT 1: CONTROL

wandb.init(project="layer-freezing-adversarial-training", name="resnet18-control-training")
table = wandb.Table(columns=["epoch", "train_adv_accuracy", "train_adv_loss", "test_benign_accuracy", "test_adv_accuracy", "test_benign_loss", "test_adv_loss"])

# Model and device setup
resnet18_control = torch.load(model_path)

# Reinitialize to ensure no conflicts when model reloaded in
for param in resnet18_control.parameters():
    param.requires_grad = True

resnet18_control = resnet18_control.to(device)
resnet18_control = torch.nn.DataParallel(resnet18_control)
cudnn.benchmark = True

# Define adversary (PGD Attack)
adversary = PGD(resnet18_control, eps=0.03, alpha=0.01, steps=40)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18_control.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

# Adjust learning rate
def adjust_learning_rate(optimizer, epoch):
    lr = 0.01
    if epoch >= 30:
        lr /= 10
    if epoch >= 40:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Training and testing loop
start_time = time.time()
for epoch in range(0, 25):  # Train for 25 epochs- 2 epochs takes around 8 minutes
    adjust_learning_rate(optimizer, epoch)
    train_adv_accuracy, train_adv_loss = adversarial_train(epoch, resnet18_control) # trained on adversarial examples
    test_benign_accuracy, test_adv_accuracy, test_benign_loss, test_adv_loss = adversarial_test(epoch, resnet18_control)
    wandb.log({
        "epoch": epoch,
        "train_adv_accuracy": train_adv_accuracy,
        "train_adv_loss": train_adv_loss,
        "test_benign_accuracy": test_benign_accuracy,
        "test_adv_accuracy": test_adv_accuracy,
        "test_benign_loss": test_benign_loss,
        "test_adv_loss": test_adv_loss
    })
    table.add_data(epoch, train_adv_accuracy, train_adv_loss, test_benign_accuracy, test_adv_accuracy, test_benign_loss, test_adv_loss)
end_time = time.time()

print(f'Training complete in {end_time - start_time:.2f} seconds')

model_control_path = '/content/drive/MyDrive/MIT/6.7960 Deep Learning/models/resnet18_control_trained.pt'
torch.save(resnet18_control, model_control_path)

wandb.log({"metrics_table": table})

wandb.finish()

  resnet18_control = torch.load(model_path)



[ Train epoch: 0 ]
Batch 0: Loss = 1.7801, Accuracy = 57.81%
Batch 10: Loss = 1.8482, Accuracy = 48.58%
Batch 20: Loss = 1.7701, Accuracy = 45.16%
Batch 30: Loss = 1.5640, Accuracy = 44.88%
Batch 40: Loss = 1.4490, Accuracy = 45.05%
Batch 50: Loss = 1.4562, Accuracy = 45.28%
Batch 60: Loss = 1.5080, Accuracy = 45.39%
Batch 70: Loss = 1.4968, Accuracy = 45.88%
Batch 80: Loss = 1.3792, Accuracy = 46.06%
Batch 90: Loss = 1.3446, Accuracy = 46.69%
Batch 100: Loss = 1.3725, Accuracy = 47.01%
Batch 110: Loss = 1.1665, Accuracy = 47.59%
Batch 120: Loss = 1.3161, Accuracy = 47.82%
Batch 130: Loss = 1.2760, Accuracy = 48.14%
Batch 140: Loss = 1.2977, Accuracy = 48.37%
Batch 150: Loss = 1.1884, Accuracy = 48.72%
Batch 160: Loss = 1.1451, Accuracy = 48.87%
Batch 170: Loss = 1.3081, Accuracy = 49.29%
Batch 180: Loss = 1.2475, Accuracy = 49.77%
Batch 190: Loss = 1.2184, Accuracy = 49.87%
Batch 200: Loss = 1.2804, Accuracy = 50.09%
Batch 210: Loss = 1.1352, Accuracy = 50.27%
Batch 220: Loss = 1.239

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
test_adv_accuracy,▁▂▄▄▅▅▅▇▆▆▆▇▅▆▆▄▄▆▇██▆▇▇▆
test_adv_loss,▇▆▃▄▂▃▃▁▂▃▃▁▆▄▄█▆▇▂▁▄▆▆▄▆
test_benign_accuracy,█▆▇▆▄▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_benign_loss,▁▁▁▁▂▂▂▂▃▃▂▃▄▃▅▇█▃▅▆▆▄▄▇▅
train_adv_accuracy,█▆▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁
train_adv_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,24.0
test_adv_accuracy,59.51
test_adv_loss,1.18432
test_benign_accuracy,11.02
test_benign_loss,24.33868
train_adv_accuracy,0.78749
train_adv_loss,0.31686


In [None]:
evaluate(resnet18_control)

Benign Accuracy: 10.00%
Adversarial Accuracy: 58.00%


(9.999999403953552, 57.999998331069946)

In [93]:
# EXPERIMENT 2: HALF FROZEN LAYERS

wandb.init(project="layer-freezing-adversarial-training", name="resnet18-half-frozen-training")
table = wandb.Table(columns=["epoch", "train_adv_accuracy", "train_adv_loss", "test_benign_accuracy", "test_adv_accuracy", "test_benign_loss", "test_adv_loss"])

# Model and device setup
resnet18_half_frozen = torch.load(model_path)

for param in resnet18_half_frozen.parameters():
    param.requires_grad = True

resnet18_half_frozen = torch.nn.DataParallel(resnet18_half_frozen)

# Freeze first 9 layers
child_counter = 0
for child in resnet18_half_frozen.module.children():
    child_counter += 1
    if child_counter <= 2:  # Freeze the first 9 layers
        for param in child.parameters():
            param.requires_grad = False

# Check that only the fc layer is trainable
for name, param in resnet18_half_frozen.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

cudnn.benchmark = True

# Define adversary (PGD Attack)
adversary = PGD(resnet18, eps=0.03, alpha=0.01, steps=40)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, resnet18.parameters()),
                      lr=0.01, momentum=0.9, weight_decay=5e-4)

# Adjust learning rate
def adjust_learning_rate(optimizer, epoch):
    lr = 0.01
    if epoch >= 30:
        lr /= 10
    if epoch >= 40:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Training and testing loop
start_time = time.time()
for epoch in range(0, 25):  # Train for 25 epochs
    adjust_learning_rate(optimizer, epoch)
    train_adv_accuracy, train_adv_loss = adversarial_train(epoch, resnet18_half_frozen)
    test_benign_accuracy, test_adv_accuracy, test_benign_loss, test_adv_loss = adversarial_test(epoch, resnet18_half_frozen)
    wandb.log({
        "epoch": epoch,
        "train_adv_accuracy": train_adv_accuracy,
        "train_adv_loss": train_adv_loss,
        "test_benign_accuracy": test_benign_accuracy,
        "test_adv_accuracy": test_adv_accuracy,
        "test_benign_loss": test_benign_loss,
        "test_adv_loss": test_adv_loss
    })
    table.add_data(epoch, train_adv_accuracy, train_adv_loss, test_benign_accuracy, test_adv_accuracy, test_benign_loss, test_adv_loss)
end_time = time.time()

print(f'Training complete in {end_time - start_time:.2f} seconds')

model_frozen_path = '/content/drive/MyDrive/MIT/6.7960 Deep Learning/models/resnet18_half_frozen_trained.pt'
torch.save(resnet18_half_frozen, model_frozen_path)

wandb.log({"metrics_table": table})

wandb.finish()

module.conv1.weight: requires_grad=False
module.bn1.weight: requires_grad=False
module.bn1.bias: requires_grad=False
module.layer1.0.conv1.weight: requires_grad=True
module.layer1.0.bn1.weight: requires_grad=True
module.layer1.0.bn1.bias: requires_grad=True
module.layer1.0.conv2.weight: requires_grad=True
module.layer1.0.bn2.weight: requires_grad=True
module.layer1.0.bn2.bias: requires_grad=True
module.layer1.1.conv1.weight: requires_grad=True
module.layer1.1.bn1.weight: requires_grad=True
module.layer1.1.bn1.bias: requires_grad=True
module.layer1.1.conv2.weight: requires_grad=True
module.layer1.1.bn2.weight: requires_grad=True
module.layer1.1.bn2.bias: requires_grad=True
module.layer2.0.conv1.weight: requires_grad=True
module.layer2.0.bn1.weight: requires_grad=True
module.layer2.0.bn1.bias: requires_grad=True
module.layer2.0.conv2.weight: requires_grad=True
module.layer2.0.bn2.weight: requires_grad=True
module.layer2.0.bn2.bias: requires_grad=True
module.layer2.0.downsample.0.weight: 

  resnet18_half_frozen = torch.load(model_path)


Batch 0: Loss = 1.6823, Accuracy = 51.56%
Batch 10: Loss = 1.8104, Accuracy = 53.27%
Batch 20: Loss = 1.7621, Accuracy = 54.65%
Batch 30: Loss = 1.7359, Accuracy = 55.09%
Batch 40: Loss = 2.2385, Accuracy = 54.55%
Batch 50: Loss = 1.9320, Accuracy = 54.90%
Batch 60: Loss = 1.7048, Accuracy = 55.37%
Batch 70: Loss = 1.7780, Accuracy = 55.09%
Batch 80: Loss = 1.3899, Accuracy = 55.22%
Batch 90: Loss = 1.6629, Accuracy = 55.07%
Batch 100: Loss = 2.0704, Accuracy = 54.92%
Batch 110: Loss = 2.0217, Accuracy = 55.07%
Batch 120: Loss = 1.6926, Accuracy = 55.26%
Batch 130: Loss = 2.2005, Accuracy = 55.19%
Batch 140: Loss = 1.7729, Accuracy = 55.18%
Batch 150: Loss = 2.1875, Accuracy = 55.29%
Batch 160: Loss = 1.7852, Accuracy = 55.34%
Batch 170: Loss = 1.8089, Accuracy = 55.29%
Batch 180: Loss = 1.8891, Accuracy = 55.31%
Batch 190: Loss = 1.9751, Accuracy = 55.32%
Batch 200: Loss = 2.5432, Accuracy = 55.25%
Batch 210: Loss = 1.6100, Accuracy = 55.22%
Batch 220: Loss = 1.6488, Accuracy = 55.25%

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
test_adv_accuracy,▄▂▄▅▃▅▃▂▅▆▂▃▂▃▅▁▄▆▄▅▄█▃▆▄
test_adv_loss,▆▂▇▄▃▃█▅▅▅▆▄▇▃▇▆▃▅▃▁▅▅▅▃▃
test_benign_accuracy,▂▇▂█▅▄▂▄▂▆▅▅▃▁▃▅▆▃▅▇▄▅▆▅▄
test_benign_loss,▆▂▆▁▃▃▇▄▅▂▃▃▅█▄▃▂▆▃▁▃▃▂▃▃
train_adv_accuracy,▂▄▁▁▆▇▁▂▄▆▂▅▄▃▅▄▃▅▅▄█▅▄▂▂
train_adv_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,24.0
test_adv_accuracy,55.73
test_adv_loss,1.80938
test_benign_accuracy,20.03
test_benign_loss,4.14776
train_adv_accuracy,1.75454
train_adv_loss,0.32405


In [96]:
evaluate(resnet18_half_frozen)

Benign Accuracy: 20.00%
Adversarial Accuracy: 23.00%


(19.999998807907104, 22.999998927116394)

In [97]:
# EXPERIMENT 3: FINETUNING

wandb.init(project="layer-freezing-adversarial-training", name="resnet18-finetuned-training")
table = wandb.Table(columns=["epoch", "train_adv_accuracy", "train_adv_loss", "test_benign_accuracy", "test_adv_accuracy", "test_benign_loss", "test_adv_loss"])

# Model and device setup
resnet18_finetuned = torch.load(model_path)
resnet18_finetuned = torch.nn.DataParallel(resnet18_finetuned)

# Freeze all layers except the fully connected (fc) layer
for name, param in resnet18_finetuned.module.named_parameters():
    if "fc" not in name:  # Freeze all layers except the 'fc' layer
        param.requires_grad = False

# Check that only the fc layer is trainable
for name, param in resnet18_finetuned.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")
cudnn.benchmark = True

# Define adversary (PGD Attack)
adversary = PGD(resnet18, eps=0.03, alpha=0.01, steps=40)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, resnet18_finetuned.parameters()),
                      lr=0.01, momentum=0.9, weight_decay=5e-4)

# Adjust learning rate
def adjust_learning_rate(optimizer, epoch):
    lr = 0.01
    if epoch >= 30:
        lr /= 10
    if epoch >= 40:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Training and testing loop
start_time = time.time()
for epoch in range(0, 25):  # Train for 25 epochs
    adjust_learning_rate(optimizer, epoch)
    train_adv_accuracy, train_adv_loss = adversarial_train(epoch, resnet18_finetuned)
    test_benign_accuracy, test_adv_accuracy, test_benign_loss, test_adv_loss = adversarial_test(epoch, resnet18_finetuned)
    wandb.log({
        "epoch": epoch,
        "train_adv_accuracy": train_adv_accuracy,
        "train_adv_loss": train_adv_loss,
        "test_benign_accuracy": test_benign_accuracy,
        "test_adv_accuracy": test_adv_accuracy,
        "test_benign_loss": test_benign_loss,
        "test_adv_loss": test_adv_loss
    })
    table.add_data(epoch, train_adv_accuracy, train_adv_loss, test_benign_accuracy, test_adv_accuracy, test_benign_loss, test_adv_loss)
end_time = time.time()

print(f'Training complete in {end_time - start_time:.2f} seconds')

model_finetuned_path = '/content/drive/MyDrive/MIT/6.7960 Deep Learning/models/resnet18_finetuned_trained.pt'
torch.save(resnet18_finetuned, model_finetuned_path)

wandb.log({"metrics_table": table})

wandb.finish()

module.conv1.weight: requires_grad=False
module.bn1.weight: requires_grad=False
module.bn1.bias: requires_grad=False
module.layer1.0.conv1.weight: requires_grad=False
module.layer1.0.bn1.weight: requires_grad=False
module.layer1.0.bn1.bias: requires_grad=False
module.layer1.0.conv2.weight: requires_grad=False
module.layer1.0.bn2.weight: requires_grad=False
module.layer1.0.bn2.bias: requires_grad=False
module.layer1.1.conv1.weight: requires_grad=False
module.layer1.1.bn1.weight: requires_grad=False
module.layer1.1.bn1.bias: requires_grad=False
module.layer1.1.conv2.weight: requires_grad=False
module.layer1.1.bn2.weight: requires_grad=False
module.layer1.1.bn2.bias: requires_grad=False
module.layer2.0.conv1.weight: requires_grad=False
module.layer2.0.bn1.weight: requires_grad=False
module.layer2.0.bn1.bias: requires_grad=False
module.layer2.0.conv2.weight: requires_grad=False
module.layer2.0.bn2.weight: requires_grad=False
module.layer2.0.bn2.bias: requires_grad=False
module.layer2.0.dow

  resnet18_finetuned = torch.load(model_path)


Batch 0: Loss = 1.8483, Accuracy = 54.69%
Batch 10: Loss = 1.5559, Accuracy = 55.26%
Batch 20: Loss = 1.6713, Accuracy = 55.28%
Batch 30: Loss = 1.1652, Accuracy = 55.27%
Batch 40: Loss = 1.3751, Accuracy = 55.47%
Batch 50: Loss = 1.0389, Accuracy = 55.42%
Batch 60: Loss = 1.1531, Accuracy = 55.65%
Batch 70: Loss = 1.0491, Accuracy = 55.39%
Batch 80: Loss = 1.1143, Accuracy = 55.73%
Batch 90: Loss = 1.3645, Accuracy = 55.65%
Batch 100: Loss = 1.1071, Accuracy = 55.74%
Batch 110: Loss = 1.0735, Accuracy = 55.83%
Batch 120: Loss = 1.2509, Accuracy = 55.75%
Batch 130: Loss = 1.3187, Accuracy = 55.92%
Batch 140: Loss = 1.2264, Accuracy = 56.08%
Batch 150: Loss = 1.0944, Accuracy = 56.09%
Batch 160: Loss = 1.2988, Accuracy = 56.13%
Batch 170: Loss = 1.3965, Accuracy = 56.11%
Batch 180: Loss = 1.2006, Accuracy = 56.22%
Batch 190: Loss = 1.1142, Accuracy = 56.27%
Batch 200: Loss = 1.1510, Accuracy = 56.33%
Batch 210: Loss = 1.1109, Accuracy = 56.34%
Batch 220: Loss = 1.2594, Accuracy = 56.39%

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
test_adv_accuracy,▁▁▂▄▄▃▄▅▇▆▅▆▆▃▇▆▅▆▆▇▆▆▆█▇
test_adv_loss,█▇▆▅▅▅▅▃▂▂▃▃▂▅▃▃▃▃▄▁▃▃▃▂▁
test_benign_accuracy,▁▃▄▇▆▇▅▂▄▄▃▃▂▅▄▄▄▃▄▃▅▃▁▃█
test_benign_loss,▃▄▄▁▂▂▄▅▃▃█▅▄▃█▅▃▅▂▄▅▇▇▃▂
train_adv_accuracy,█▄▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁
train_adv_loss,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,24.0
test_adv_accuracy,61.35
test_adv_loss,1.10083
test_benign_accuracy,26.81
test_benign_loss,2.37642
train_adv_accuracy,1.11513
train_adv_loss,0.32405


In [98]:
evaluate(resnet18_finetuned)

Benign Accuracy: 26.00%
Adversarial Accuracy: 26.00%


(25.999999046325684, 25.999999046325684)