In [None]:
import sys
sys.path.insert(0,'../code')

In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from datamodules.transformations import UnNest
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from transformers import ViTFeatureExtractor, ViTForImageClassification
from torchvision.transforms import Compose
from tqdm.auto import tqdm
from datamodules.image_classification import CIFAR10DataModule
from models.interpretation import ImageInterpretationNet
from utils.getters_setters import vit_getter, vit_setter
from attributions.grad_cam import grad_cam
from attributions.attention_rollout import attention_rollout

### Load CIFAR-10 Test Split and Model

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

vit = ViTForImageClassification.from_pretrained("tanlq/vit-base-patch16-224-in21k-finetuned-cifar10").to(device)

diffmask = ImageInterpretationNet.load_from_checkpoint('diffmask.ckpt').to(device)

feature_extractor=ViTFeatureExtractor.from_pretrained(
    "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10", return_tensors="pt"
)
feature_extractor = UnNest(feature_extractor)

dm = CIFAR10DataModule(feature_extractor=feature_extractor, batch_size=16)
dm.prepare_data()
dm.setup()
dataloader = dm.test_dataloader()

### Function to measure KL-divergence between masked & unmasked input

In [None]:
@torch.no_grad()
def kl_divergence(model, input, mask, b, n_hidden = 14, patch_size=16):
    # Reshape mask to hidden states' shape
    B, H, W = mask.shape
    mask = mask.reshape(B, 1, H, W)
    mask = F.interpolate(mask, scale_factor=1/patch_size)
    mask = mask.reshape(B, -1, 1)
    
    # Get hidden states from unmasked input
    logits_orig, hidden_states = vit_getter(model, input)
    
    # Calculate hidden states from masked input
    new_hidden_states = hidden_states[0] * mask + b * (1 - mask)        
    new_hidden_states = [new_hidden_states] + [None] * (n_hidden - 1)
    
    # Append CLS token
    cls_tokens = model.vit.embeddings.cls_token.expand(B, -1, -1)
    new_hidden_states[0] = torch.cat((cls_tokens, new_hidden_states[0]), dim=1)

    # Get logits from new hidden states (masked input)
    logits, _ = vit_setter(model, input, new_hidden_states)

    # Compute KL divergence between the logits from the original and the masked input
    kl_div = torch.distributions.kl_divergence(
                torch.distributions.Categorical(logits=logits_orig),
                torch.distributions.Categorical(logits=logits),
            )
    
    return kl_div

### KL-Divergence for Grad-CAM

In [None]:
klds = []
masks_percentage = []

for images, _ in tqdm(dataloader):
    images = images.cuda()
    gradcam_masks = grad_cam(images, vit, True if device=='cuda' else False)
    masked_pixels_percentages = [100 * (1 - mask.mean(-1).mean(-1).item()) for mask in gradcam_masks]
    klds.append(kl_divergence(vit, images, gradcam_masks.cuda(), diffmask.gate.placeholder))
    masks_percentage.extend(masked_pixels_percentages)

    
klds = torch.cat(klds)
print(f"Grad-CAM mean KL-Divergence: {klds.mean()}")
print(f"Masking percentage: {np.mean(masks_percentage)}")

### KL-Divergence for Attention Rollout

In [None]:
klds = []
masks_percentage = []

for images, _ in tqdm(dataloader):
    images = images.cuda()
    rollout_masks = attention_rollout(images=images, vit=vit, device=device)
    masked_pixels_percentages = [100 * (1 - mask.mean(-1).mean(-1).item()) for mask in rollout_masks]
    klds.append(kl_divergence(vit, images, rollout_masks, diffmask.gate.placeholder))
    masks_percentage.extend(masked_pixels_percentages)

    
klds = torch.cat(klds)
print(f"Rollout mean KL-Divergence: {klds.mean()}")
print(f"Masking percentage: {np.mean(masks_percentage)}")

### KL-Divergence for DiffMask

In [None]:
diffmask.set_vision_transformer(vit)

klds = []
masks_percentage = []

for images, _ in tqdm(dataloader):
    images = images.cuda()
    diff_masks = diffmask.get_mask(images)["mask"].detach()
    masked_pixels_percentages = [100 * (1 - mask.mean(-1).mean(-1).item()) for mask in diff_masks]
    klds.append(kl_divergence(vit, images, diff_masks, diffmask.gate.placeholder))
    masks_percentage.extend(masked_pixels_percentages)
    
klds = torch.cat(klds)
print(f"Diffmask mean KL-Divergence: {klds.mean()}")
print(f"Masking percentage: {np.mean(masks_percentage)}")