# Activation Visualization Experiments



In [35]:
!pip install torch torchvision timm matplotlib pandas seaborn



In [36]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
import timm
import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm
import seaborn as sns
import pandas as pd

In [37]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

seed_everything()


In [38]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [39]:
activation_dict = {}
hooks = []
NFEX = 10
random_indices = {}

def get_hook(name):
    def hook(module, input, output):
        activations = output[:, 0, :].detach().cpu().numpy()
        if not len(random_indices):
            num_activations = activations.shape[1]
            random_indices["x"] = np.random.choice(num_activations, size=NFEX, replace=False)
        for i in range(activations.shape[0]):
            activation_dict[name]["tmp"][i].append(activations[i, random_indices["x"]])
    return hook

In [40]:
model = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=10)
model.eval()
for param in model.parameters():
    param.requires_grad = False

for idx, block in enumerate(model.blocks):
    name = f"block_{idx}"
    activation_dict[name] = {"tmp": []}
    hooks.append(block.register_forward_hook(get_hook(name)))

model = model.to(device)

In [41]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

cifar = datasets.CIFAR10(root="./data", train=False, transform=transform, download=True)
svhn = datasets.SVHN(root="./data", split='test', transform=transform, download=True)

cifar_loader = DataLoader(Subset(cifar, range(1000)), batch_size=32, shuffle=False)
svhn_loader = DataLoader(Subset(svhn, range(1000)), batch_size=32, shuffle=False)


In [42]:
def extract_features(dataloader):
    for key in activation_dict:
        for i in range(10):
            activation_dict[key][i] = []
        activation_dict[key]["tmp"] = []

    for images, targets in tqdm(dataloader, desc=f"Processing"):
        with torch.no_grad():
            activations = model(images.to(device))

        for key in activation_dict:
            for i in range(images.shape[0]):
                activation_dict[key][targets[i].item()].append(activation_dict[key]["tmp"][i])
            activation_dict[key]["tmp"] = []

    for key in activation_dict:
        for i in range(images.shape[0]):
            for cls, features in activation_dict[key].items():
                activation_dict[key][cls] = np.stack(features, axis=0).T
        del activation_dict[key]["tmp"]

    return activation_dict

In [45]:
print("CIFAR-10")
cifar_act = extract_features(cifar_loader)
print("SVHN")
svhn_act = extract_features(svhn_loader)


CIFAR-10


Processing: 100%|██████████| 32/32 [00:11<00:00,  2.83it/s]


SVHN


Processing: 100%|██████████| 32/32 [00:11<00:00,  2.78it/s]
