# AI & Security Project

**implementing_defensive_techniques.ipynb**: in this notebook we explore several defensive techniques, to make our model less prone to attacks.


In [9]:
import os
import json
import numpy as np
import torch
import torchvision
from torchvision import transforms
from torch.utils import data
from tqdm.notebook import tqdm

## Step 0: Configurations


In [10]:
# Define constants
DATASET_PATH = r"./data/TinyImageNet-sad/"  # Adjust this to the correct path where the dataset is stored
CHECKPOINT_PATH = r"./models/"
NORM_MEAN = np.array([0.485, 0.456, 0.406])
NORM_STD = np.array([0.229, 0.224, 0.225])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define transformations
plain_transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize(mean=NORM_MEAN, std=NORM_STD)]
)

# Construct the path to the dataset
imagenet_path = os.path.join(DATASET_PATH, "TinyImageNet")
assert os.path.isdir(imagenet_path), (
    f'Could not find the ImageNet dataset at the expected path: "{imagenet_path}". '
    "Please make sure the dataset is downloaded and the path is correct."
)

## Step 1: dataset and libraries


In [11]:
# Load the dataset
dataset = torchvision.datasets.ImageFolder(
    root=imagenet_path, transform=plain_transforms
)
data_loader = data.DataLoader(
    dataset, batch_size=32, shuffle=False, drop_last=False, num_workers=8
)

In [12]:
# Load label names
label_list_path = os.path.join(imagenet_path, "label_list.json")
assert os.path.isfile(
    label_list_path
), f'Label list file not found at "{label_list_path}".'

with open(label_list_path, "r") as f:
    label_names = json.load(f)

## Step 2: model functions

### Utility

In [13]:
def load_model(model_func, trainable=False):
    """Load a pretrained model given its name."""
    model = model_func()
    model = model.to(device) # Ensure the model is on the correct device
    model.eval()

    # Control gradient computation
    for p in model.parameters():
        p.requires_grad = trainable

    return model


def eval_model(dataset_loader, model, img_func=None):
    """Evaluate the model on the given dataset loader."""
    tp, tp_5, counter = 0.0, 0.0, 0.0
    for imgs, labels in tqdm(dataset_loader, desc="Validating...", leave=False):
        imgs = imgs.to(device)
        labels = labels.to(device)
        if img_func is not None:
            imgs = img_func(imgs, labels)
        with torch.no_grad():
            preds = model(imgs)
        tp += (preds.argmax(dim=-1) == labels).sum()
        tp_5 += (preds.topk(5, dim=-1)[1] == labels[..., None]).any(dim=-1).sum()
        counter += preds.shape[0]
    acc = tp.float().item() / counter
    top5 = tp_5.float().item() / counter
    print(f"\tTop-1 error: {(100.0 * (1 - acc)):4.2f}%")
    print(f"\tTop-5 error: {(100.0 * (1 - top5)):4.2f}%")
    return acc, top5

### FGSM Attack

In [14]:
# Define FGSM attack
def fgsm_attack(images, labels, model, epsilon):
    images.requires_grad = True
    outputs = model(images)
    loss = torch.nn.CrossEntropyLoss()(outputs, labels)
    model.zero_grad()
    loss.backward()
    perturbations = epsilon * images.grad.sign()
    adv_images = images + perturbations
    adv_images = torch.clamp(adv_images, 0, 1)  # Keep pixel values in range
    return adv_images

### Adversarial training

In [15]:
# Define adversarial training
def adversarial_training(model_func, train_loader, epsilon):
    model = model_func()
    model = model.to(device)  # Ensure model is on the correct device
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
    loss_fn = torch.nn.CrossEntropyLoss()
    
    EPOCHS = 10

    for epoch in range(EPOCHS):
        for imgs, labels in tqdm(train_loader, desc=f"Adversarial Training (epoch={epoch+1}/{EPOCHS})", leave=False):
            imgs, labels = imgs.to(device), labels.to(device)
            # Generate adversarial examples
            adv_imgs = fgsm_attack(imgs, labels, model, epsilon)
            # Combine clean and adversarial examples
            combined_imgs = torch.cat([imgs, adv_imgs])
            combined_labels = torch.cat([labels, labels])
            # Train on combined examples
            optimizer.zero_grad()
            preds = model(combined_imgs)
            loss = loss_fn(preds, combined_labels)
            loss.backward()
            optimizer.step()
    return model

### Evaluate all models

In [16]:
from torchvision import models
from torchvision.models import ResNet18_Weights, ResNet50_Weights, ResNet152_Weights, VGG16_Weights, VGG19_Weights, RegNet_Y_128GF_Weights, ViT_H_14_Weights, ViT_L_16_Weights


list_of_models = {
    "ResNet18": lambda: models.resnet18(weights=ResNet18_Weights.DEFAULT),
    # "ResNet50": lambda: models.resnet50(weights=ResNet50_Weights.DEFAULT),
    # "ResNet152": lambda: models.resnet152(weights=ResNet152_Weights.DEFAULT),
    # "VGG16": lambda: models.vgg16(weights=VGG16_Weights.DEFAULT),
    # "VGG19": lambda: models.vgg19(weights=VGG19_Weights.DEFAULT),
}


# Evaluate all models
epsilon = 0.03
for model_name, model_func in list_of_models.items():
    print(f"\nEvaluating {model_name} (No Attack):")
    model = load_model(model_func)
    _ = eval_model(data_loader, model)

    print(f"\nEvaluating {model_name} (With FGSM Attack):")
    model = load_model(model_func)
    _ = eval_model(data_loader, model, img_func=lambda x, y: fgsm_attack(x, y, model, epsilon))

    print(f"\nEvaluating {model_name} (With Defense - Adversarial Training):")
    adv_model = adversarial_training(model_func, data_loader, epsilon)
    _ = eval_model(data_loader, adv_model, img_func=lambda x, y: fgsm_attack(x, y, adv_model, epsilon))



Evaluating ResNet18 (No Attack):


Validating...:   0%|          | 0/157 [00:00<?, ?it/s]

	Top-1 error: 24.00%
	Top-5 error: 6.76%

Evaluating ResNet18 (With FGSM Attack):


Validating...:   0%|          | 0/157 [00:00<?, ?it/s]

	Top-1 error: 84.86%
	Top-5 error: 66.16%

Evaluating ResNet18 (With Defense - Adversarial Training):


Adversarial Training (epoch=1/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=2/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=3/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=4/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=5/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=6/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=7/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=8/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=9/10):   0%|          | 0/157 [00:00<?, ?it/s]

Adversarial Training (epoch=10/10):   0%|          | 0/157 [00:00<?, ?it/s]

Validating...:   0%|          | 0/157 [00:00<?, ?it/s]

	Top-1 error: 5.54%
	Top-5 error: 1.68%
