#### Setup


In [1]:
import torch
from data_handlers import Load_ImageNet100
from overcomplete.models import DinoV2, ViT, ResNet
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## ImageNet

#### Train SAE

Set Parameters

In [2]:
model = ResNet(device="cuda")
concepts = 2048*5
epochs = 20
save_dir = "results/ResNet_ImageNet100_BatchCollapse"
dataloader = Load_ImageNet100(transform=model.preprocess, batch_size=2000, shuffle=True)

  from .autonotebook import tqdm as notebook_tqdm


Execute Flow

In [3]:
for image, _ in dataloader:
    image = image.to(device)
    Activations = model.forward_features(image)
    print("Raw Activation Shape: ", Activations.shape)
    Activations = rearrange(Activations, 'n c w h -> (n w h) c')
    print("Rearranged Activation Shape: ", Activations.shape)
    break

Raw Activation Shape:  torch.Size([2000, 2048, 7, 7])
Rearranged Activation Shape:  torch.Size([98000, 2048])


In [4]:
sae = TopKSAE(Activations.shape[-1], nb_concepts=concepts, top_k=5, device='cuda')

dataloader = torch.utils.data.DataLoader(TensorDataset(Activations), batch_size=1024, shuffle=True)
optimizer = torch.optim.Adam(sae.parameters(), lr=5e-4)

def criterion(x, x_hat, pre_codes, codes, dictionary):
  mse = (x - x_hat).square().mean()
  return mse

logs = train_sae(sae, dataloader, criterion, optimizer, nb_epochs=epochs, device='cuda')

Epoch[1/20], Loss: 0.1700, R2: 0.1113, L0: 5.0000, Dead Features: 0.7%, Time: 2.2167 seconds
Epoch[2/20], Loss: 0.1179, R2: 0.3830, L0: 5.0000, Dead Features: 98.7%, Time: 2.0666 seconds
Epoch[3/20], Loss: 0.0971, R2: 0.4914, L0: 5.0000, Dead Features: 96.7%, Time: 1.9005 seconds
Epoch[4/20], Loss: 0.0908, R2: 0.5244, L0: 5.0000, Dead Features: 97.0%, Time: 1.9680 seconds
Epoch[5/20], Loss: 0.0877, R2: 0.5407, L0: 5.0000, Dead Features: 97.9%, Time: 1.9141 seconds
Epoch[6/20], Loss: 0.0852, R2: 0.5544, L0: 5.0000, Dead Features: 98.1%, Time: 2.0335 seconds
Epoch[7/20], Loss: 0.0828, R2: 0.5661, L0: 5.0000, Dead Features: 98.1%, Time: 1.8890 seconds
Epoch[8/20], Loss: 0.0808, R2: 0.5771, L0: 5.0000, Dead Features: 98.1%, Time: 1.8947 seconds
Epoch[9/20], Loss: 0.0792, R2: 0.5850, L0: 5.0000, Dead Features: 97.8%, Time: 1.9052 seconds
Epoch[10/20], Loss: 0.0779, R2: 0.5919, L0: 5.0000, Dead Features: 97.7%, Time: 2.0414 seconds
Epoch[11/20], Loss: 0.0768, R2: 0.5973, L0: 5.0000, Dead Fea

In [5]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !

sae = sae.eval()

with torch.no_grad():
  pre_codes, codes = sae.encode(Activations)


print("Raw Code Shape: ", codes.shape)
codes = rearrange(codes, '(n w h) c -> n w h c', w=7, h=7)
print("Rearranged Code Shape: ", codes.shape)


Raw Code Shape:  torch.Size([98000, 10240])
Rearranged Code Shape:  torch.Size([2000, 7, 7, 10240])


In [6]:
# Compute per-concept L1 norm across all spatial positions and batch
codes_flat = codes.abs().sum(dim=(1, 2))        # (32, D)
concept_strength = codes_flat.sum(dim=0)        # (D,)

topk = int((1 - logs['dead_features'][-1]) * concepts)
top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)

print(f"Top-{topk} Concepts by L1 Norm")

# first, lets use the simple overlay to get a broad sense of what's going on
for concept_id in top_concepts:
  print('Concept', concept_id.item())
  overlay_top_heatmaps(image, codes, concept_id=concept_id.item())
  os.makedirs(save_dir, exist_ok=True)
  filename = f"concept_{concept_id.item()}.png"
  filepath = os.path.join(save_dir, filename)
  plt.savefig(filepath, bbox_inches='tight', dpi=300)
  plt.close()  

Top-394 Concepts by L1 Norm
Concept 3473
Concept 8566
Concept 7744
Concept 10084
Concept 5120
Concept 7208
Concept 8460
Concept 6873
Concept 9229
Concept 3277
Concept 4669
Concept 5855
Concept 641
Concept 9758
Concept 8416
Concept 4609
Concept 3515
Concept 2735
Concept 515
Concept 7374
Concept 1116
Concept 1473
Concept 9447
Concept 907
Concept 3199
Concept 7279
Concept 6887
Concept 2395
Concept 3014
Concept 431
Concept 7201
Concept 2048
Concept 9267
Concept 797
Concept 1853
Concept 3165
Concept 9129
Concept 5575
Concept 1543
Concept 6476
Concept 980
Concept 10039
Concept 3044
Concept 3459
Concept 9660
Concept 6490
Concept 330
Concept 10220
Concept 9375
Concept 10169
Concept 4010
Concept 6421
Concept 1393
Concept 4049
Concept 2454
Concept 1585
Concept 7946
Concept 7179
Concept 7700
Concept 7380
Concept 216
Concept 2937
Concept 9600
Concept 5831
Concept 1145
Concept 5194
Concept 1725
Concept 2106
Concept 106
Concept 618
Concept 997
Concept 9224
Concept 9870
Concept 3830
Concept 8443
Conc