## Load dataset

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.Grayscale(num_output_channels=3), 
    transforms.ToTensor(),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)

test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  1%|          | 1114112/170498071 [00:25<2:06:43, 22275.79it/s]

## Load model

In [2]:
import os
from tqdm import tqdm

DEVICE = "cuda"
WEIGHTS_PATH = "./alex_weights.pth"
have_loaded_weights = False

def load_model():
    model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)
    model = model.to(DEVICE)

    if os.path.exists(WEIGHTS_PATH):
        model.load_state_dict(torch.load(WEIGHTS_PATH))
        have_loaded_weights = True
        print(f"Successfully load weights from \"{WEIGHTS_PATH}\"")
    return model, have_loaded_weights

## Fine-tune model

In [3]:
import torch.optim as optim
import numpy as np

model, have_loaded_weights = load_model()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 1
if not have_loaded_weights:
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for idx, data in tqdm(enumerate(train_loader), 
                          total=train_loader.__len__()):
            batch_X, batch_Y = data
            batch_X:torch.Tensor = batch_X.to(DEVICE)
            batch_Y:torch.Tensor = batch_Y.to(DEVICE)
            
            outputs = model(batch_X)
            loss = criterion(outputs, batch_Y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        torch.save(model.state_dict(), WEIGHTS_PATH)
        
        model.eval()
        totall_accuracy = []
        with torch.no_grad():
            for idx, data in tqdm(enumerate(test_loader), 
                          total=test_loader.__len__()):
                batch_X, batch_Y = data
                batch_X:torch.Tensor = batch_X.to(DEVICE)
                batch_Y:torch.Tensor = batch_Y.to(DEVICE)
                
                outputs = model(batch_X)
                predicted = outputs.argmax(1)
                totall_accuracy.append((predicted == batch_Y).float().mean().item())

        totall_accuracy = np.array(totall_accuracy).mean()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Accuracy: {100 * totall_accuracy:.2f}")
        
else:
    print("Already loaded pretrained wweights.")


Successfully load weights from "./alex_weights.pth"
Already loaded pretrained wweights.


## AGST

In [None]:
from agst import *
from captum.attr import Saliency

LEARNING_RATE = 1e-3 # 1e-2 in the paper but does not work here.
EPS = [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4]
K = int((3 * 224 * 224) * 0.3)
LAMBDA = 1.0


for eps in EPS:
    print(f"AGST on {eps}")
    model, have_loaded_weights = load_model()
    saliency = Saliency(model)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

    def training_forward_func(loss:torch.Tensor):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    agst = AGST(
        model=model,
        training_forward_func = training_forward_func,
        loss_func=nn.CrossEntropyLoss(),
        attak_func="FGSM",
        explain_func=saliency.attribute,
        eps=eps,
        k=K,
        lam=LAMBDA,
        feature_range=[0.0, 1.0],
        device=torch.device(DEVICE)
    )

    num_epochs = 1
    for epoch in range(num_epochs):
        running_loss = agst.train_one_epoch(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        # torch.save(model.state_dict(), f"./robust_alexnet_{epoch + 1:2d}.pth")
        # agst.evaluate_model(train_loader)
        # agst.evaluate_model(test_loader)
        robustness = agst.evaluate_model_robustness(test_loader)

AGST on 0.0
Successfully load weights from "./alex_weights.pth"


 88%|████████▊ | 413/469 [03:02<00:45,  1.24it/s]