## Load dataset

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)), 
])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.Compose([
    transforms.Grayscale(num_output_channels=3), 
    transforms.ToTensor()
]), download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=False)

test_dataset = datasets.MNIST(root='./data', train=False, transform=transforms.Compose([
    transforms.Grayscale(num_output_channels=3), 
    transforms.ToTensor()
]), download=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)

DATA_SIZE = train_dataset[0][0].size()


## Load model

In [2]:
import os
import asgt
from asgt import model_utils
from asgt import attack_utils
from tqdm import tqdm

DEVICE = torch.device("cuda")
WEIGHTS_PATH = "./alexnet_weights.pth"

model, have_loaded_weights = model_utils.load_model("alexnet", 
                                                    WEIGHTS_PATH, 
                                                    DEVICE, 
                                                    transform)

if have_loaded_weights:
    loss_func = nn.CrossEntropyLoss()
    FGSM_model = attack_utils.FGSM(model, loss_func, eps = 0.25)

    attack_utils.evaluate_model(model, test_loader, DEVICE)
    attack_utils.evaluate_model_robustness(model, test_loader, FGSM_model, DEVICE)

Successfully load weights from "./alexnet_weights.pth"


100%|██████████| 79/79 [00:01<00:00, 41.24it/s]


Accuracy: 96.24%


100%|██████████| 79/79 [00:02<00:00, 26.88it/s]

Robustness accuracy: 22.78%





## Fine-tune model

In [3]:
import torch.optim as optim
import numpy as np

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 1
if not have_loaded_weights:
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for idx, data in tqdm(enumerate(train_loader), 
                          total=train_loader.__len__()):
            batch_X, batch_Y = data
            batch_X:torch.Tensor = batch_X.to(DEVICE)
            batch_Y:torch.Tensor = batch_Y.to(DEVICE)
            
            outputs = model(batch_X)
            loss = criterion(outputs, batch_Y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        torch.save(model.state_dict(), WEIGHTS_PATH)
        
        model.eval()
        totall_accuracy = []
        with torch.no_grad():
            for idx, data in tqdm(enumerate(test_loader), 
                          total=test_loader.__len__()):
                batch_X, batch_Y = data
                batch_X:torch.Tensor = batch_X.to(DEVICE)
                batch_Y:torch.Tensor = batch_Y.to(DEVICE)
                
                outputs = model(batch_X)
                predicted = outputs.argmax(1)
                totall_accuracy.append((predicted == batch_Y).float().mean().item())

        totall_accuracy = np.array(totall_accuracy).mean()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Accuracy: {100 * totall_accuracy:.2f}")
        
else:
    print("Already loaded pretrained wweights.")


Already loaded pretrained wweights.


## ASGT

In [4]:
from asgt import *
from captum.attr import Saliency

LEARNING_RATE = 1e-3 # 1e-2 not work here
EPS = [0.20]

K = int(DATA_SIZE[-1] * DATA_SIZE[-2] * 0.3)
LAMBDA = 1.0

robustness_list = []

eps_tqdm = tqdm(EPS[robustness_list.__len__():])
for eps in eps_tqdm:
    print(f"ASGT on {eps}")
    model, have_loaded_weights = model_utils.load_model("alexnet", 
                                                        WEIGHTS_PATH, 
                                                        DEVICE, 
                                                        transform)
    saliency = Saliency(model)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    optimizer.zero_grad()

    def training_forward_func(loss:torch.Tensor):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    asgt = ASGT(
        model=model,
        training_forward_func = training_forward_func,
        loss_func=nn.CrossEntropyLoss(),
        attak_func="FGSM",
        explain_func=partial(saliency.attribute, abs=False),
        eps=eps,
        k=K,
        lam=LAMBDA,
        feature_range=[0.0, 1.0],
        device=torch.device(DEVICE)
    )
    # robustness = asgt.evaluate_model_robustness(test_loader)
    # asgt.evaluate_model(train_loader)
    asgt.evaluate_model(test_loader)
    robustness = asgt.evaluate_model_robustness(test_loader)
    
    num_epochs = 1
    for epoch in range(num_epochs):
        running_loss = asgt.train_one_epoch(train_loader, use_tqdm=False)
        eps_tqdm.set_description(f"Eps {eps},Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        
        if hasattr(model, "model"):
            torch.save(model.model.state_dict(), f"./robust_alexnet_{epoch + 1:02d}.pth")
        else: 
            torch.save(model.state_dict(), f"./robust_alexnet_{epoch + 1:02d}.pth")
        
        if os.path.exists(f"./robust_alexnet_{epoch:02d}.pth"):
            os.remove(f"./robust_alexnet_{epoch:02d}.pth")
    
    asgt.evaluate_model(train_loader)
    asgt.evaluate_model(test_loader)
    robustness = asgt.evaluate_model_robustness(test_loader)
        
    robustness_list.append(robustness)

  0%|          | 0/1 [00:00<?, ?it/s]

ASGT on 0.2
Successfully load weights from "./alexnet_weights.pth"


100%|██████████| 79/79 [00:01<00:00, 48.88it/s]


Accuracy: 96.24%


100%|██████████| 79/79 [00:02<00:00, 29.12it/s]


Robustness accuracy: 27.43%


100%|██████████| 469/469 [00:09<00:00, 47.13it/s] | 0/1 [00:43<?, ?it/s]


Accuracy: 97.97%


100%|██████████| 79/79 [00:01<00:00, 41.71it/s]


Accuracy: 98.09%


100%|██████████| 79/79 [00:02<00:00, 28.92it/s]
Eps 0.2,Epoch [1/1], Loss: 0.8815: 100%|██████████| 1/1 [00:58<00:00, 58.43s/it]

Robustness accuracy: 97.13%



