#### Setup


In [9]:
import torch
import torchvision
import timm
import os
import matplotlib.pyplot as plt
import torch.nn as nn
from einops import rearrange
from tqdm import tqdm
from overcomplete.models import DinoV2, ViT, ResNet, ViT_Large, SigLIP
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR


from lib.data_handlers import Load_ImageNet100, Load_PACS, Load_ImageNet100Sketch
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
from lib.eval import evaluate_models
from lib.visualizer import visualize_concepts
from lib.mlp import train_mlp, test_mlp, prune_weights, check_sparsity


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

#### Train MLP

Set Parameters

In [10]:
## Number of Concepts
concepts = 768 * 8

## Our Pretrained SAE
sae = TopKSAE(768, nb_concepts=concepts, top_k=16, device="cuda")
sae.load_state_dict(torch.load("./models/ViT_MLP.pt")['ViT'])


## Our Full ViT Model
vit_full = timm.create_model('vit_base_patch16_224', pretrained=True)

In [11]:
## Get the Train Loaders
inet_trainloader, image_dataset = Load_ImageNet100(transform=None, batch_size=64, shuffle=True, dataset_allow=True, train=True)
inet_testloader = Load_ImageNet100(transform=None, batch_size=256, shuffle=True, dataset_allow=False, train=False)
sketch_trainloader, sketch_testloader = Load_ImageNet100Sketch(transform=None, batch_size=4, shuffle=True, train=False)


trainloaders = {
    "INET" : inet_trainloader,
    "Sketch": sketch_trainloader
}

testloaders = {
    "INET": inet_testloader,
    "Sketch": sketch_testloader
}


## Setup MLPs and Optimizers
mlps = {}
optimizers = {}
schedulers = {}

for key in testloaders.keys():
    mlps[key] = nn.Linear(concepts, 100)
    optimizers[key] = torch.optim.Adam(params=mlps[key].parameters(), lr=1e-4)
    schedulers[key] = torch.optim.lr_scheduler.StepLR(optimizer=optimizers[key], gamma=0.5, step_size=1)

loss_fn = nn.CrossEntropyLoss()
epochs = {
    "INET": 3,
    "Sketch": 7
}
alpha = 1.0  #Sparsity Constraint

## Training

In [None]:
# for domain in trainloaders.keys():

#     print(f"====> Beginning Training for Domain: {domain}")
    
#     for epoch in range(epochs[domain]):
#         train_mlp(
#             image_loader=trainloaders[domain],
#             mlp=mlps[domain],
#             optimizer=optimizers[domain],
#             loss_fn=loss_fn,
#             vit_full=vit_full,
#             sae=sae,
#             internal_map=image_dataset.class_to_idx,
#             alpha=alpha,
#             epoch=epoch
#         )
#         if schedulers[domain] is not None:
#             schedulers[domain].step()

#     torch.save(mlps[domain].state_dict(), f"./models/MLPs/mlp_vit_alpha1_{domain}_7ep.pth")
#     print(f"Domain {domain} MLP Saved")

====> Beginning Training for Domain: INET


Epoch 0: 100%|██████████| 2032/2032 [15:28<00:00,  2.19batch/s, loss=8.5324] 


Epoch 0 finished. Average loss: 101.9885


Epoch 1: 100%|██████████| 2032/2032 [08:30<00:00,  3.98batch/s, loss=4.2550]


Epoch 1 finished. Average loss: 4.3454


Epoch 2: 100%|██████████| 2032/2032 [15:48<00:00,  2.14batch/s, loss=2.4734]


Epoch 2 finished. Average loss: 2.4281
Domain INET MLP Saved
====> Beginning Training for Domain: Sketch


Epoch 0: 100%|██████████| 1146/1146 [00:50<00:00, 22.61batch/s, loss=20.5407]


Epoch 0 finished. Average loss: 195.2722


Epoch 1: 100%|██████████| 1146/1146 [00:52<00:00, 21.98batch/s, loss=10.7439]


Epoch 1 finished. Average loss: 13.8518


Epoch 2: 100%|██████████| 1146/1146 [00:51<00:00, 22.06batch/s, loss=5.8571] 


Epoch 2 finished. Average loss: 7.3724


Epoch 3: 100%|██████████| 1146/1146 [00:49<00:00, 23.08batch/s, loss=3.9974] 


Epoch 3 finished. Average loss: 4.7621


Epoch 4: 100%|██████████| 1146/1146 [00:50<00:00, 22.76batch/s, loss=3.0687]


Epoch 4 finished. Average loss: 3.5995


Epoch 5: 100%|██████████| 1146/1146 [01:24<00:00, 13.53batch/s, loss=2.6366] 


Epoch 5 finished. Average loss: 3.0377


Epoch 6: 100%|██████████| 1146/1146 [00:45<00:00, 25.41batch/s, loss=2.9927] 

Epoch 6 finished. Average loss: 2.7231
Domain Sketch MLP Saved





## Loading

In [12]:
for domain in trainloaders.keys():
    mlps[domain].load_state_dict(torch.load(f"./models/MLPs/mlp_vit_alpha1_{domain}_7ep.pth")) 

Pre-Sparse Accuracy

In [13]:
for domain in trainloaders.keys():
    test_mlp(
            image_loader=testloaders[domain],
            mlp=mlps[domain],
            sae=sae,
            loss_fn=loss_fn,
            vit_full=vit_full,
            internal_map=image_dataset.class_to_idx
        )

Testing: 100%|██████████| 20/20 [00:28<00:00,  1.45s/batch, loss=0.1393]


Test finished. Avg loss: 0.2088, Accuracy: 95.67%


Testing: 100%|██████████| 128/128 [00:04<00:00, 29.69batch/s, loss=0.0299] 

Test finished. Avg loss: 0.5158, Accuracy: 89.71%





Sparsification

In [15]:
sparse_mlps = {}

for key in trainloaders.keys():
    sparse_mlps[key] = prune_weights(mlps[key], k=0.5)
    sparsity = check_sparsity(sparse_mlps[key])
    print(f"{key} MLP Sparsity: {sparsity:.2f}")


INET MLP Sparsity: 99.70
Sketch MLP Sparsity: 99.82


Post-Sparse Accuracy

In [10]:
for key in trainloaders.keys():
    print(f"Testing MLP: {key}")
    test_mlp(
        image_loader=testloaders[key],
        mlp=sparse_mlps[key],
        sae=sae,
        loss_fn=loss_fn,
        vit_full=vit_full,
        internal_map=image_dataset.class_to_idx
    )

Testing MLP: INET


Testing: 100%|██████████| 20/20 [00:36<00:00,  1.83s/batch, loss=0.1685]


Test finished. Avg loss: 0.2094, Accuracy: 95.57%
Testing MLP: Sketch


Testing: 100%|██████████| 128/128 [00:04<00:00, 28.27batch/s, loss=0.0000]

Test finished. Avg loss: 0.5742, Accuracy: 86.35%





## Visualization 

In [16]:
import torch

weights_ood = mlps["Sketch"].weight.detach().cpu()  # shape: (100, 6000)
weights_iid = mlps["INET"].weight.detach().cpu()  # shape: (100, 6000)

weights = {"OOD": weights_ood, "IID": weights_iid}

nonzero_indices = {}

# Collect nonzero indices per class
for tag, weight in weights.items():
    nonzero_indices[tag] = []
    print(f"\n================ {tag} ==================")
    for class_idx in range(weight.shape[0]):
        # Extract weights for this class
        class_weights = weight[class_idx]

        # Get nonzero indices
        nonzero_idx = torch.nonzero(class_weights, as_tuple=True)[0]

        # Sort by actual weight (positive first, then descending)
        sorted_vals, sorted_idx = torch.sort(
            class_weights[nonzero_idx], descending=True
        )
        sorted_features = nonzero_idx[sorted_idx].tolist()
        sorted_strengths = sorted_vals.tolist()

        nonzero_indices[tag].append(set(sorted_features))

        # Print nicely
        print(f"\nClass {class_idx}: {len(sorted_features)} non-zero features (sorted by weight, positives first)")
        for f_idx, f_val in zip(sorted_features, sorted_strengths):
            print(f"    Feature {f_idx}: weight={f_val:.4f}")

# Print common indices strength-wise
print("\n============== Common Non-Zero Indices ==============")
for class_idx in range(weights_ood.shape[0]):
    common = nonzero_indices["OOD"][class_idx].intersection(nonzero_indices["IID"][class_idx])

    # For strength display, take from OOD & IID separately
    ood_weights = weights["OOD"][class_idx]
    iid_weights = weights["IID"][class_idx]

    # Sort common features by average weight (favoring positive ones)
    common_sorted = sorted(
        list(common),
        key=lambda i: (ood_weights[i] + iid_weights[i]) / 2,
        reverse=True
    )

    print(f"\nClass {class_idx}: {len(common)} common non-zero features (sorted by avg weight)")
    for f_idx in common_sorted:
        print(
            f"    Feature {f_idx}: "
            f"OOD={ood_weights[f_idx]:.4f}, "
            f"IID={iid_weights[f_idx]:.4f}"
        )




Class 0: 10 non-zero features (sorted by weight, positives first)
    Feature 606: weight=0.0052
    Feature 4395: weight=0.0021
    Feature 1611: weight=0.0017
    Feature 1543: weight=0.0015
    Feature 2621: weight=0.0005
    Feature 3304: weight=0.0005
    Feature 3736: weight=0.0005
    Feature 1971: weight=0.0004
    Feature 757: weight=0.0002
    Feature 4261: weight=0.0001

Class 1: 5 non-zero features (sorted by weight, positives first)
    Feature 521: weight=0.0037
    Feature 1971: weight=0.0025
    Feature 4261: weight=0.0017
    Feature 5594: weight=0.0006
    Feature 136: weight=0.0003

Class 2: 14 non-zero features (sorted by weight, positives first)
    Feature 92: weight=0.0070
    Feature 2128: weight=0.0048
    Feature 448: weight=0.0045
    Feature 2998: weight=0.0039
    Feature 1971: weight=0.0029
    Feature 1705: weight=0.0011
    Feature 1758: weight=0.0009
    Feature 2621: weight=0.0009
    Feature 4955: weight=0.0005
    Feature 5461: weight=0.0005
    Fe

In [7]:
SAEs = {
    "ViT": sae
}

models = {
    "ViT": ViT(device="cuda")
}

In [None]:
SAEs = {
    "ViT": sae
}

models = {
    "ViT": ViT(device="cuda")
}

inet_trainloader, _= Load_ImageNet100(transform=None, batch_size=256, shuffle=True, dataset_allow=True, train=True)

activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=inet_trainloader,
    max_seq_len=196,
    save_dir=f"./activations/ViTMLP_INET",
    generate=True,
    rearrange_string='n t d -> (n t) d'
)

visualize_concepts(
    activation_loader=activations_dataloader,
    SAEs=SAEs,
    save_dir=f"./results/visualize2_VITMLP_INET",
    patch_width=14,
    num_concepts=concepts,
    n_images=4,
    abort_threshold=0.0,
)

In [8]:
sketch_trainloader = Load_ImageNet100Sketch(transform=None, batch_size=250, shuffle=True, train=True)

activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=sketch_trainloader,
    max_seq_len=196,
    save_dir=f"./activations/ViTMLP_sketch",
    generate=True,
    rearrange_string='n t d -> (n t) d'
)

visualize_concepts(
    activation_loader=activations_dataloader,
    SAEs=SAEs,
    save_dir=f"./results/visualize2_VITMLP_sketch",
    patch_width=14,
    num_concepts=concepts,
    n_images=8,
    abort_threshold=0.0,
)

Processing Batches:   0%|          | 0/21 [00:00<?, ?it/s]

<torch.utils.data.dataloader.DataLoader object at 0x0000022E26821AE0>


Processing Batches: 100%|██████████| 21/21 [00:45<00:00,  2.15s/it]


Generating Activations


100%|██████████| 21/21 [00:55<00:00,  2.62s/it]


Saving Concepts


100%|██████████| 6144/6144 [15:44<00:00,  6.51it/s]
