In [None]:
IMAGENET_R_CLASS_SUBLIST = [
        1, 2, 4, 6, 8, 9, 11, 13, 22, 23, 26, 29, 31, 39, 47, 63, 71, 76, 79, 84, 90, 94, 96, 97, 99, 100, 105, 107,
        113, 122,
        125, 130, 132, 144, 145, 147, 148, 150, 151, 155, 160, 161, 162, 163, 171, 172, 178, 187, 195, 199, 203,
        207, 208, 219,
        231, 232, 234, 235, 242, 245, 247, 250, 251, 254, 259, 260, 263, 265, 267, 269, 276, 277, 281, 288, 289,
        291, 292, 293,
        296, 299, 301, 308, 309, 310, 311, 314, 315, 319, 323, 327, 330, 334, 335, 337, 338, 340, 341, 344, 347,
        353, 355, 361,
        362, 365, 366, 367, 368, 372, 388, 390, 393, 397, 401, 407, 413, 414, 425, 428, 430, 435, 437, 441, 447,
        448, 457, 462,
        463, 469, 470, 471, 472, 476, 483, 487, 515, 546, 555, 558, 570, 579, 583, 587, 593, 594, 596, 609, 613,
        617, 621, 629,
        637, 657, 658, 701, 717, 724, 763, 768, 774, 776, 779, 780, 787, 805, 812, 815, 820, 824, 833, 847, 852,
        866, 875, 883,
        889, 895, 907, 928, 931, 932, 933, 934, 936, 937, 943, 945, 947, 948, 949, 951, 953, 954, 957, 963, 965,
        967, 980, 981,
        983, 988]
IMAGENET_R_CLASS_SUBLIST_MASK = [(i in IMAGENET_R_CLASS_SUBLIST) for i in range(1000)]

In [None]:
import argparse
import copy 
import sys

import numpy as np
import torch 
import torchvision
from torchvision import models
import torchvision.transforms.v2 as v2 
import matplotlib.pyplot as plt

sys.path.insert(0, '../')
import datasets

from methods.tact_utils import get_PCs, remove_PCs


In [None]:
def show_image(cam):
    result = []
    for img in cam:
        img = img - np.min(img)
        img = img / (np.max(img)-np.min(img))
        result.append(img)
    result = np.float32(result)
    plt.imshow(np.transpose(result, (1, 2, 0))) 
    plt.axis('off')
    plt.show()

In [None]:
def get_rgb_image(image):
    result = []
    for img in image:
        img = img - np.min(img)
        img = img / (np.max(img)-np.min(img))
        result.append(img)
    result = np.float32(result)
    result = np.transpose(result, (1, 2, 0))
    return result

In [None]:
PRETRAINED_WEIGHT_DICT = {
    'vit_b_32': models.ViT_B_32_Weights.IMAGENET1K_V1,
}

base_augmentation = v2.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
non_causal_augmentation = v2.Compose([
                v2.ToDtype(torch.uint8, scale=True),
                v2.RandAugment(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
            ])

device = 'cpu'

In [None]:
arguments = {
    'dataset': 'imagenet_r',
    'data_dir': '/path/to/data/dir',
    'eval_batch_size': 8,
}

args = argparse.Namespace(**arguments)
device = torch.device('cpu')

dataset_class = getattr(datasets, args.dataset)

In [None]:
model = models.vit_b_32(weights=PRETRAINED_WEIGHT_DICT['vit_b_32']).to(device)
test_loader = dataset_class.getTestLoader(args, device)

In [None]:
def classifier(model, x):
    return model.heads(x)
        
def featurize(model, x):  
    x = model._process_input(x)
    n = x.shape[0]
    batch_class_token = model.class_token.expand(n, -1, -1)
    x = torch.cat([batch_class_token, x], dim=1)

    x = model.encoder(x)
    x = x[:, 0]
    return x

classifier_weight = model.heads[0].weight.detach()

In [None]:
import torch.nn as nn

class AttributeModel(nn.Module):
    def __init__(self, model, weight):
        super().__init__()
        self.model = copy.deepcopy(model)
        self.weight = copy.deepcopy(weight)
        
    def forward(self, x):
        x = self.model._process_input(x)
        n = x.shape[0]

        batch_class_token = self.model.class_token.expand(n, -1, -1)
        x = torch.cat([batch_class_token, x], dim=1)

        x = self.model.encoder(x)
        x = x[:, 0]
        if len(self.weight.size()) == 2:
            return torch.mm(x, self.weight.T)
        else:
            y_hat = torch.bmm(x.unsqueeze(1), self.weight.transpose(1, 2))
            y_hat = y_hat.squeeze(1)
            return y_hat

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
import cv2
def scale_cam_image(cam, target_size=None):
    result = []
    for img in cam:
        img = img - np.min(img)
        img = img / (np.max(img)-np.min(img))
        if target_size is not None:
            img = cv2.resize(img, target_size)
        result.append(img)
    result = np.float32(result)

    return result

def visualise_CAM(image, cam):
    cam = scale_cam_image(cam, target_size=(224,224))
    image = scale_cam_image(image.numpy())
    image = np.transpose(image, (1,2,0))
    visualization = show_cam_on_image(image, cam, use_rgb=True)

    CAM = np.uint8(255*cam)
    CAM = cv2.merge([CAM, CAM, CAM])
    images = np.hstack((np.uint8(255*image), CAM , visualization))
    return images 

In [None]:
def reshape_transform(tensor, height=7, width=7):
    result = tensor[:, 1 :  , :].reshape(tensor.size(0),
        height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

In [None]:
max_aug = 256
remove = 10

In [None]:
for batch_num, (x, y) in enumerate(test_loader):
    x, y = x.to(device), y.to(device)
    base_x = [base_augmentation(current_x) for current_x in x]
    base_x = torch.stack(base_x).to(device)
    features = featurize(model, base_x)
    features_under_augmentation = [features.detach()]
    for _ in range(max_aug):
        augmented_x = non_causal_augmentation(x)
        feature_under_aug = featurize(model, augmented_x) 
        features_under_augmentation.append(feature_under_aug.detach())
    features_under_augmentation = torch.stack(features_under_augmentation)
    
    current_features = features_under_augmentation[:max_aug+1]
    current_features = current_features.transpose(0, 1)
    _, V, mean = get_PCs(current_features)
    
    update_f = remove_PCs(features, mean, V, remove)
    
    model_weight = classifier_weight
    model_weight = torch.stack([model_weight for _ in range(features.size(0))])
    projected_prototype = remove_PCs(model_weight, mean, V, remove)
    
    old_pred = classifier(model, features)[:, IMAGENET_R_CLASS_SUBLIST_MASK].argmax(1)
    new_pred = classifier(model, update_f)[:, IMAGENET_R_CLASS_SUBLIST_MASK].argmax(1)
    break 

In [None]:
show_image(torchvision.utils.make_grid(base_x.reshape(-1, 3, 224, 224),4, 2).numpy())

GradCAM

In [None]:
project_y = [IMAGENET_R_CLASS_SUBLIST[y[i]] for i in range(y.size(0))]

In [None]:
targets = [ClassifierOutputTarget(i) for i in project_y]
attr_model = AttributeModel(model, classifier_weight)

# https://github.com/jacobgil/pytorch-grad-cam/blob/master/tutorials/vision_transformers.md
gradcam = GradCAM(model=attr_model, target_layers=[attr_model.model.encoder.layers[-1].ln_1], reshape_transform=reshape_transform)
grayscale_cam = gradcam(input_tensor=base_x, targets=targets, aug_smooth=True)

In [None]:
tact_attr_model = AttributeModel(model, projected_prototype)

tact_gradcam = GradCAM(model=tact_attr_model, target_layers=[tact_attr_model.model.encoder.layers[-1].ln_1], reshape_transform=reshape_transform)
tact_grayscale_cam = tact_gradcam(input_tensor=base_x, targets=targets, aug_smooth=True)

In [None]:
cams_on_image = [
    show_cam_on_image(get_rgb_image(base_x[i].numpy()), grayscale_cam[i], use_rgb=True, image_weight=0.5) for i in range(base_x.size(0))
]
cams_on_image = np.stack(cams_on_image).transpose(0, 3, 1, 2)

show_image(torchvision.utils.make_grid(torch.as_tensor(cams_on_image), 4, 2).numpy())
plt.show()

In [None]:
tact_cams_on_image = [
    show_cam_on_image(get_rgb_image(base_x[i].numpy()), tact_grayscale_cam[i], use_rgb=True, image_weight=0.5) for i in range(base_x.size(0))
]
tact_cams_on_image = np.stack(tact_cams_on_image).transpose(0, 3, 1, 2)

show_image(torchvision.utils.make_grid(torch.as_tensor(tact_cams_on_image), 4, 2).numpy())
plt.show()