#### Setup


In [7]:
import torch
import torchvision
from lib.data_handlers import Load_ImageNet100, Load_PACS
from overcomplete.models import DinoV2, ViT, ResNet, ViT_Large, SigLIP
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
import torch.nn as nn
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from lib.eval import evaluate_models
from lib.visualizer import visualize_concept
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

### Train a USAE

Define Models and Generate Their Activations

In [8]:
models = {
    "Dino1": DinoV2(device="cuda"),
    "Dino2": DinoV2(device="cuda")
}

activation_dir = "activations/ImageNet100_Dino_Dino"
image_loader = Load_ImageNet100(transform=None, batch_size=256, shuffle=True)


activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=196,   
    save_dir=activation_dir, 
    generate=False,  
    rearrange_string='n t d -> (n t) d'
    )

Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main
Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main


SAE Training

In [9]:
concepts = 786 * 8
epochs = 50
lr=3e-4
sample = next(iter(activations_dataloader))

In [10]:
SAEs = {}
optimizers = {}
schedulers = {}

for key, model in models.items():

    SAEs[key] = TopKSAE(
        sample[f"activations_{key}"].shape[-1],
        nb_concepts=concepts,
        top_k=10,
        device="cuda",
    )
    optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

    # Set up a Linear + Cosine Scheduler
    warmup_scheduler = LinearLR(
        optimizers[key], start_factor=1e-6 / 3e-4, end_factor=1.0, total_iters=25
    )
    cosine_scheduler = CosineAnnealingLR(optimizers[key], T_max=epochs, eta_min=1e-6)
    schedulers[key] = SequentialLR(
        optimizers[key],
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[25],
    )

criterion = nn.L1Loss(reduction="mean")  # change to mean reduction

In [6]:
train_usae(
    names=list(models.keys()),
    models=SAEs,
    dataloader=activations_dataloader,
    criterion=criterion,
    nb_epochs=epochs,
    optimizers=optimizers,
    schedulers=schedulers,
    device="cuda",
    seed=42,
)

Epoch 1/50: 100%|██████████| 508/508 [02:06<00:00,  4.02it/s, loss=1.42]



[Epoch 1] Loss: 747.4242 | Time: 126.28s | Dead Features: 0.0%


Epoch 2/50: 100%|██████████| 508/508 [02:02<00:00,  4.15it/s, loss=1.39]



[Epoch 2] Loss: 709.8987 | Time: 122.32s | Dead Features: 0.0%


Epoch 3/50:  62%|██████▏   | 313/508 [01:16<00:47,  4.10it/s, loss=1.37]


KeyboardInterrupt: 

In [None]:
model_state_dicts = {name: model.state_dict() for name, model in SAEs.items()}
torch.save(model_state_dicts, "./models/USAE_Dino_Dino_undercomplete_50epoch.pt")

In [None]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !


## Create a tensor to save a list of top activations
topk = int(0.08 * 500 * 7) 
selected_concepts = torch.zeros(topk+1)
activations = next(iter(activations_dataloader))

for i, (key, model) in enumerate(models.items()):
  sae = SAEs[key]
  Activations = activations[f'activations_{key}'].to(device)
  with torch.no_grad():
    pre_codes, codes = sae.encode(Activations.squeeze())
    
    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    
    codes_flat = codes.abs().sum(dim=(1, 2))        
    concept_strength = codes_flat.sum(dim=0)        
    top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)
    selected_concepts[i:i + topk] = top_concepts



# Overlay Top 20 for this model
for id in selected_concepts:
#for concept_id in range(50):
  concept_id = int(id.item())
  for key, model in models.items():
    sae = SAEs[key]
    Activations = activations[f'activations_{key}'].to(device)
    with torch.no_grad():
      pre_codes, codes = sae.encode(Activations.squeeze())

    codes = rearrange(codes, '(n w h) d -> n w h d', w=14, h=14)
    
    save_dir = f"results/usae_run10_vit_siglip/{key}_concepts"

    overlay_top_heatmaps(activations[f"images"].squeeze(), codes, concept_id=concept_id)
    os.makedirs(save_dir, exist_ok=True)
    filename = f"concept_{concept_id}_{key}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.close()  