In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# FGSM Attack Function

def fgsm_attack(model, loss_fn, image, label, epsilon):
    image.requires_grad = True
    output = model(image)
    loss = loss_fn(output, label)
    model.zero_grad()
    loss.backward()
    grad = image.grad.data
    perturbed_image = image + epsilon * grad.sign()
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image


#  FGSM + Gaussian Noise
def fgsm_gaussian_attack(model, loss_fn, image, label, epsilon, sigma=0.1):
    image.requires_grad = True
    output = model(image)
    loss = loss_fn(output, label)
    model.zero_grad()
    loss.backward()
    noise = torch.randn_like(image) * sigma
    perturbed_image = image + epsilon * noise
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image


# Load Pretrained ResNet18 Adapted for MNIST

model = models.resnet18(pretrained=True)

# Replace first and last layers to match MNIST input size (1 channel) and 10 classes
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
model.fc = nn.Linear(model.fc.in_features, 10)

model = model.to(device)

transform = transforms.Compose([
    transforms.Resize((224, 224)),    
    transforms.ToTensor(),
])

test_loader = DataLoader(
    datasets.MNIST('.', download=True, train=False, transform=transform),
    batch_size=1,
    shuffle=True
)


# MNIST Train Data for Fine-Tuning

train_loader = DataLoader(
    datasets.MNIST('.', download=True, train=True, transform=transform),
    batch_size=64,
    shuffle=True
)


#Fine-tune on MNIST

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

print("Fine-tuning pretrained ResNet18 on MNIST for 1 epoch...")
model.train()
for epoch in range(1):  
    running_loss = 0.0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    
    print(f"Epoch {epoch+1}, Loss: {running_loss / len(train_loader)}")

print("Fine-tuning complete.")


torch.save(model.state_dict(), 'finetuned_resnet18_mnist.pth')
print("Model saved as 'finetuned_resnet18_mnist.pth'.")

# Evaluate
model.eval()
epsilon = 0.25
clean_correct = 0
fgsm_correct = 0
gaussian_correct = 0
total_samples = 0  

for i, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)

    # Clean accuracy
    output = model(data)
    pred = output.argmax(dim=1)
    clean_correct += pred.eq(target).sum().item()

    # FGSM
    adv_data = fgsm_attack(model, loss_fn, data.clone(), target, epsilon)
    adv_output = model(adv_data)
    fgsm_pred = adv_output.argmax(dim=1)
    fgsm_correct += fgsm_pred.eq(target).sum().item()

    # FGSM + Gaussian
    adv_data_gauss = fgsm_gaussian_attack(model, loss_fn, data.clone(), target, epsilon)
    adv_output_gauss = model(adv_data_gauss)
    gauss_pred = adv_output_gauss.argmax(dim=1)
    gaussian_correct += gauss_pred.eq(target).sum().item()

    total_samples += target.size(0)


print(f"\nEvaluation on {total_samples} MNIST samples (ResNet18):")
print(f"Clean Accuracy         : {clean_correct}/{total_samples} = {clean_correct/total_samples*100:.2f}%")
print(f"FGSM Accuracy          : {fgsm_correct}/{total_samples} = {fgsm_correct/total_samples*100:.2f}%")
print(f"FGSM + Gaussian Accuracy: {gaussian_correct}/{total_samples} = {gaussian_correct/total_samples*100:.2f}%")


Fine-tuning pretrained ResNet18 on MNIST for 1 epoch...
Epoch 1, Loss: 0.06693653437562896
Fine-tuning complete.
Model saved as 'finetuned_resnet18_mnist.pth'.

Evaluation on 10000 MNIST samples (ResNet18):
Clean Accuracy         : 9791/10000 = 97.91%
FGSM Accuracy          : 1251/10000 = 12.51%
FGSM + Gaussian Accuracy: 9792/10000 = 97.92%
