In [1]:
import torch
import torchvision
from lib.data_handlers import Load_ImageNet100, Load_PACS
from overcomplete.models import DinoV2, ViT, ResNet
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
import torch.nn as nn
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
import timm
from torchsummary import summary
from lib.visualizer import visualize_concepts
from domainbed import algorithms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint_ERM = torch.load(r"C:\Users\sproj_ha\Desktop\vision_interp\models\PACS\ERM_sketch_dinov2\model.pkl")
checkpoint_IRM = torch.load(r"C:\Users\sproj_ha\Desktop\vision_interp\models\PACS\IRM_sketch_dinov2\model.pkl")

model_ERM = algorithms.ERM(input_shape=checkpoint_ERM["model_input_shape"],hparams= checkpoint_ERM["model_hparams"],num_domains= checkpoint_ERM["model_num_domains"],num_classes= checkpoint_ERM["model_num_classes"])
model_IRM = algorithms.IRM(input_shape=checkpoint_IRM["model_input_shape"],hparams= checkpoint_IRM["model_hparams"],num_domains= checkpoint_IRM["model_num_domains"],num_classes= checkpoint_IRM["model_num_classes"])

model_ERM.load_state_dict(checkpoint_ERM["model_dict"])
model_IRM.load_state_dict(checkpoint_IRM["model_dict"])

model_ERM.to(device)
model_IRM.to(device)

x = torch.ones(1, 3, 224, 224).to(device)
x = model_ERM.featurizer.network.forward_features(x)['x_norm_patchtokens']
y = torch.ones(1, 3, 224, 224).to(device)
y = model_IRM.featurizer.network.forward_features(y)['x_norm_clstoken']
print(x.shape, y.shape)

Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main
Using cache found in C:\Users\sproj_ha/.cache\torch\hub\facebookresearch_dinov2_main


torch.Size([1, 256, 768]) torch.Size([1, 768])


In [3]:
models = {
    "ERM": model_ERM,
    "IRM": model_IRM
}

activation_dir = "activations/PACS_sketch_DinoV2_ERM_IRM"
image_loader = Load_PACS(root_dir="C:/Users/sproj_ha/Desktop/vision_interp/datasets/data/PACS", batch_size=32)

activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=None,
    save_dir=activation_dir,
    generate=False,
    rearrange_string='n t d -> (n t) d'
    )

SAE Visualization

In [4]:
concepts = 768 * 8
epochs = 250
lr=1e-4
sample = next(iter(activations_dataloader))

In [5]:
SAEs = {}
optimizers = {}
schedulers = {}

for key, model in models.items():
  SAEs[key] = TopKSAE(sample[f"activations_{key}"].shape[-1], nb_concepts=concepts, top_k=16, device='cuda')
  optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

  warmup_scheduler = LinearLR(optimizers[key], start_factor=1e-6 / 3e-4, end_factor=1.0, total_iters=10)
  cosine_scheduler = CosineAnnealingLR(optimizers[key], T_max=epochs, eta_min=1e-6)
  schedulers[key] = SequentialLR(optimizers[key], schedulers=[warmup_scheduler, cosine_scheduler], milestones=[10])


criterion = nn.L1Loss(reduction="mean")

In [None]:
train_usae(names=list(models.keys()),
           models=SAEs,
           dataloader=activations_dataloader,
           criterion=criterion,
           nb_epochs=epochs,
           optimizers=optimizers,
           schedulers=schedulers,
           device='cuda')

In [None]:
# model_state_dicts = {name: model.state_dict() for name, model in SAEs.items()}
# torch.save(model_state_dicts, "./models/USAEs_DinoV2_250ep_ERM_IRM.pt")

In [6]:
model_state_dicts = torch.load("./models/USAEs_DinoV2_250ep_ERM_IRM.pt")
for name, model in SAEs.items():
    print(model.load_state_dict(model_state_dicts[name]))

<All keys matched successfully>
<All keys matched successfully>


In [7]:
visualize_concepts(
        activation_loader=activations_dataloader,
        SAEs=SAEs,
        num_concepts=concepts,
        n_images=4,
        patch_width=16, 
        save_dir="results/visualizer_Dino_ERM_IRM/",
        abort_threshold=0.0,
    )

100%|██████████| 313/313 [11:21<00:00,  2.18s/it]


6144
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipp

## Idea 1: Model Focused

#### Hypothesis: 
More generalizable models will have more domain-invariant concepts.  

#### Experiment: 
Take M different models (trained on different DG techniques), train a Universal Sparse Autoencoder to obtain a shared concept space. Identify a shared concept. Then identify if the models still activate that concept on similar OOD data (for e.g If the concept is the leg of the dog, then that concept should also be activated on a sketch of a deg). This will determine if the concept is domain-invariant for a particular model. We can then attribute this invariance to the DG method (e.g DAN, MMD etc used).

#### End Goal: 
A new metric/framework/insight for DG method evaluation.


## Idea 2: Concept Focused

#### Hypothesis:
Spurious Concepts will contribute more to classification decision for a less generalizable model.
Domain Invariant Concepts will contribute more to classification for a generalizable model.

#### Experiment:
Take M different models (trained on different DG techniques), train a Universal Sparse Autoencoder to obtain a shared concept space. Then identify which concepts influence a classification decision using some Mutual Information Proxy. Then identify which concepts are spurious (e.g capturing grass for a cow) by comparing label-concept correlations across multiple domains. 

#### End Goal:
Explore the relationship of spurious and invariant concepts with softmax distributions for different DG techniques. 


## Idea 3: Single Model
#### Hypothesis
There exists a tradeoff between Discrimination & Generalization

#### Experiment
Take one model, use a SAE to obtain a concept space. First identify which parts of the feature map led to a classification decision. Then identify what the concepts were activated by that part of the feature map. Finally, identify through perturbations, how the generalizable and discriminative the model is (through concept invariance, and classification accuracy respectively).

#### End Goal:
Plug and Play system where you input model, and it tells you how generalizable it is.
