In [31]:
import torch
import torchvision
import torch.nn as nn
import os
import matplotlib.pyplot as plt
from einops import rearrange
from tqdm import tqdm
from PIL import Image
from torchvision import transforms

from overcomplete.models import DinoV2, ViT, ResNet, ViT_Large, SigLIP
from overcomplete.sae import TopKSAE
from overcomplete.visualization.plot_utils import (interpolate_cv2, show)
from overcomplete.visualization.cmaps import VIRIDIS_ALPHA


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()


In [32]:
# Load Model
model = ViT(device="cuda")


# Load SAE
sae = TopKSAE(768, nb_concepts=768*8, top_k=16, device="cuda")
sae.load_state_dict(torch.load("./models/ViT_MLP.pt")["ViT"])

## Define domain paths
domain_roots = {
    "inet": "./datasets/imagenet100/train.X1",
    "sketch": "./datasets/imagenetsketch/sketch"
}


In [33]:
def visualize_class_on_concept(concept, class_idx, model, sae, domain_roots):
    transform = transforms.ToTensor()
    

    for domain, dir in domain_roots.items(): 
        classes = os.listdir(dir)
        class_dir = os.path.join(dir, classes[class_idx])

        images = [Image.open(os.path.join(path)) for path in os.listdir(class_dir)]

        for img in images:
            img = model.preprocess(img).unsqueeze(dim=0)
            x = model.forward_features(img.to(device))
            x = rearrange(x, 'n t d -> (n t) d')
            _, z = sae.encode(x)
            z = rearrange(z, '(n w h) d -> n w h d', w=14, h=14)
            width, height = img.shape[-1], img.shape[-2]
            heatmap = interpolate_cv2(z[:, :, :, concept], (width, height))
            show(x)
            show(heatmap, cmap=VIRIDIS_ALPHA, alpha=1.0)
            plt.show()

## Testing

In [None]:
visualize_class_on_concept(
    concept=606,
    
)