#### Setup


In [1]:
import torch
import torchvision
from lib.data_handlers import Load_ImageNet100, Load_PACS
from overcomplete.models import DinoV2, ViT, ResNet, ViT_Large, SigLIP
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
import torch.nn as nn
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from lib.eval import evaluate_models

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

### Train a USAE

Define Models and Generate Their Activations

In [2]:
models = {
    "Dino1": DinoV2(device="cuda"),
    "Dino2": DinoV2(device="cuda")
}

activation_dir = "activations/ImageNet100_Dino_Dino"
image_loader = Load_ImageNet100(transform=None, batch_size=256, shuffle=True)


activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=196,   
    save_dir=activation_dir, 
    generate=True,  
    rearrange_string='n t d -> (n t) d'
    )

Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main
Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main
Processing Batches: 100%|██████████| 508/508 [13:31<00:00,  1.60s/it]


SAE Training

In [3]:
concepts = 768 * 8
epochs = 50
lr=3e-4
sample = next(iter(activations_dataloader))

In [4]:
SAEs = {}
optimizers = {}
schedulers = {}

for key, model in models.items():

    SAEs[key] = TopKSAE(
        sample[f"activations_{key}"].shape[-1],
        nb_concepts=concepts,
        top_k=64,
        device="cuda",
    )
    optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

    # Set up a Linear + Cosine Scheduler
    warmup_scheduler = LinearLR(
        optimizers[key], start_factor=1e-6 / 3e-4, end_factor=1.0, total_iters=25
    )
    cosine_scheduler = CosineAnnealingLR(optimizers[key], T_max=epochs, eta_min=1e-6)
    schedulers[key] = SequentialLR(
        optimizers[key],
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[25],
    )

criterion = nn.L1Loss(reduction="mean")  # change to mean reduction

train_usae(
    names=list(models.keys()),
    models=SAEs,
    dataloader=activations_dataloader,
    criterion=criterion,
    nb_epochs=epochs,
    optimizers=optimizers,
    schedulers=schedulers,
    device="cuda",
    seed=42,
)

Epoch 1/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.29]



[Epoch 1] Loss: 723.6816 | Time: 134.71s | Dead Features: 0.0%


Epoch 2/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.24]



[Epoch 2] Loss: 637.6887 | Time: 135.56s | Dead Features: 0.0%


Epoch 3/50: 100%|██████████| 508/508 [02:16<00:00,  3.73it/s, loss=1.21]



[Epoch 3] Loss: 617.3701 | Time: 136.08s | Dead Features: 0.0%


Epoch 4/50: 100%|██████████| 508/508 [02:16<00:00,  3.72it/s, loss=1.2] 



[Epoch 4] Loss: 609.2689 | Time: 136.44s | Dead Features: 0.0%


Epoch 5/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 5] Loss: 605.3486 | Time: 135.93s | Dead Features: 0.0%


Epoch 6/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.19]



[Epoch 6] Loss: 603.2971 | Time: 135.52s | Dead Features: 0.0%


Epoch 7/50: 100%|██████████| 508/508 [02:16<00:00,  3.72it/s, loss=1.19]



[Epoch 7] Loss: 602.1533 | Time: 136.47s | Dead Features: 0.0%


Epoch 8/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.19]



[Epoch 8] Loss: 601.5103 | Time: 135.51s | Dead Features: 0.0%


Epoch 9/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.19]



[Epoch 9] Loss: 601.2236 | Time: 135.48s | Dead Features: 0.0%


Epoch 10/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 10] Loss: 601.1666 | Time: 135.69s | Dead Features: 0.0%


Epoch 11/50: 100%|██████████| 508/508 [02:16<00:00,  3.72it/s, loss=1.19]



[Epoch 11] Loss: 601.2270 | Time: 136.70s | Dead Features: 0.0%


Epoch 12/50: 100%|██████████| 508/508 [02:16<00:00,  3.73it/s, loss=1.19]



[Epoch 12] Loss: 601.3700 | Time: 136.11s | Dead Features: 0.0%


Epoch 13/50: 100%|██████████| 508/508 [02:16<00:00,  3.73it/s, loss=1.19]



[Epoch 13] Loss: 601.4660 | Time: 136.34s | Dead Features: 0.0%


Epoch 14/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 14] Loss: 601.6959 | Time: 135.77s | Dead Features: 0.0%


Epoch 15/50: 100%|██████████| 508/508 [02:16<00:00,  3.74it/s, loss=1.19]



[Epoch 15] Loss: 601.8830 | Time: 136.01s | Dead Features: 0.0%


Epoch 16/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 16] Loss: 602.0599 | Time: 135.79s | Dead Features: 0.0%


Epoch 17/50: 100%|██████████| 508/508 [02:16<00:00,  3.72it/s, loss=1.19]



[Epoch 17] Loss: 602.1846 | Time: 136.60s | Dead Features: 0.0%


Epoch 18/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.19]



[Epoch 18] Loss: 602.3553 | Time: 135.48s | Dead Features: 0.0%


Epoch 19/50: 100%|██████████| 508/508 [02:16<00:00,  3.73it/s, loss=1.19]



[Epoch 19] Loss: 602.4556 | Time: 136.20s | Dead Features: 0.0%


Epoch 20/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.19]



[Epoch 20] Loss: 602.5566 | Time: 135.51s | Dead Features: 0.0%


Epoch 21/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.19]



[Epoch 21] Loss: 602.6189 | Time: 134.84s | Dead Features: 0.0%


Epoch 22/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.19]



[Epoch 22] Loss: 602.7692 | Time: 134.76s | Dead Features: 0.0%


Epoch 23/50: 100%|██████████| 508/508 [02:13<00:00,  3.80it/s, loss=1.19]



[Epoch 23] Loss: 602.8837 | Time: 133.73s | Dead Features: 0.0%


Epoch 24/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 24] Loss: 603.0783 | Time: 135.71s | Dead Features: 0.0%


Epoch 25/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 25] Loss: 603.1920 | Time: 135.93s | Dead Features: 0.0%


Epoch 26/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.19]



[Epoch 26] Loss: 603.3541 | Time: 135.43s | Dead Features: 0.0%


Epoch 27/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.19]



[Epoch 27] Loss: 603.4827 | Time: 134.71s | Dead Features: 0.0%


Epoch 28/50: 100%|██████████| 508/508 [02:14<00:00,  3.76it/s, loss=1.19]



[Epoch 28] Loss: 603.5984 | Time: 134.98s | Dead Features: 0.0%


Epoch 29/50: 100%|██████████| 508/508 [02:15<00:00,  3.76it/s, loss=1.2] 



[Epoch 29] Loss: 603.7560 | Time: 135.09s | Dead Features: 0.0%


Epoch 30/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.19]



[Epoch 30] Loss: 603.8077 | Time: 134.76s | Dead Features: 0.1%


Epoch 31/50: 100%|██████████| 508/508 [02:14<00:00,  3.78it/s, loss=1.2] 



[Epoch 31] Loss: 603.8772 | Time: 134.47s | Dead Features: 0.1%


Epoch 32/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 32] Loss: 603.9247 | Time: 135.70s | Dead Features: 0.1%


Epoch 33/50: 100%|██████████| 508/508 [02:15<00:00,  3.76it/s, loss=1.19]



[Epoch 33] Loss: 603.9277 | Time: 135.14s | Dead Features: 0.1%


Epoch 34/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.2] 



[Epoch 34] Loss: 603.9273 | Time: 134.58s | Dead Features: 0.3%


Epoch 35/50: 100%|██████████| 508/508 [02:14<00:00,  3.78it/s, loss=1.19]



[Epoch 35] Loss: 603.8940 | Time: 134.34s | Dead Features: 0.2%


Epoch 36/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 36] Loss: 603.9122 | Time: 135.65s | Dead Features: 0.4%


Epoch 37/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.19]



[Epoch 37] Loss: 603.9615 | Time: 134.89s | Dead Features: 0.3%


Epoch 38/50: 100%|██████████| 508/508 [02:16<00:00,  3.73it/s, loss=1.19]



[Epoch 38] Loss: 603.9090 | Time: 136.26s | Dead Features: 0.6%


Epoch 39/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.19]



[Epoch 39] Loss: 603.8998 | Time: 134.74s | Dead Features: 0.5%


Epoch 40/50: 100%|██████████| 508/508 [02:14<00:00,  3.78it/s, loss=1.2] 



[Epoch 40] Loss: 603.9896 | Time: 134.41s | Dead Features: 0.7%


Epoch 41/50: 100%|██████████| 508/508 [02:15<00:00,  3.76it/s, loss=1.2] 



[Epoch 41] Loss: 603.9084 | Time: 135.18s | Dead Features: 0.9%


Epoch 42/50: 100%|██████████| 508/508 [02:14<00:00,  3.78it/s, loss=1.19]



[Epoch 42] Loss: 604.0214 | Time: 134.41s | Dead Features: 0.9%


Epoch 43/50: 100%|██████████| 508/508 [02:14<00:00,  3.78it/s, loss=1.2] 



[Epoch 43] Loss: 603.9670 | Time: 134.22s | Dead Features: 0.8%


Epoch 44/50: 100%|██████████| 508/508 [02:15<00:00,  3.76it/s, loss=1.19]



[Epoch 44] Loss: 604.0427 | Time: 135.17s | Dead Features: 1.0%


Epoch 45/50: 100%|██████████| 508/508 [02:15<00:00,  3.74it/s, loss=1.19]



[Epoch 45] Loss: 604.0363 | Time: 135.94s | Dead Features: 1.0%


Epoch 46/50: 100%|██████████| 508/508 [02:14<00:00,  3.78it/s, loss=1.19]



[Epoch 46] Loss: 604.1035 | Time: 134.55s | Dead Features: 1.1%


Epoch 47/50: 100%|██████████| 508/508 [02:16<00:00,  3.73it/s, loss=1.19]



[Epoch 47] Loss: 604.2782 | Time: 136.11s | Dead Features: 1.0%


Epoch 48/50: 100%|██████████| 508/508 [02:15<00:00,  3.76it/s, loss=1.2] 



[Epoch 48] Loss: 604.4348 | Time: 135.10s | Dead Features: 1.3%


Epoch 49/50: 100%|██████████| 508/508 [02:15<00:00,  3.75it/s, loss=1.2] 



[Epoch 49] Loss: 604.6152 | Time: 135.58s | Dead Features: 1.2%


Epoch 50/50: 100%|██████████| 508/508 [02:14<00:00,  3.77it/s, loss=1.2] 


[Epoch 50] Loss: 604.8196 | Time: 134.66s | Dead Features: 1.2%





In [5]:
model_state_dicts = {name: model.state_dict() for name, model in SAEs.items()}
torch.save(model_state_dicts, "./models/USAE_Dino_Dino_50epoch.pt")

In [6]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !


## Create a tensor to save a list of top activations
topk = int(0.08 * 768 * 7) 
selected_concepts = torch.zeros(topk+1)
activations = next(iter(activations_dataloader))

for i, (key, model) in enumerate(models.items()):
  sae = SAEs[key]
  Activations = activations[f'activations_{key}'].to(device)
  with torch.no_grad():
    pre_codes, codes = sae.encode(Activations.squeeze())
    
    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    
    codes_flat = codes.abs().sum(dim=(1, 2))        
    concept_strength = codes_flat.sum(dim=0)        
    top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)
    selected_concepts[i:i + topk] = top_concepts



# Overlay Top 20 for this model
for id in selected_concepts:
#for concept_id in range(50):
  concept_id = int(id.item())
  for key, model in models.items():
    sae = SAEs[key]
    Activations = activations[f'activations_{key}'].to(device)
    with torch.no_grad():
      pre_codes, codes = sae.encode(Activations.squeeze())

    codes = rearrange(codes, '(n w h) d -> n w h d', w=14, h=14)
    
    save_dir = f"results/usae_run11_dino_dino/{key}_concepts"

    overlay_top_heatmaps(activations[f"images"].squeeze(), codes, concept_id=concept_id)
    os.makedirs(save_dir, exist_ok=True)
    filename = f"concept_{concept_id}_{key}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.close()  