#### Setup


In [1]:
import torch
import torchvision
from lib.data_handlers import Load_ImageNet100, Load_PACS
from overcomplete.models import DinoV2, ViT, ResNet
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
import torch.nn as nn
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
import timm
from torchsummary import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


### Train a USAE

Define Models and Generate Their Activations

In [2]:
import torch
from domainbed import algorithms

  ''' return proj_{B(h, \delta)}(adv_h), Euclidean projection to Euclidean ball'''


In [None]:

checkpoint_ERM = torch.load(r"C:\Users\sproj_ha\Desktop\vision_interp\domainbed\results\PACS\ERM_sketch\model.pkl")
checkpoint_IRM = torch.load(r"C:\Users\sproj_ha\Desktop\vision_interp\domainbed\results\PACS\IRM_sketch\model.pkl")
print(checkpoint_ERM["model_input_shape"])

model_ERM = algorithms.ERM(input_shape=checkpoint_ERM["model_input_shape"],hparams= checkpoint_ERM["model_hparams"],num_domains= checkpoint_ERM["model_num_domains"],num_classes= checkpoint_ERM["model_num_classes"])
model_ERM.load_state_dict(checkpoint_ERM["model_dict"])

model_IRM = algorithms.IRM(input_shape=checkpoint_IRM["model_input_shape"],hparams= checkpoint_IRM["model_hparams"],num_domains= checkpoint_IRM["model_num_domains"],num_classes= checkpoint_IRM["model_num_classes"])
model_IRM.load_state_dict(checkpoint_IRM["model_dict"])

In [None]:
models = {
    "ERM": model_ERM,
    "IRM": model_IRM
}

activation_dir = "activations/PACS_sketch_ResNet_ERM_IRM"
image_loader = Load_PACS(root_dir="C:/Users/sproj_ha/Desktop/vision_interp/domainbed/domainbed/data/PACS", batch_size=32)

activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=None,   
    save_dir=activation_dir, 
    generate=True,  
    rearrange_string='n c w h -> (n w h) c'
    )

SAE Visualization

In [None]:
concepts = 786 * 8
epochs = 50
lr=5e-5
sample = next(iter(activations_dataloader))

In [None]:
SAEs = {}
optimizers = {}
schedulers = {}

for key, model in models.items():
  SAEs[key] = TopKSAE(sample[f"activations_{key}"].shape[-1], nb_concepts=concepts, top_k=8, device='cuda')
  optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

  # Set up a Linear + Cosine Scheduler
  warmup_scheduler = LinearLR(optimizers[key], start_factor=1e-6 / 3e-4, end_factor=1.0, total_iters=10)
  cosine_scheduler = CosineAnnealingLR(optimizers[key], T_max=epochs, eta_min=1e-6)
  schedulers[key] = SequentialLR(optimizers[key], schedulers=[warmup_scheduler, cosine_scheduler], milestones=[10])


criterion = nn.MSELoss(reduction="sum")

train_usae(names=list(models.keys()),
           models=SAEs,
           dataloader=activations_dataloader,
           criterion=criterion,
           nb_epochs=epochs,
           optimizers=optimizers,
           schedulers=schedulers,
           device='cuda')

In [None]:
model_state_dicts = {name: model.state_dict() for name, model in SAEs.items()}
torch.save(model_state_dicts, "./models/USAEs_DinoViT_V3_100ep.pt")

In [None]:
# now the funny part, we have access to 4 functions that allow us to inspect the concepts,
# let use them to understand a bit more the top 3 concepts !


## Create a tensor to save a list of top activations
selected_concepts = torch.zeros(1260)
activations = next(iter(activations_dataloader))

for i, (key, model) in enumerate(models.items()):
  sae = SAEs[key]
  Activations = activations[f'activations_{key}'].to(device)
  with torch.no_grad():
    pre_codes, codes = sae.encode(Activations.squeeze())
    
    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    
    codes_flat = codes.abs().sum(dim=(1, 2))        
    concept_strength = codes_flat.sum(dim=0)        
    topk = 630
    top_concepts = torch.argsort(concept_strength, descending=True)[:topk].to(device)
    selected_concepts[i:i + topk] = top_concepts



# Overlay Top 20 for this model
for id in selected_concepts:
  concept_id = int(id.item())
  for key, model in models.items():
    sae = SAEs[key]
    Activations = activations[f'activations_{key}'].to(device)
    with torch.no_grad():
      pre_codes, codes = sae.encode(Activations.squeeze())

    codes = rearrange(codes, '(n w h) d -> n w h d', w=16, h=16)
    
    save_dir = f"results/usae_run5/{key}_concepts"

    overlay_top_heatmaps(activations[f"images"].squeeze(), codes, concept_id=concept_id)
    os.makedirs(save_dir, exist_ok=True)
    filename = f"concept_{concept_id}_{key}.png"
    filepath = os.path.join(save_dir, filename)
    plt.savefig(filepath, bbox_inches='tight', dpi=300)
    plt.close()  