#### Setup


In [26]:
import torch
from data_handlers import Load_ImageNet100
from overcomplete.models import DinoV2, ViT, ResNet
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from universal_trainer import train_usae
from activation_generator import Load_activation_dataloader
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Train a USAE

Define Models and Generate Their Activations

In [27]:
models = {
    "DinoV2": DinoV2(device="cuda"),
    "ViT": ViT(device="cuda")
}

activation_dir = "activations/ImageNet100_Dino_ViT"
image_loader = Load_ImageNet100(transform=None, batch_size=256, shuffle=True)

activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=256,   
    save_dir=activation_dir, 
    generate=False,  
    rearrange_string='n t d -> (n t) d'
    )

Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main


SAE Visualization

In [31]:
concepts = 50
epochs = 30
lr=1e-3
sample = next(iter(activations_dataloader))

In [32]:
print(sample['activations_DinoV2'].shape)
print(sample['activations_ViT'].shape)

torch.Size([1, 65536, 384])
torch.Size([1, 65536, 768])


In [33]:
SAEs = {}
optimizers = {}

for key, model in models.items():
  #print(f"Training SAE for {key} with shape {sample[f'activations_{key}'].shape[-1]}")
  SAEs[key] = TopKSAE(sample[f"activations_{key}"].shape[-1], nb_concepts=concepts, top_k=3, device='cuda')
  optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

criterion = nn.MSELoss()

train_usae(names=list(models.keys()),
           models=SAEs,
           dataloader=activations_dataloader,
           criterion=criterion,
           nb_epochs=epochs,
           optimizers=optimizers,
           device='cuda')

Epoch 1/30: 100%|██████████| 127/127 [00:54<00:00,  2.33it/s, loss=9.08]
Epoch 2/30: 100%|██████████| 127/127 [00:55<00:00,  2.27it/s, loss=8.89]
Epoch 3/30: 100%|██████████| 127/127 [00:57<00:00,  2.22it/s, loss=8.88]
Epoch 4/30: 100%|██████████| 127/127 [00:57<00:00,  2.19it/s, loss=8.82]
Epoch 5/30: 100%|██████████| 127/127 [00:57<00:00,  2.20it/s, loss=8.84]
Epoch 6/30: 100%|██████████| 127/127 [00:57<00:00,  2.21it/s, loss=8.86]
Epoch 7/30: 100%|██████████| 127/127 [00:57<00:00,  2.20it/s, loss=8.72]
Epoch 8/30: 100%|██████████| 127/127 [00:53<00:00,  2.38it/s, loss=8.88]
Epoch 9/30: 100%|██████████| 127/127 [00:47<00:00,  2.67it/s, loss=8.78]
Epoch 10/30: 100%|██████████| 127/127 [00:48<00:00,  2.62it/s, loss=8.82]
Epoch 11/30: 100%|██████████| 127/127 [00:50<00:00,  2.53it/s, loss=8.64]
Epoch 12/30: 100%|██████████| 127/127 [00:50<00:00,  2.54it/s, loss=8.69]
Epoch 13/30: 100%|██████████| 127/127 [00:51<00:00,  2.48it/s, loss=8.71]
Epoch 14/30: 100%|██████████| 127/127 [00:51<00

KeyboardInterrupt: 

In [None]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !


## Create a tensor to save a list of top activations
selected_concepts = torch.zeros(40)
activations = next(iter(activations_dataloader))

for i, (key, model) in enumerate(models.items()):
  sae = SAEs[key]
  Activations = activations[f'activations_{key}'].to(device)
  with torch.no_grad():
    pre_codes, codes = sae.encode(Activations.squeeze())
    

    save_dir = f"activations/usae_run2/{key}_concepts"
    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    

    codes_flat = codes.abs().sum(dim=(1, 2))        
    concept_strength = codes_flat.sum(dim=0)        
    topk = 20
    top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)
    selected_concepts[i:i + topk] = top_concepts



# Overlay Top 20 for this model



for id in selected_concepts:
  concept_id = int(id.item())
  for key, model in models.items():
    sae = SAEs[key]
    Activations = activations[f'activations_{key}'].to(device)
    with torch.no_grad():
      pre_codes, codes = sae.encode(Activations.squeeze())

    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    
    save_dir = f"results/usae_run3/{key}_concepts"

    overlay_top_heatmaps(activations[f"images"].squeeze(), codes, concept_id=concept_id)
    os.makedirs(save_dir, exist_ok=True)
    filename = f"concept_{concept_id}_{key}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.close()  