In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
import torch.optim as optim
import torchvision.transforms as transforms
import kornia as K
from functools import partial
from torchvision.models import resnet18
from torch.utils.data import DataLoader

BATCH_SIZE = 1

# transform = transforms.Compose([
#     transforms.Resize(224),
#     transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# ])

transform = transforms.Compose([
    K.geometry.Resize((224, 224)),
    K.enhance.Normalize(mean=torch.tensor((0.48145466, 0.4578275, 0.40821073)), 
                        std=torch.tensor((0.26862954, 0.26130258, 0.27577711)))
])
train_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', 
                                 train=True, 
                                 transform=transforms.ToTensor(), 
                                 download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', 
                                train=False, 
                                transform=transforms.ToTensor(), 
                                download=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
import os
import asgt
from asgt import model_utils
from asgt import attack_utils
from tqdm import tqdm

DEVICE = torch.device("cuda")
VANILLA_WEIGHTS_PATH = "./resnet18_weights.pth"
ROBUST_WEIGHTS_PATH = "./robust_resnet18_20.pth"

model, have_loaded_weights = model_utils.load_model("resnet18", 
                                                    VANILLA_WEIGHTS_PATH, 
                                                    DEVICE, 
                                                    transform)
robust_model, have_loaded_weights = model_utils.load_model("resnet18", 
                                                           ROBUST_WEIGHTS_PATH, 
                                                           DEVICE, 
                                                           transform)
loss_func = nn.CrossEntropyLoss()
FGSM_model = attack_utils.FGSM(model, loss_func, eps = 0.025)
FGSM_robust_model = attack_utils.FGSM(robust_model, loss_func, eps = 0.025)

attack_utils.evaluate_model(model, test_loader, DEVICE)
attack_utils.evaluate_model(robust_model, test_loader, DEVICE)

attack_utils.evaluate_model_robustness(model, test_loader, FGSM_model, DEVICE)
attack_utils.evaluate_model_robustness(robust_model, test_loader, FGSM_robust_model, DEVICE)

In [None]:
from asgt import *
from captum.attr import Saliency


saliency = Saliency(model)
robust_saliency = Saliency(robust_model)
for idx, data in tqdm(enumerate(test_loader), 
                            total=test_loader.__len__()):
    if idx > 100:
        break
    batch_X, batch_Y = data
    batch_X:torch.Tensor = batch_X.to(DEVICE)
    batch_Y:torch.Tensor = batch_Y.to(DEVICE)

    
    outputs = model(batch_X).argmax(1).item()
    robust_outputs = robust_model(batch_X).argmax(1).item()
    
    if outputs != batch_Y.item() or robust_outputs != batch_Y.item():
        continue
      
    attribution = saliency.attribute(batch_X, batch_Y)
    robust_attribution = robust_saliency.attribute(batch_X, batch_Y)

    figure, axis = visualization.visualize_image_attr_multiple(attribution.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(), 
                                    batch_X.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(),
                                    signs=["all", 
                                        "positive",
                                        "positive",
                                        "positive",
                                        "positive"],
                                    titles=[None,
                                            None,
                                            "Vanilla atrribution",
                                            None,
                                            None],
                                    use_pyplot=False,
                                    methods=["original_image", "heat_map", "blended_heat_map", "masked_image", "alpha_scaling"],)
    figure.savefig(f"./imgs/CIFAR10_{idx:002d}_vanilla.jpg", format='jpg', dpi=300)

    figure, axis = visualization.visualize_image_attr_multiple(robust_attribution.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(), 
                                batch_X.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(),
                                signs=["all", 
                                    "positive",
                                    "positive",
                                    "positive",
                                    "positive"],
                                titles=[None,
                                        None,
                                        "Robust atrribution",
                                        None,
                                        None],
                                use_pyplot=False,
                                methods=["original_image", "heat_map", "blended_heat_map", "masked_image", "alpha_scaling"],)
    
    figure.savefig(f"./imgs/CIFAR10_{idx:002d}_robust.jpg", format='jpg', dpi=300)
