## Load dataset

In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader

BATCH_SIZE = 256

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# transform = transforms.Compose([
#     K.geometry.Resize((224, 224)),
#     K.enhance.Normalize(mean=torch.tensor((0.48145466, 0.4578275, 0.40821073)), 
#                         std=torch.tensor((0.26862954, 0.26130258, 0.27577711)))
# ])
train_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', 
                                 train=True, 
                                 transform=transforms.ToTensor(), 
                                 download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', 
                                train=False, 
                                transform=transforms.ToTensor(), 
                                download=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

DATA_SIZE = train_dataset[0][0].size()

## Load model

In [None]:
import os
from tqdm import tqdm

DEVICE = torch.device("cuda")
WEIGHTS_PATH = "./resnet18_weights.pth"
have_loaded_weights = False

def load_model(transform:nn.Module=None):
    model = resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model = model.float().to(DEVICE)

    have_loaded_weights = False
    if os.path.exists(WEIGHTS_PATH):
        model.load_state_dict(torch.load(WEIGHTS_PATH))
        have_loaded_weights = True
        print(f"Successfully load weights from \"{WEIGHTS_PATH}\"")

    if transform is not None:
        model = ComposedModel(model, 
                        transform)
    
        
    return model, have_loaded_weights

class ComposedModel(nn.Module):
    def __init__(self, model: nn.Module, compose: nn.Module):
        super().__init__()
        self.model = model
        self.compose = compose

    def forward(self, x):
        x = self.compose(x)
        return self.model(x)

## Fine-tune model

In [None]:
import torch.optim as optim
import numpy as np

def finetune_model():
    model, have_loaded_weights = load_model(transform)
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    num_epochs = 20
    if not have_loaded_weights:
        for epoch in range(num_epochs):
            model.train()
            running_loss = 0.0
            for idx, data in tqdm(enumerate(train_loader), 
                            total=train_loader.__len__()):
                batch_X, batch_Y = data
                batch_X:torch.Tensor = batch_X.to(DEVICE)
                batch_Y:torch.Tensor = batch_Y.to(DEVICE)
                
                outputs = model(batch_X)
                loss = criterion(outputs, batch_Y)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                running_loss += loss.item()
            print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
            
            if hasattr(model, "model"):
                torch.save(model.model.state_dict(), WEIGHTS_PATH)
            else: 
                torch.save(model.state_dict(), WEIGHTS_PATH)
            model.eval()
            totall_accuracy = []
            with torch.no_grad():
                for idx, data in tqdm(enumerate(test_loader), 
                            total=test_loader.__len__()):
                    batch_X, batch_Y = data
                    batch_X:torch.Tensor = batch_X.to(DEVICE)
                    batch_Y:torch.Tensor = batch_Y.to(DEVICE)
                    
                    outputs = model(batch_X)
                    predicted = outputs.argmax(1)
                    totall_accuracy.append((predicted == batch_Y).float().mean().item())

            totall_accuracy = np.array(totall_accuracy).mean()
            print(f"Epoch [{epoch + 1}/{num_epochs}], Accuracy: {100 * totall_accuracy:.2f}")
            
    else:
        print("Already loaded pretrained wweights.")

finetune_model()

## ASGT

In [None]:
from asgt import *
import asgt.model_utils as model_utils
import asgt.attack_utils as attack_utils
from captum.attr import Saliency

LEARNING_RATE = 1e-4
# EPS = [0.005, 0.015, 0.020, 0.025, 0.030, 0.035, 0.040]
EPS = [0.025]

K = int(DATA_SIZE[-2] * DATA_SIZE[-1] * 0.1)
LAMBDA = 1.0

robustness_list = []



eps_tqdm = tqdm(EPS[robustness_list.__len__():])
for eps in eps_tqdm:
    num_epochs = 20
    eps_tqdm.set_description(f"Eps {eps},Epoch [0/{num_epochs}], Loss: NaN")
    
    model, have_loaded_weights = model_utils.load_model("resnet18", 
                                                    WEIGHTS_PATH, 
                                                    DEVICE, 
                                                    transform)
    
    saliency = Saliency(model)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    optimizer.zero_grad()
    def training_forward_func(loss:torch.Tensor):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    loss_func=nn.CrossEntropyLoss()
    # attak_func = attack_utils.FGSM(model, loss_func, eps = 0.25)
    attak_func = attack_utils.PGD(model, loss_func, alpha=0.01, eps=eps, epoch=30)

    asgt = robust_training(
        model=model,
        training_forward_func = training_forward_func,
        loss_func=loss_func,
        attak_func=attak_func,
        explain_func=partial(saliency.attribute, abs=False),
        eps=eps,
        k=K,
        lam=LAMBDA,
        feature_range=[0.0, 1.0],
        robust_loss_func = "adversarial_saliency_guided_training",
        device=torch.device(DEVICE)
    )
    # robustness = asgt.evaluate_model_robustness(test_loader)
    # asgt.evaluate_model(train_loader)
    asgt.evaluate_model(test_loader)
    robustness = asgt.evaluate_model_robustness(test_loader)
    
    
    for epoch in range(num_epochs):
        running_loss = asgt.train_one_epoch(train_loader, use_tqdm=False)
        eps_tqdm.set_description(f"Eps {eps},Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        
        if hasattr(model, "model"):
            torch.save(model.model.state_dict(), f"./robust_resnet18_{epoch + 1:02d}.pth")
        else: 
            torch.save(model.state_dict(), f"./robust_resnet18_{epoch + 1:02d}.pth")
        
        if os.path.exists(f"./robust_resnet18_{epoch:02d}.pth"):
            os.remove(f"./robust_resnet18_{epoch:02d}.pth")
    
        asgt.evaluate_model(train_loader)
        asgt.evaluate_model(test_loader)
        robustness = asgt.evaluate_model_robustness(test_loader)
        
    robustness_list.append(robustness)