#### Setup


In [6]:
import torch
import torchvision
from lib.data_handlers import Load_ImageNet100, Load_PACS, Load_ImageNet100Sketch
from overcomplete.models import DinoV2, ViT, ResNet, ViT_Large, SigLIP
from torch.utils.data import DataLoader, TensorDataset
from overcomplete.sae import TopKSAE, train_sae
from overcomplete.visualization import (overlay_top_heatmaps, evidence_top_images, zoom_top_images, contour_top_image)
import os
import matplotlib.pyplot as plt
from einops import rearrange
from lib.universal_trainer import train_usae
from lib.activation_generator import Load_activation_dataloader
import torch.nn as nn
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR
from lib.eval import evaluate_models
from lib.visualizer import visualize_concepts
from tqdm import tqdm
import timm


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()


## ImageNet

#### Train SAE

In [7]:
models = {
    "ViT": ViT(device="cuda")
}

activations_dir="/activations/ViT_Solo"

image_loader = Load_ImageNet100(transform=None, batch_size=256, shuffle=True)

activations_dataloader = Load_activation_dataloader(
    models=models,
    image_dataloader=image_loader,
    max_seq_len=196,
    save_dir=activations_dir,
    generate=False,
    rearrange_string='n t d -> (n t) d'
)

Set Parameters

In [8]:
concepts = 768 * 8
epochs = 100
lr=3e-4
sample = next(iter(activations_dataloader))

Execute Flow

In [9]:
SAEs = {}
optimizers = {}
schedulers = {}

for key, model in models.items():

    SAEs[key] = TopKSAE(
        sample[f"activations_{key}"].shape[-1],
        nb_concepts=concepts,
        top_k=16,
        device="cuda",
    )
    optimizers[key] = torch.optim.Adam(SAEs[key].parameters(), lr=lr)

    # Set up a Linear + Cosine Scheduler
    warmup_scheduler = LinearLR(
        optimizers[key], start_factor=1e-6 / 3e-4, end_factor=1.0, total_iters=10
    )
    cosine_scheduler = CosineAnnealingLR(optimizers[key], T_max=epochs, eta_min=1e-6)
    schedulers[key] = SequentialLR(
        optimizers[key],
        schedulers=[warmup_scheduler, cosine_scheduler],
        milestones=[25],
    )

criterion = nn.L1Loss(reduction="mean")  # change to mean reduction

In [6]:
train_usae(
    names=list(models.keys()),
    models=SAEs,
    dataloader=activations_dataloader,
    criterion=criterion,
    nb_epochs=epochs,
    optimizers=optimizers,
    schedulers=schedulers,
    device="cuda",
    seed=42,
)

Epoch 1/100: 100%|██████████| 508/508 [02:28<00:00,  3.41it/s, loss=0.47] 



[Epoch 1] Loss: 275.5159 | Time: 148.84s | Dead Features: 0.0%


Epoch 2/100: 100%|██████████| 508/508 [02:28<00:00,  3.42it/s, loss=0.43] 



[Epoch 2] Loss: 230.0995 | Time: 148.54s | Dead Features: 0.1%


Epoch 3/100: 100%|██████████| 508/508 [02:29<00:00,  3.40it/s, loss=0.436]



[Epoch 3] Loss: 220.1320 | Time: 149.27s | Dead Features: 1.1%


Epoch 4/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.431]



[Epoch 4] Loss: 215.0126 | Time: 152.71s | Dead Features: 3.1%


Epoch 5/100: 100%|██████████| 508/508 [02:33<00:00,  3.31it/s, loss=0.426]



[Epoch 5] Loss: 211.4538 | Time: 153.30s | Dead Features: 7.3%


Epoch 6/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.413]



[Epoch 6] Loss: 208.7933 | Time: 151.73s | Dead Features: 12.2%


Epoch 7/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.407]



[Epoch 7] Loss: 206.6270 | Time: 151.06s | Dead Features: 16.0%


Epoch 8/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.404]



[Epoch 8] Loss: 204.9744 | Time: 151.85s | Dead Features: 18.4%


Epoch 9/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.408]



[Epoch 9] Loss: 203.5284 | Time: 151.55s | Dead Features: 23.5%


Epoch 10/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.395]



[Epoch 10] Loss: 202.3575 | Time: 151.57s | Dead Features: 27.2%


Epoch 11/100: 100%|██████████| 508/508 [02:31<00:00,  3.34it/s, loss=0.388]



[Epoch 11] Loss: 201.2594 | Time: 151.91s | Dead Features: 32.6%


Epoch 12/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.394]



[Epoch 12] Loss: 200.3950 | Time: 151.48s | Dead Features: 35.7%


Epoch 13/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.378]



[Epoch 13] Loss: 199.5158 | Time: 151.03s | Dead Features: 41.6%


Epoch 14/100: 100%|██████████| 508/508 [02:29<00:00,  3.40it/s, loss=0.396]



[Epoch 14] Loss: 198.8250 | Time: 149.62s | Dead Features: 43.8%


Epoch 15/100: 100%|██████████| 508/508 [02:28<00:00,  3.43it/s, loss=0.4]  



[Epoch 15] Loss: 198.1498 | Time: 148.16s | Dead Features: 47.9%


Epoch 16/100: 100%|██████████| 508/508 [02:29<00:00,  3.41it/s, loss=0.384]



[Epoch 16] Loss: 197.5553 | Time: 149.15s | Dead Features: 50.8%


Epoch 17/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.39] 



[Epoch 17] Loss: 196.9506 | Time: 151.79s | Dead Features: 53.3%


Epoch 18/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.387]



[Epoch 18] Loss: 196.3991 | Time: 151.33s | Dead Features: 56.9%


Epoch 19/100: 100%|██████████| 508/508 [02:31<00:00,  3.34it/s, loss=0.375]



[Epoch 19] Loss: 195.9306 | Time: 151.89s | Dead Features: 58.1%


Epoch 20/100: 100%|██████████| 508/508 [02:31<00:00,  3.34it/s, loss=0.393]



[Epoch 20] Loss: 195.4657 | Time: 151.99s | Dead Features: 60.1%


Epoch 21/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.378]



[Epoch 21] Loss: 195.1071 | Time: 151.09s | Dead Features: 61.6%


Epoch 22/100: 100%|██████████| 508/508 [02:30<00:00,  3.37it/s, loss=0.388]



[Epoch 22] Loss: 194.6770 | Time: 150.57s | Dead Features: 63.3%


Epoch 23/100: 100%|██████████| 508/508 [02:30<00:00,  3.37it/s, loss=0.382]



[Epoch 23] Loss: 194.3863 | Time: 150.56s | Dead Features: 64.0%


Epoch 24/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.381]



[Epoch 24] Loss: 194.0021 | Time: 152.54s | Dead Features: 65.2%


Epoch 25/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.374]



[Epoch 25] Loss: 193.7440 | Time: 152.77s | Dead Features: 66.0%


Epoch 26/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.377]



[Epoch 26] Loss: 193.4124 | Time: 151.82s | Dead Features: 66.8%


Epoch 27/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.38] 



[Epoch 27] Loss: 193.1904 | Time: 151.38s | Dead Features: 67.2%


Epoch 28/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.374]



[Epoch 28] Loss: 192.8892 | Time: 152.35s | Dead Features: 67.5%


Epoch 29/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.376]



[Epoch 29] Loss: 192.6739 | Time: 151.12s | Dead Features: 67.7%


Epoch 30/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.377]



[Epoch 30] Loss: 192.4422 | Time: 151.71s | Dead Features: 68.7%


Epoch 31/100: 100%|██████████| 508/508 [02:30<00:00,  3.37it/s, loss=0.386]



[Epoch 31] Loss: 192.2223 | Time: 150.72s | Dead Features: 68.8%


Epoch 32/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.377]



[Epoch 32] Loss: 192.0388 | Time: 152.24s | Dead Features: 70.3%


Epoch 33/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.371]



[Epoch 33] Loss: 191.8168 | Time: 151.22s | Dead Features: 69.9%


Epoch 34/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.378]



[Epoch 34] Loss: 191.6827 | Time: 151.29s | Dead Features: 70.0%


Epoch 35/100: 100%|██████████| 508/508 [02:30<00:00,  3.36it/s, loss=0.37] 



[Epoch 35] Loss: 191.4511 | Time: 150.99s | Dead Features: 70.4%


Epoch 36/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.374]



[Epoch 36] Loss: 191.3605 | Time: 151.76s | Dead Features: 70.6%


Epoch 37/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.379]



[Epoch 37] Loss: 191.1678 | Time: 151.46s | Dead Features: 70.9%


Epoch 38/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.374]



[Epoch 38] Loss: 191.1027 | Time: 152.36s | Dead Features: 71.3%


Epoch 39/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.38] 



[Epoch 39] Loss: 190.9334 | Time: 151.70s | Dead Features: 70.9%


Epoch 40/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.379]



[Epoch 40] Loss: 190.8471 | Time: 151.32s | Dead Features: 71.2%


Epoch 41/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.376]



[Epoch 41] Loss: 190.6903 | Time: 152.13s | Dead Features: 71.0%


Epoch 42/100: 100%|██████████| 508/508 [02:30<00:00,  3.38it/s, loss=0.373]



[Epoch 42] Loss: 190.5816 | Time: 150.48s | Dead Features: 71.2%


Epoch 43/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.386]



[Epoch 43] Loss: 190.4681 | Time: 152.06s | Dead Features: 70.9%


Epoch 44/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.371]



[Epoch 44] Loss: 190.3513 | Time: 152.16s | Dead Features: 71.0%


Epoch 45/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.373]



[Epoch 45] Loss: 190.2807 | Time: 152.29s | Dead Features: 70.8%


Epoch 46/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.371]



[Epoch 46] Loss: 190.1426 | Time: 151.54s | Dead Features: 71.2%


Epoch 47/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.37] 



[Epoch 47] Loss: 190.1185 | Time: 152.32s | Dead Features: 70.8%


Epoch 48/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.382]



[Epoch 48] Loss: 189.9860 | Time: 151.59s | Dead Features: 71.5%


Epoch 49/100: 100%|██████████| 508/508 [02:30<00:00,  3.36it/s, loss=0.373]



[Epoch 49] Loss: 189.9804 | Time: 150.98s | Dead Features: 71.4%


Epoch 50/100: 100%|██████████| 508/508 [02:30<00:00,  3.37it/s, loss=0.373]



[Epoch 50] Loss: 189.8543 | Time: 150.57s | Dead Features: 71.4%


Epoch 51/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.374]



[Epoch 51] Loss: 189.8455 | Time: 151.63s | Dead Features: 71.5%


Epoch 52/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.385]



[Epoch 52] Loss: 189.7402 | Time: 151.39s | Dead Features: 71.4%


Epoch 53/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.369]



[Epoch 53] Loss: 189.7173 | Time: 152.04s | Dead Features: 71.5%


Epoch 54/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.368]



[Epoch 54] Loss: 189.6273 | Time: 152.63s | Dead Features: 71.5%


Epoch 55/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.356]



[Epoch 55] Loss: 189.5840 | Time: 152.73s | Dead Features: 71.5%


Epoch 56/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.371]



[Epoch 56] Loss: 189.5421 | Time: 152.39s | Dead Features: 71.7%


Epoch 57/100: 100%|██████████| 508/508 [02:32<00:00,  3.32it/s, loss=0.381]



[Epoch 57] Loss: 189.4653 | Time: 152.83s | Dead Features: 71.5%


Epoch 58/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.364]



[Epoch 58] Loss: 189.4459 | Time: 151.87s | Dead Features: 71.5%


Epoch 59/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.373]



[Epoch 59] Loss: 189.3472 | Time: 152.54s | Dead Features: 71.7%


Epoch 60/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.379]



[Epoch 60] Loss: 189.3511 | Time: 152.20s | Dead Features: 71.7%


Epoch 61/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.379]



[Epoch 61] Loss: 189.2421 | Time: 151.39s | Dead Features: 71.8%


Epoch 62/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.368]



[Epoch 62] Loss: 189.2655 | Time: 151.40s | Dead Features: 71.8%


Epoch 63/100: 100%|██████████| 508/508 [03:12<00:00,  2.64it/s, loss=0.382]



[Epoch 63] Loss: 189.1631 | Time: 192.37s | Dead Features: 71.8%


Epoch 64/100: 100%|██████████| 508/508 [03:12<00:00,  2.65it/s, loss=0.366]



[Epoch 64] Loss: 189.1802 | Time: 192.04s | Dead Features: 72.1%


Epoch 65/100: 100%|██████████| 508/508 [03:10<00:00,  2.66it/s, loss=0.369]



[Epoch 65] Loss: 189.0934 | Time: 190.95s | Dead Features: 71.7%


Epoch 66/100: 100%|██████████| 508/508 [03:13<00:00,  2.63it/s, loss=0.368]



[Epoch 66] Loss: 189.0883 | Time: 193.31s | Dead Features: 71.7%


Epoch 67/100: 100%|██████████| 508/508 [03:11<00:00,  2.66it/s, loss=0.38] 



[Epoch 67] Loss: 189.0327 | Time: 191.20s | Dead Features: 71.6%


Epoch 68/100: 100%|██████████| 508/508 [03:12<00:00,  2.64it/s, loss=0.374]



[Epoch 68] Loss: 189.0072 | Time: 192.50s | Dead Features: 71.7%


Epoch 69/100: 100%|██████████| 508/508 [03:04<00:00,  2.75it/s, loss=0.372]



[Epoch 69] Loss: 188.9887 | Time: 184.46s | Dead Features: 71.6%


Epoch 70/100: 100%|██████████| 508/508 [03:12<00:00,  2.63it/s, loss=0.38] 



[Epoch 70] Loss: 188.9313 | Time: 192.97s | Dead Features: 71.5%


Epoch 71/100: 100%|██████████| 508/508 [03:13<00:00,  2.63it/s, loss=0.376]



[Epoch 71] Loss: 188.9324 | Time: 193.19s | Dead Features: 71.6%


Epoch 72/100: 100%|██████████| 508/508 [03:14<00:00,  2.61it/s, loss=0.376]



[Epoch 72] Loss: 188.8419 | Time: 194.46s | Dead Features: 71.8%


Epoch 73/100: 100%|██████████| 508/508 [03:05<00:00,  2.74it/s, loss=0.37] 



[Epoch 73] Loss: 188.8683 | Time: 185.68s | Dead Features: 71.6%


Epoch 74/100: 100%|██████████| 508/508 [02:30<00:00,  3.37it/s, loss=0.377]



[Epoch 74] Loss: 188.7770 | Time: 150.90s | Dead Features: 71.6%


Epoch 75/100: 100%|██████████| 508/508 [02:30<00:00,  3.37it/s, loss=0.356]



[Epoch 75] Loss: 188.8070 | Time: 150.73s | Dead Features: 71.7%


Epoch 76/100: 100%|██████████| 508/508 [02:32<00:00,  3.32it/s, loss=0.372]



[Epoch 76] Loss: 188.7270 | Time: 152.79s | Dead Features: 71.8%


Epoch 77/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.365]



[Epoch 77] Loss: 188.7494 | Time: 151.71s | Dead Features: 71.3%


Epoch 78/100: 100%|██████████| 508/508 [02:32<00:00,  3.32it/s, loss=0.369]



[Epoch 78] Loss: 188.6831 | Time: 152.97s | Dead Features: 71.7%


Epoch 79/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.368]



[Epoch 79] Loss: 188.6868 | Time: 151.36s | Dead Features: 71.5%


Epoch 80/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.378]



[Epoch 80] Loss: 188.6604 | Time: 152.51s | Dead Features: 71.7%


Epoch 81/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.38] 



[Epoch 81] Loss: 188.6281 | Time: 151.83s | Dead Features: 71.6%


Epoch 82/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.373]



[Epoch 82] Loss: 188.6269 | Time: 151.47s | Dead Features: 71.7%


Epoch 83/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.375]



[Epoch 83] Loss: 188.5672 | Time: 152.33s | Dead Features: 71.7%


Epoch 84/100: 100%|██████████| 508/508 [02:31<00:00,  3.36it/s, loss=0.363]



[Epoch 84] Loss: 188.5940 | Time: 151.33s | Dead Features: 71.5%


Epoch 85/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.374]



[Epoch 85] Loss: 188.5142 | Time: 152.31s | Dead Features: 71.8%


Epoch 86/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.37] 



[Epoch 86] Loss: 188.5592 | Time: 151.70s | Dead Features: 71.7%


Epoch 87/100: 100%|██████████| 508/508 [02:31<00:00,  3.34it/s, loss=0.372]



[Epoch 87] Loss: 188.4761 | Time: 151.94s | Dead Features: 71.7%


Epoch 88/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.375]



[Epoch 88] Loss: 188.5193 | Time: 152.69s | Dead Features: 71.7%


Epoch 89/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.381]



[Epoch 89] Loss: 188.4429 | Time: 152.24s | Dead Features: 71.7%


Epoch 90/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.382]



[Epoch 90] Loss: 188.4697 | Time: 152.30s | Dead Features: 71.8%


Epoch 91/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.373]



[Epoch 91] Loss: 188.4235 | Time: 152.76s | Dead Features: 71.7%


Epoch 92/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.365]



[Epoch 92] Loss: 188.4192 | Time: 152.15s | Dead Features: 71.8%


Epoch 93/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.381]



[Epoch 93] Loss: 188.3983 | Time: 151.71s | Dead Features: 71.9%


Epoch 94/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.354]



[Epoch 94] Loss: 188.3731 | Time: 152.40s | Dead Features: 71.9%


Epoch 95/100: 100%|██████████| 508/508 [02:32<00:00,  3.33it/s, loss=0.375]



[Epoch 95] Loss: 188.3853 | Time: 152.36s | Dead Features: 71.9%


Epoch 96/100: 100%|██████████| 508/508 [02:31<00:00,  3.34it/s, loss=0.368]



[Epoch 96] Loss: 188.3297 | Time: 151.93s | Dead Features: 71.9%


Epoch 97/100: 100%|██████████| 508/508 [02:31<00:00,  3.34it/s, loss=0.375]



[Epoch 97] Loss: 188.3638 | Time: 151.95s | Dead Features: 71.8%


Epoch 98/100: 100%|██████████| 508/508 [02:32<00:00,  3.34it/s, loss=0.373]



[Epoch 98] Loss: 188.2775 | Time: 152.21s | Dead Features: 72.0%


Epoch 99/100: 100%|██████████| 508/508 [02:31<00:00,  3.35it/s, loss=0.377]



[Epoch 99] Loss: 188.3318 | Time: 151.61s | Dead Features: 71.7%


Epoch 100/100: 100%|██████████| 508/508 [02:32<00:00,  3.32it/s, loss=0.367]


[Epoch 100] Loss: 188.2568 | Time: 152.88s | Dead Features: 71.7%





In [None]:
model_state_dicts = {name: model.state_dict() for name, model in SAEs.items()}
torch.save(model_state_dicts, "./models/ViT_MLP.pt")

: 

In [10]:
model_path = "./models/ViT_MLP.pt"
state = torch.load(model_path)

for name, sae in SAEs.items():
    print(sae.load_state_dict(state[name]))

<All keys matched successfully>


In [None]:
### SAVE THE Z_sAES and then on them
visualize_concepts(
    activation_loader=activations_dataloader,
    SAEs=SAEs,
    num_concepts=concepts,
    n_images=8,
    patch_width=14, 
    save_dir="results/visualizer_ViT_MLP/",
    abort_threshold=0.0,
)

100%|██████████| 508/508 [14:12<00:00,  1.68s/it]


6144
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipping cuz 0
Skipp

In [None]:
vit_full = timm.create_model('vit_base_patch16_224', pretrained=True)
mlp = nn.Linear(concepts, 1000) # 1000 Classes
vit_extractor = models["ViT"]

##################



##################


