In [None]:
import sys
sys.path.insert(0,'../code')

In [None]:
import torch
import time
from transformers import ViTFeatureExtractor, ViTForImageClassification
from models.interpretation import ImageInterpretationNet
from datamodules.image_classification import CIFAR10DataModule
from datamodules.transformations import UnNest
from attributions.grad_cam import grad_cam
from attributions.attention_rollout import attention_rollout

# Load model and data

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

vit = ViTForImageClassification.from_pretrained("tanlq/vit-base-patch16-224-in21k-finetuned-cifar10").to(device)

feature_extractor=ViTFeatureExtractor.from_pretrained(
    "tanlq/vit-base-patch16-224-in21k-finetuned-cifar10", return_tensors="pt"
)
feature_extractor = UnNest(feature_extractor)

dm = CIFAR10DataModule(feature_extractor=feature_extractor, batch_size=10)
dm.prepare_data()
dm.setup()
dataloader = iter(dm.test_dataloader())

images = next(dataloader)[0].to(device)

# Inference Time

In [None]:
start = time.time()
grad_cam(images, vit, True if device=='cuda' else False)
print(f"Inference time for Grad-CAM {time.time() - start}")

In [None]:
start = time.time()
attention_rollout(images=images, vit=vit, device=device)
print(f"Inference time for Attention Rollout {time.time() - start}")

In [None]:
diffmask = ImageInterpretationNet.load_from_checkpoint('diffmask.ckpt').to(device)
diffmask.set_vision_transformer(vit)

start = time.time()
diffmask.get_mask(images)["mask"]
print(f"Inference time for DiffMask {time.time() - start}")