#### Setup


In [1]:
import torch
from data_handlers import Load_ImageNet100
from overcomplete.models import DinoV2, ViT, ResNet
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ImageNet

#### Train SAE

Set Parameters

In [15]:
model = ViT(device="cuda")
concepts = 384
epochs = 30
save_dir = "results/ViT_ImageNet100"
dataloader = Load_ImageNet100(transform=model.preprocess, batch_size=2000, shuffle=True)

Execute Flow

In [11]:
for image, _ in dataloader:
    image = image.to(device)
    Activations = model.forward_features(image)
    print("Raw Activation Shape: ", Activations.shape)
    Activations = rearrange(Activations, 'n t d -> (n t) d')
    print("Rearranged Activation Shape: ", Activations.shape)
    break

Raw Activation Shape:  torch.Size([2000, 196, 768])
Rearranged Activation Shape:  torch.Size([392000, 768])


In [16]:
sae = TopKSAE(Activations.shape[-1], nb_concepts=concepts, top_k=3, device='cuda')

dataloader = torch.utils.data.DataLoader(TensorDataset(Activations), batch_size=1024, shuffle=True)
optimizer = torch.optim.Adam(sae.parameters(), lr=5e-4)

def criterion(x, x_hat, pre_codes, codes, dictionary):
  mse = (x - x_hat).square().mean()
  return mse

logs = train_sae(sae, dataloader, criterion, optimizer, nb_epochs=epochs, device='cuda')

Epoch[1/30], Loss: 2.1239, R2: 0.4644, L0: 3.0000, Dead Features: 0.0%, Time: 1.4676 seconds
Epoch[2/30], Loss: 1.6053, R2: 0.5952, L0: 3.0000, Dead Features: 48.2%, Time: 1.5791 seconds
Epoch[3/30], Loss: 1.5564, R2: 0.6075, L0: 3.0000, Dead Features: 37.0%, Time: 1.4667 seconds
Epoch[4/30], Loss: 1.5343, R2: 0.6131, L0: 3.0000, Dead Features: 54.9%, Time: 1.4838 seconds
Epoch[5/30], Loss: 1.5231, R2: 0.6159, L0: 3.0000, Dead Features: 50.0%, Time: 1.4987 seconds
Epoch[6/30], Loss: 1.5171, R2: 0.6175, L0: 3.0000, Dead Features: 54.4%, Time: 1.6841 seconds
Epoch[7/30], Loss: 1.5113, R2: 0.6189, L0: 3.0000, Dead Features: 56.0%, Time: 1.7773 seconds
Epoch[8/30], Loss: 1.5066, R2: 0.6201, L0: 3.0000, Dead Features: 57.0%, Time: 1.7212 seconds
Epoch[9/30], Loss: 1.4991, R2: 0.6220, L0: 3.0000, Dead Features: 64.6%, Time: 1.8111 seconds
Epoch[10/30], Loss: 1.4912, R2: 0.6240, L0: 3.0000, Dead Features: 64.1%, Time: 1.7570 seconds
Epoch[11/30], Loss: 1.4882, R2: 0.6247, L0: 3.0000, Dead Fea

In [17]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !

sae = sae.eval()

with torch.no_grad():
  pre_codes, codes = sae.encode(Activations)


print("Raw Code Shape: ", codes.shape)
codes = rearrange(codes, '(n w h) d -> n w h d', w=14, h=14)
print("Rearranged Code Shape: ", codes.shape)


Raw Code Shape:  torch.Size([392000, 384])
Rearranged Code Shape:  torch.Size([2000, 14, 14, 384])


In [18]:
# Compute per-concept L1 norm across all spatial positions and batch
codes_flat = codes.abs().sum(dim=(1, 2))        # (32, D)
concept_strength = codes_flat.sum(dim=0)        # (D,)

topk = int((1 - logs['dead_features'][-1]) * concepts)
top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)

print(f"Top-{topk} Concepts by L1 Norm")

# first, lets use the simple overlay to get a broad sense of what's going on
for concept_id in top_concepts:
  print('Concept', concept_id.item())
  overlay_top_heatmaps(image, codes, concept_id=concept_id.item())
  os.makedirs(save_dir, exist_ok=True)
  filename = f"concept_{concept_id.item()}.png"
  filepath = os.path.join(save_dir, filename)
  plt.savefig(filepath, bbox_inches='tight', dpi=300)
  plt.close()  

Top-108 Concepts by L1 Norm
Concept 234
Concept 150
Concept 59
Concept 110
Concept 127
Concept 365
Concept 13
Concept 78
Concept 254
Concept 8
Concept 354
Concept 275
Concept 45
Concept 151
Concept 293
Concept 158
Concept 20
Concept 195
Concept 170
Concept 89
Concept 72
Concept 217
Concept 277
Concept 373
Concept 209
Concept 192
Concept 169
Concept 79
Concept 153
Concept 270
Concept 219
Concept 320
Concept 65
Concept 19
Concept 290
Concept 274
Concept 172
Concept 5
Concept 96
Concept 210
Concept 244
Concept 358
Concept 379
Concept 301
Concept 36
Concept 61
Concept 164
Concept 50
Concept 104
Concept 376
Concept 356
Concept 179
Concept 364
Concept 327
Concept 68
Concept 264
Concept 93
Concept 49
Concept 207
Concept 242
Concept 12
Concept 310
Concept 269
Concept 4
Concept 148
Concept 218
Concept 157
Concept 258
Concept 268
Concept 241
Concept 40
Concept 294
Concept 198
Concept 44
Concept 307
Concept 149
Concept 42
Concept 58
Concept 90
Concept 299
Concept 176
Concept 288
Concept 206
Conce