#### Setup


In [1]:
import torch
from lib.data_handlers import Load_ImageNet100
from overcomplete.models import DinoV2, ViT, ResNet
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ImageNet

#### Train SAE

Set Parameters

In [2]:
model = ViT(device="cuda")
concepts = 786 * 8
epochs = 50
save_dir = "results/ViT_ImageNet100"
dataloader = Load_ImageNet100(transform=model.preprocess, batch_size=2000, shuffle=True)

  from .autonotebook import tqdm as notebook_tqdm


Execute Flow

In [3]:
for image, _ in dataloader:
    image = image.to(device)
    Activations = model.forward_features(image)
    print("Raw Activation Shape: ", Activations.shape)
    Activations = rearrange(Activations, 'n t d -> (n t) d')
    print("Rearranged Activation Shape: ", Activations.shape)
    break

Raw Activation Shape:  torch.Size([2000, 196, 768])
Rearranged Activation Shape:  torch.Size([392000, 768])


In [None]:
sae = TopKSAE(Activations.shape[-1], nb_concepts=concepts, top_k=3, device='cuda')

dataloader = torch.utils.data.DataLoader(TensorDataset(Activations), batch_size=1024, shuffle=True)
optimizer = torch.optim.Adam(sae.parameters(), lr=5e-4)

def criterion(x, x_hat, pre_codes, codes, dictionary):
  mse = (x - x_hat).square().mean()
  return mse

logs = train_sae(sae, dataloader, criterion, optimizer, nb_epochs=epochs, device='cuda')

In [None]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !

sae = sae.eval()

with torch.no_grad():
  pre_codes, codes = sae.encode(Activations)

print("Raw Code Shape: ", codes.shape)
codes = rearrange(codes, '(n w h) d -> n w h d', w=14, h=14)
print("Rearranged Code Shape: ", codes.shape)

In [None]:
# Compute per-concept L1 norm across all spatial positions and batch
codes_flat = codes.abs().sum(dim=(1, 2))        # (32, D)
concept_strength = codes_flat.sum(dim=0)        # (D,)

topk = int((1 - logs['dead_features'][-1]) * concepts)
top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)

print(f"Top-{topk} Concepts by L1 Norm")

# first, lets use the simple overlay to get a broad sense of what's going on
for concept_id in top_concepts:
  print('Concept', concept_id.item())
  overlay_top_heatmaps(image, codes, concept_id=concept_id.item())
  os.makedirs(save_dir, exist_ok=True)
  filename = f"concept_{concept_id.item()}.png"
  filepath = os.path.join(save_dir, filename)
  plt.savefig(filepath, bbox_inches='tight', dpi=300)
  plt.close()  