In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
import torch.optim as optim
import torchvision.transforms as transforms
import kornia as K
from functools import partial
from torchvision.models import resnet18
from torch.utils.data import DataLoader

BATCH_SIZE = 1

# transform = transforms.Compose([
#     transforms.Resize(224),
#     transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
# ])

transform = transforms.Compose([
    K.geometry.Resize((224, 224)),
    K.enhance.Normalize(mean=torch.tensor((0.48145466, 0.4578275, 0.40821073)), 
                        std=torch.tensor((0.26862954, 0.26130258, 0.27577711)))
])
train_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', 
                                 train=True, 
                                 transform=transforms.ToTensor(), 
                                 download=True)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=False)

test_dataset = datasets.CIFAR10(root='/home/ksas/Public/datasets/cifar10_concept_bank', 
                                train=False, 
                                transform=transforms.ToTensor(), 
                                download=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
import os
from tqdm import tqdm

DEVICE = "cuda"
# WEIGHTS_PATH = "./robust_resnet18_0.025.pth"
VANILLA_WEIGHTS_PATH = "./resnet18_weights.pth"
ROBUST_WEIGHTS_PATH = "./robust_resnet18_0.025.pth"

have_loaded_weights = False

class ComposedModel(nn.Module):
    def __init__(self, model: nn.Module, compose: nn.Module):
        super().__init__()
        self.model = model
        self.compose = compose

    def forward(self, x):
        x = self.compose(x)
        return self.model(x)

def load_model(path:str, transform:nn.Module=None):
    model = resnet18(weights=models.ResNet18_Weights.DEFAULT)
    model.fc = nn.Linear(model.fc.in_features, 10)
    model = model.float().to(DEVICE)
    
    have_loaded_weights = False
    if os.path.exists(path):
        model.load_state_dict(torch.load(path))
        have_loaded_weights = True
        print(f"Successfully load weights from \"{path}\"")
        
    if transform is not None:
        model = ComposedModel(model, 
                        transform)
        
    return model, have_loaded_weights




In [None]:
from asgt import *
from captum.attr import Saliency

model, have_loaded_weights = load_model(VANILLA_WEIGHTS_PATH, transform)
totall_accuracy = []
with torch.no_grad():
    for idx, data in tqdm(enumerate(test_loader), 
                total=test_loader.__len__()):
        batch_X, batch_Y = data
        batch_X:torch.Tensor = batch_X.to(DEVICE)
        batch_Y:torch.Tensor = batch_Y.to(DEVICE)
        
        outputs = model(batch_X)
        predicted = outputs.argmax(1)
        totall_accuracy.append((predicted == batch_Y).float().mean().item())

totall_accuracy = np.array(totall_accuracy).mean()
print(totall_accuracy)
robust_model, have_loaded_weights = load_model(ROBUST_WEIGHTS_PATH, transform)
saliency = Saliency(model)
robust_saliency = Saliency(robust_model)
for idx, data in tqdm(enumerate(test_loader), 
                            total=test_loader.__len__()):
    batch_X, batch_Y = data
    batch_X:torch.Tensor = batch_X.to(DEVICE)
    batch_Y:torch.Tensor = batch_Y.to(DEVICE)

    
    outputs = model(batch_X).argmax(1).item()
    robust_outputs = robust_model(batch_X).argmax(1).item()
    
    if outputs != batch_Y.item() or robust_outputs != batch_Y.item():
        continue
      
    attribution = saliency.attribute(batch_X, batch_Y)
    robust_attribution = robust_saliency.attribute(batch_X, batch_Y)

    figure, axis = visualization.visualize_image_attr_multiple(attribution.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(), 
                                    batch_X.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(),
                                    signs=["all", 
                                        "positive",
                                        "positive",
                                        "positive",
                                        "positive"],
                                    titles=[None,
                                            None,
                                            "Vanilla atrribution",
                                            None,
                                            None],
                                    use_pyplot=True,
                                    methods=["original_image", "heat_map", "blended_heat_map", "masked_image", "alpha_scaling"],)
    
    figure, axis = visualization.visualize_image_attr_multiple(robust_attribution.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(), 
                                batch_X.squeeze(0).permute((1, 2, 0)).detach().cpu().numpy(),
                                signs=["all", 
                                    "positive",
                                    "positive",
                                    "positive",
                                    "positive"],
                                titles=[None,
                                        None,
                                        "Robust atrribution",
                                        None,
                                        None],
                                use_pyplot=True,
                                methods=["original_image", "heat_map", "blended_heat_map", "masked_image", "alpha_scaling"],)
    
    import pdb; pdb.set_trace()
