#### Setup


In [1]:
import torch
import torchvision
from lib.data_handlers import Load_ImageNet100, Load_PACS
from overcomplete.models import DinoV2, ViT, ResNet, ViT_Large, SigLIP
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
import torch.nn as nn
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
torch.cuda.empty_cache()

### Train a USAE

Define Models and Generate Their Activations

In [None]:
import torch.nn.functional as F

models = {
    "ViT": ViT(device="cuda"),
    "SigLIP": SigLIP(device="cuda")
}

activation_dir = "activations/ImageNet100_ViT_SigLIP"
image_loader = Load_ImageNet100(transform=None, batch_size=256, shuffle=True)


activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=196,   
    save_dir=activation_dir, 
    generate=True,  
    rearrange_string='n t d -> (n t) d'
    )

  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Processing Batches:  45%|████▌     | 231/508 [06:48<07:59,  1.73s/it]

SAE Visualization

In [None]:
concepts = 768 * 8
epochs = 40
lr=3e-4
sample = next(iter(activations_dataloader))

In [None]:
SAEs = {}
optimizers = {}
schedulers = {}

for key, model in models.items():
  
  SAEs[key] = TopKSAE(sample[f"activations_{key}"].shape[-1], nb_concepts=concepts, top_k=64, device='cuda')
  optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

  # Set up a Linear + Cosine Scheduler
  warmup_scheduler = LinearLR(optimizers[key], start_factor=1e-6 / 3e-4, end_factor=1.0, total_iters=25)
  cosine_scheduler = CosineAnnealingLR(optimizers[key], T_max=epochs, eta_min=1e-6)
  schedulers[key] = SequentialLR(optimizers[key], schedulers=[warmup_scheduler, cosine_scheduler], milestones=[25])


criterion = nn.L1Loss(reduction="mean") # change to mean reduction 

train_usae(names=list(models.keys()),
           models=SAEs,
           dataloader=activations_dataloader,
           criterion=criterion,
           nb_epochs=epochs,
           optimizers=optimizers,
           schedulers=schedulers,
           device='cuda')

Epoch 1/20: 100%|██████████| 508/508 [07:01<00:00,  1.21it/s, loss=1.31]



[Epoch 1] Loss: 694.6747 | Time: 421.47s | Dead Features: 0.0%
ViT Loss: [308.1587703227997]
DinoV2 Loss: [386.51594799757004]


Epoch 2/20: 100%|██████████| 508/508 [13:42<00:00,  1.62s/it, loss=1.27]



[Epoch 2] Loss: 676.8938 | Time: 822.66s | Dead Features: 71.3%
ViT Loss: [308.1587703227997, 293.40135046839714]
DinoV2 Loss: [386.51594799757004, 383.49242997169495]


Epoch 3/20: 100%|██████████| 508/508 [13:30<00:00,  1.60s/it, loss=1.25]



[Epoch 3] Loss: 671.6713 | Time: 810.70s | Dead Features: 87.6%
ViT Loss: [308.1587703227997, 293.40135046839714, 291.0020546913147]
DinoV2 Loss: [386.51594799757004, 383.49242997169495, 380.669264793396]


Epoch 4/20: 100%|██████████| 508/508 [13:29<00:00,  1.59s/it, loss=1.24]



[Epoch 4] Loss: 670.7984 | Time: 809.93s | Dead Features: 90.7%
ViT Loss: [308.1587703227997, 293.40135046839714, 291.0020546913147, 290.23292818665504]
DinoV2 Loss: [386.51594799757004, 383.49242997169495, 380.669264793396, 380.56543242931366]


Epoch 5/20: 100%|██████████| 508/508 [13:30<00:00,  1.59s/it, loss=1.25]



[Epoch 5] Loss: 671.3460 | Time: 810.23s | Dead Features: 91.5%
ViT Loss: [308.1587703227997, 293.40135046839714, 291.0020546913147, 290.23292818665504, 289.9391912519932]
DinoV2 Loss: [386.51594799757004, 383.49242997169495, 380.669264793396, 380.56543242931366, 381.40680783987045]


Epoch 6/20: 100%|██████████| 508/508 [13:28<00:00,  1.59s/it, loss=1.25]



[Epoch 6] Loss: 671.7623 | Time: 808.74s | Dead Features: 90.3%
ViT Loss: [308.1587703227997, 293.40135046839714, 291.0020546913147, 290.23292818665504, 289.9391912519932, 289.6635777056217]
DinoV2 Loss: [386.51594799757004, 383.49242997169495, 380.669264793396, 380.56543242931366, 381.40680783987045, 382.09873074293137]


Epoch 7/20:  36%|███▋      | 185/508 [04:53<07:47,  1.45s/it, loss=1.4] 

In [None]:
model_state_dicts = {name: model.state_dict() for name, model in SAEs.items()}
torch.save(model_state_dicts, "./models/USAE_ViT_SigLIP_20epoch.pt")

In [None]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !


## Create a tensor to save a list of top activations
topk = int(0.08 * 768 * 7) 
selected_concepts = torch.zeros(topk+1)
activations = next(iter(activations_dataloader))

for i, (key, model) in enumerate(models.items()):
  sae = SAEs[key]
  Activations = activations[f'activations_{key}'].to(device)
  with torch.no_grad():
    pre_codes, codes = sae.encode(Activations.squeeze())
    
    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    
    codes_flat = codes.abs().sum(dim=(1, 2))        
    concept_strength = codes_flat.sum(dim=0)        
    top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)
    selected_concepts[i:i + topk] = top_concepts



# Overlay Top 20 for this model
for id in selected_concepts:
#for concept_id in range(50):
  concept_id = int(id.item())
  for key, model in models.items():
    sae = SAEs[key]
    Activations = activations[f'activations_{key}'].to(device)
    with torch.no_grad():
      pre_codes, codes = sae.encode(Activations.squeeze())

    codes = rearrange(codes, '(n w h) d -> n w h d', w=14, h=14)
    
    save_dir = f"results/usae_run10_vit_siglip/{key}_concepts"

    overlay_top_heatmaps(activations[f"images"].squeeze(), codes, concept_id=concept_id)
    os.makedirs(save_dir, exist_ok=True)
    filename = f"concept_{concept_id}_{key}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.close()  