## Load dataset

In [1]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torch.utils.data import DataLoader

BATCH_SIZE = 256

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    # transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

train_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', train=False, transform=transform, download=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)



Files already downloaded and verified
Files already downloaded and verified


## Load model

In [2]:
import os
from tqdm import tqdm

DEVICE = "cuda"
WEIGHTS_PATH = "./resnet18_weights.pth"
have_loaded_weights = False

def load_model():
    model = resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model = model.to(DEVICE)

    have_loaded_weights = False
    if os.path.exists(WEIGHTS_PATH):
        model.load_state_dict(torch.load(WEIGHTS_PATH))
        have_loaded_weights = True
        print(f"Successfully load weights from \"{WEIGHTS_PATH}\"")
    return model, have_loaded_weights

class ComposedModel(nn.Module):
    def __init__(self, model: nn.Module, compose: nn.Module):
        super().__init__()
        self.model = model
        self.compose = compose

    def forward(self, x):
        x = self.compose(x)
        return self.model(x)

## Fine-tune model

In [3]:
import torch.optim as optim
import numpy as np

model, have_loaded_weights = load_model()
model = ComposedModel(model, 
                      transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)))
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 20
if not have_loaded_weights:
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for idx, data in tqdm(enumerate(train_loader), 
                          total=train_loader.__len__()):
            batch_X, batch_Y = data
            batch_X:torch.Tensor = batch_X.to(DEVICE)
            batch_Y:torch.Tensor = batch_Y.to(DEVICE)
            
            outputs = model(batch_X)
            loss = criterion(outputs, batch_Y)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        torch.save(model.model.state_dict(), WEIGHTS_PATH)
        
        model.eval()
        totall_accuracy = []
        with torch.no_grad():
            for idx, data in tqdm(enumerate(test_loader), 
                          total=test_loader.__len__()):
                batch_X, batch_Y = data
                batch_X:torch.Tensor = batch_X.to(DEVICE)
                batch_Y:torch.Tensor = batch_Y.to(DEVICE)
                
                outputs = model(batch_X)
                predicted = outputs.argmax(1)
                totall_accuracy.append((predicted == batch_Y).float().mean().item())

        totall_accuracy = np.array(totall_accuracy).mean()
        print(f"Epoch [{epoch + 1}/{num_epochs}], Accuracy: {100 * totall_accuracy:.2f}")
        
else:
    print("Already loaded pretrained wweights.")


100%|██████████| 196/196 [02:09<00:00,  1.52it/s]


Epoch [1/20], Loss: 0.4041


100%|██████████| 40/40 [00:30<00:00,  1.32it/s]


Epoch [1/20], Accuracy: 93.57


100%|██████████| 196/196 [02:53<00:00,  1.13it/s]


Epoch [2/20], Loss: 0.0902


100%|██████████| 40/40 [00:35<00:00,  1.13it/s]


Epoch [2/20], Accuracy: 94.13


100%|██████████| 196/196 [02:56<00:00,  1.11it/s]


Epoch [3/20], Loss: 0.0203


100%|██████████| 40/40 [00:25<00:00,  1.59it/s]


Epoch [3/20], Accuracy: 94.42


100%|██████████| 196/196 [02:29<00:00,  1.31it/s]


Epoch [4/20], Loss: 0.0054


100%|██████████| 40/40 [00:27<00:00,  1.45it/s]


Epoch [4/20], Accuracy: 94.64


100%|██████████| 196/196 [02:49<00:00,  1.15it/s]


Epoch [5/20], Loss: 0.0022


100%|██████████| 40/40 [00:26<00:00,  1.51it/s]


Epoch [5/20], Accuracy: 94.63


100%|██████████| 196/196 [02:30<00:00,  1.30it/s]


Epoch [6/20], Loss: 0.0013


100%|██████████| 40/40 [00:34<00:00,  1.16it/s]


Epoch [6/20], Accuracy: 94.78


100%|██████████| 196/196 [02:35<00:00,  1.26it/s]


Epoch [7/20], Loss: 0.0009


100%|██████████| 40/40 [00:29<00:00,  1.34it/s]


Epoch [7/20], Accuracy: 94.80


100%|██████████| 196/196 [02:25<00:00,  1.35it/s]


Epoch [8/20], Loss: 0.0007


100%|██████████| 40/40 [00:32<00:00,  1.25it/s]


Epoch [8/20], Accuracy: 94.86


100%|██████████| 196/196 [02:38<00:00,  1.24it/s]


Epoch [9/20], Loss: 0.0005


100%|██████████| 40/40 [00:31<00:00,  1.28it/s]


Epoch [9/20], Accuracy: 94.87


100%|██████████| 196/196 [02:45<00:00,  1.18it/s]


Epoch [10/20], Loss: 0.0004


100%|██████████| 40/40 [00:32<00:00,  1.23it/s]


Epoch [10/20], Accuracy: 94.88


100%|██████████| 196/196 [02:33<00:00,  1.28it/s]


Epoch [11/20], Loss: 0.0003


100%|██████████| 40/40 [00:28<00:00,  1.41it/s]


Epoch [11/20], Accuracy: 94.91


100%|██████████| 196/196 [02:42<00:00,  1.21it/s]


Epoch [12/20], Loss: 0.0003


100%|██████████| 40/40 [00:20<00:00,  1.99it/s]


Epoch [12/20], Accuracy: 94.89


100%|██████████| 196/196 [02:18<00:00,  1.41it/s]


Epoch [13/20], Loss: 0.0002


100%|██████████| 40/40 [00:17<00:00,  2.30it/s]


Epoch [13/20], Accuracy: 94.87


100%|██████████| 196/196 [02:35<00:00,  1.26it/s]


Epoch [14/20], Loss: 0.0002


100%|██████████| 40/40 [00:21<00:00,  1.86it/s]


Epoch [14/20], Accuracy: 94.86


100%|██████████| 196/196 [02:37<00:00,  1.24it/s]


Epoch [15/20], Loss: 0.0002


100%|██████████| 40/40 [00:30<00:00,  1.31it/s]


Epoch [15/20], Accuracy: 94.86


100%|██████████| 196/196 [02:25<00:00,  1.35it/s]


Epoch [16/20], Loss: 0.0001


100%|██████████| 40/40 [00:30<00:00,  1.30it/s]


Epoch [16/20], Accuracy: 94.84


100%|██████████| 196/196 [03:10<00:00,  1.03it/s]


Epoch [17/20], Loss: 0.0001


100%|██████████| 40/40 [00:31<00:00,  1.26it/s]


Epoch [17/20], Accuracy: 94.84


100%|██████████| 196/196 [02:27<00:00,  1.33it/s]


Epoch [18/20], Loss: 0.0001


100%|██████████| 40/40 [00:34<00:00,  1.16it/s]


Epoch [18/20], Accuracy: 94.85


100%|██████████| 196/196 [02:37<00:00,  1.24it/s]


Epoch [19/20], Loss: 0.0001


100%|██████████| 40/40 [00:24<00:00,  1.63it/s]


Epoch [19/20], Accuracy: 94.82


100%|██████████| 196/196 [02:33<00:00,  1.28it/s]


Epoch [20/20], Loss: 0.0001


100%|██████████| 40/40 [00:36<00:00,  1.10it/s]

Epoch [20/20], Accuracy: 94.80





## ASGT

In [None]:
from asgt import *
from captum.attr import Saliency

LEARNING_RATE = 1e-4
EPS = [0.025]
K = int((3 * 224 * 224) * 0.1)
LAMBDA = 1.0


for eps in EPS:
    print(f"ASGT on {eps}")
    model, have_loaded_weights = load_model()
    model = ComposedModel(model, 
                      transforms.Normalize((0.48145466, 0.4578275, 0.40821073), 
                                           (0.26862954, 0.26130258, 0.27577711)))
    saliency = Saliency(model)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE)

    def training_forward_func(loss:torch.Tensor):
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    asgt = ASGT(
        model=model,
        training_forward_func = training_forward_func,
        loss_func=nn.CrossEntropyLoss(),
        attak_func="FGSM",
        explain_func=saliency.attribute,
        eps=eps,
        k=K,
        lam=LAMBDA,
        feature_range=[0.0, 1.0],
        device=torch.device(DEVICE)
    )

    asgt.evaluate_model(train_loader)
    asgt.evaluate_model(test_loader)
    robustness = asgt.evaluate_model_robustness(test_loader)
    
    num_epochs = 20
    for epoch in range(num_epochs):
        running_loss = asgt.train_one_epoch(train_loader)
        print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(train_loader):.4f}")
        # torch.save(model.state_dict(), f"./robust_alexnet_{epoch + 1:2d}.pth")           
        asgt.evaluate_model(train_loader)
        asgt.evaluate_model(test_loader)
        robustness = asgt.evaluate_model_robustness(test_loader)

ASGT on 0.025
Successfully load weights from "./resnet18_weights.pth"


100%|██████████| 196/196 [02:40<00:00,  1.22it/s]


Accuracy: 100.00%


100%|██████████| 40/40 [00:32<00:00,  1.21it/s]


Accuracy: 94.80%


100%|██████████| 40/40 [00:39<00:00,  1.01it/s]


Robustness accuracy: 34.17%


100%|██████████| 196/196 [04:28<00:00,  1.37s/it]


Epoch [1/20], Loss: 10.9098


100%|██████████| 196/196 [02:16<00:00,  1.43it/s]


Accuracy: 13.25%


100%|██████████| 40/40 [00:24<00:00,  1.61it/s]


Accuracy: 12.73%


100%|██████████| 40/40 [00:32<00:00,  1.22it/s]


Robustness accuracy: 13.21%


100%|██████████| 196/196 [04:56<00:00,  1.51s/it]


Epoch [2/20], Loss: 8.4720


100%|██████████| 196/196 [02:50<00:00,  1.15it/s]


Accuracy: 12.85%


100%|██████████| 40/40 [00:31<00:00,  1.28it/s]


Accuracy: 12.85%


100%|██████████| 40/40 [00:43<00:00,  1.08s/it]


Robustness accuracy: 12.93%


 93%|█████████▎| 183/196 [04:50<00:15,  1.16s/it]