In this notebook we provide the reconstruction pipeline code for the RECAST network, where RECAST module has been used in the MLP layers of a Vision Transfer. A similar pipeline can be used for other types of networks as well (RECAST applied to the attention layers, CNN layers etc.) by changing the [model definition](models) and the forward pass in `reconstruction_loss_dynamic()` accordingly.

## Model Definition

In [None]:
import numpy as np
import torch
import torch.nn as nn
import argparse
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import gc
import torch.backends.cudnn as cudnn
import numpy as np
import random


manualSeed = 42
DEFAULT_THRESHOLD = 5e-3

random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
np.random.seed(manualSeed)
cudnn.benchmark = False
torch.backends.cudnn.enabled = False
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
import gc
FACTORS = 6
TEMPLATES = 2
MULT = 1
num_cf = 2


def calculate_parameters(model):
    attention_params = 0
    template_params = 0
    coefficients_params = 0
    mlp_params = 0
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    classifer_params = sum(p.numel() for p in model.head.parameters())

    for n, p in model.named_parameters():
        if ".attn." in n:
            attention_params += p.numel()
            # print("Attention params: ", n, p.numel())
        if "template_banks" in n:
            template_params += p.numel()
            # print("Template params: ", n, p.numel())
        if "mlp" in n:
            mlp_params += p.numel()
            # print("MLP params: ", n, p.numel())
        if "coefficients" in n:
            coefficients_params += p.numel()
            # print("Coefficients params: ", n, p.numel())

    print("Classifier head: ", model.head)
    # print(model.fc)
    print(
        f"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}"
    )
    print(f"Attention parameters: {attention_params}")
    print(f"Templates params: {template_params}")
    print(f"MLP params: {mlp_params}")
    print(f"Coefficients params: {coefficients_params}")


class MLPTemplateBank(nn.Module):
    def __init__(self, num_templates, in_features, out_features):
        super(MLPTemplateBank, self).__init__()
        self.num_templates = num_templates
        self.coefficient_shape = (num_templates, 1, 1)
        templates = [
            torch.Tensor(out_features, in_features) for _ in range(num_templates)
        ]
        for i in range(num_templates):
            init.kaiming_normal_(templates[i])
        self.templates = nn.Parameter(torch.stack(templates))

    def forward(self, coefficients):
        params = self.templates * coefficients
        summed_params = torch.sum(params, dim=0)
        return summed_params

    def __repr__(self):
        return f"MLPTemplateBank(num_templates={self.templates.shape[0]}, in_features={self.templates.shape[1]}, out_features={self.templates.shape[2]}, coefficients={self.coefficient_shape})"


class SharedMLP(nn.Module):
    def __init__(
        self, bank1, bank2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0.0
    ):
        super(SharedMLP, self).__init__()
        self.bank1 = None
        self.bank2 = None

        if bank1 != None and bank2 != None:
            self.bank1 = bank1
            self.bank2 = bank2
            self.coefficients1 = nn.ParameterList(
                [
                    nn.Parameter(
                        torch.zeros(bank1.coefficient_shape), requires_grad=True
                    )
                    for _ in range(num_cf)
                ]
            )
            self.coefficients2 = nn.ParameterList(
                [
                    nn.Parameter(
                        torch.zeros(bank2.coefficient_shape), requires_grad=True
                    )
                    for _ in range(num_cf)
                ]
            )
            self.bias1 = nn.Parameter(torch.zeros(bank1.templates.shape[1]))
            self.bias2 = nn.Parameter(torch.zeros(bank2.templates.shape[1]))

        self.act = act_layer()
        self.norm = nn.Identity()
        self.drop = nn.Dropout(drop)
        self.init_weights()

    def init_weights(self):
        if self.bank1 != None:
            for cf in self.coefficients1:
                nn.init.orthogonal_(cf)
        if self.bank2 != None:
            for cf in self.coefficients2:
                nn.init.orthogonal_(cf)

    def forward(self, x):
        if self.bank1 != None:
            weight1 = []
            for c in self.coefficients1:
                w = self.bank1(c)
                weight1.append(w)
            weights1 = torch.stack(weight1).mean(0)
        if self.bank2 != None:
            weight2 = []
            for c in self.coefficients2:
                w = self.bank2(c)
                weight2.append(w)
            weights2 = torch.stack(weight2).mean(0)

        x = F.linear(x, weights1, self.bias1)
        x = self.act(x)
        x = self.norm(x)
        x = F.linear(x, weights2, self.bias2)
        x = self.drop(x)
        return x


# original timm module for vision transformer
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=6,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)  # attention proba
        attn = self.attn_drop(attn)
        x = attn @ v  # attention output
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        bank1=None,
        bank2=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.ls1 = nn.Identity()
        self.ls2 = nn.Identity()
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = SharedMLP(
            bank1, bank2, act_layer=act_layer, norm_layer=norm_layer, drop=drop
        )

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        return x


class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0
        )
        self.norm = nn.Identity()

    def forward(self, x):
        x = self.proj(x)
        x = nn.Flatten(start_dim=2, end_dim=3)(x).permute(0, 2, 1)
        x = self.norm(x)
        return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.img_size = img_size
        self.dim = embed_dim
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_features = self.embed_dim
        self.num_prefix_tokens = 1
        self.num_patches = (img_size // patch_size) ** 2
        self.num_prefix_tokens = 1
        self.has_class_token = True
        self.cls_token = nn.Parameter(torch.ones(1, 1, self.embed_dim))
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )

        num_patches = (self.img_size // self.patch_size) ** 2
        print("Num patches: ", num_patches)
        embed_len = num_patches + self.num_prefix_tokens
        self.pos_embed = nn.Parameter(
            torch.ones(1, num_patches + self.num_prefix_tokens, embed_dim) * 0.02,
            requires_grad=True,
        )

        self.pos_drop = nn.Dropout(p=drop_rate)
        self.patch_drop = nn.Identity()
        self.fc_norm = nn.Identity()
        self.head_drop = nn.Dropout(drop_rate)

        self.num_groups = FACTORS
        self.num_layers_in_group = (
            depth // self.num_groups
        ) 
        print("Num layers in group: ", self.num_layers_in_group)
        self.num_templates = TEMPLATES
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.template_banks1 = nn.ModuleList(
            [
                MLPTemplateBank(self.num_templates, embed_dim, mlp_hidden_dim)
                for _ in range(self.num_groups)
            ]
        )
        self.template_banks2 = nn.ModuleList(
            [
                MLPTemplateBank(self.num_templates, mlp_hidden_dim, embed_dim)
                for _ in range(self.num_groups)
            ]
        )
        self.depth = depth

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, self.num_groups)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList()
        for i in range(depth):
            group_idx = i // self.num_layers_in_group
            print(group_idx)
            bank1 = self.template_banks1[group_idx]
            bank2 = self.template_banks2[group_idx]
            self.blocks.append(
                Block(
                    dim=self.embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[group_idx],
                    norm_layer=norm_layer,
                    bank1=bank1,
                    bank2=bank2,
                )
            )
        print(f"Num blocks: {len(self.blocks)}")
        self.norm = norm_layer(self.embed_dim)
        self.head = (
            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

    def _pos_embed(self, x):
        to_cat = []
        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)
        x = x + self.pos_embed
        return self.pos_drop(x)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

    def forward_head(self, x):
        x = x[:, 0]
        x = self.fc_norm(x)
        x = self.head_drop(x)
        x = self.head(x)
        return x

    def forward(self, x):
        features = self.forward_features(x)
        head_output = self.forward_head(features)
        return head_output

my_model = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=10, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)
calculate_parameters(my_model)

my_model = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)
calculate_parameters(my_model)

from pytorch_model_summary import summary
my_model.eval()
print(summary(my_model, torch.zeros(1, 3,224,224), show_hierarchical=True))
del my_model 
torch.cuda.empty_cache()

## Load existing weight from timm

In [None]:
# load timm vit weights into the custom model
import timm 
vit = timm.create_model("vit_small_patch16_224", pretrained=True)
# freeze the weights of the timm model
for param in vit.parameters():
    param.requires_grad = False
FOUND = []
# load the weights from the timm model to the custom model
def load_weights(model, timm_model):
    model_dict = model.state_dict()
    timm_dict = timm_model.state_dict()
    new_dict = {}
    for k, v in timm_dict.items():
        if k in model_dict:
            # check the shape of the weights
            if model_dict[k].shape == timm_dict[k].shape:
                new_dict[k] = v 
                FOUND.append(k)
        else:
            if "mlp.fc1.bias" in k:
                new_k = k.replace("mlp.fc1.bias", "mlp.bias1") # name in model_dict
                print(f"{k} --> {new_k}")
                if model_dict[new_k].shape == timm_dict[k].shape:
                    new_dict[new_k] = v
                    FOUND.append(new_k)
            elif "mlp.fc2.bias" in k:
                new_k = k.replace("mlp.fc2.bias", "mlp.bias2",)
                print(f"{k} --> {new_k}")
                if model_dict[new_k].shape == timm_dict[k].shape:
                    new_dict[new_k] = v
                    FOUND.append(new_k)
            else:
                print("Key not found: ", k)
    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    return model

tpbvit = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)

tpbvit = load_weights(tpbvit, vit)
tpbvit_state_dict = list(tpbvit.state_dict().keys()) 
print("Found: ", len(FOUND))
target_params = []

for k in tpbvit_state_dict:
    if k not in FOUND:
        target_params.append(k)

for name, param in tpbvit.named_parameters():
    if name not in target_params or "coefficients1" in name or "coefficients2" in name:
        param.requires_grad = False

calculate_parameters(tpbvit) 

for n, p in tpbvit.named_parameters():
    if p.requires_grad:
        print(n, p.requires_grad)

## Main Reconstruction Loop

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter
import tqdm
import copy
class GroupWeightReconstructor(nn.Module):
    def __init__(self, in_features, out_features, num_templates, hidden_dim, bank):
        super(GroupWeightReconstructor, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_templates = num_templates
        self.num_cf = int(self.num_templates * MULT)
        hidden_dim =  hidden_dim 
        self.encoder = nn.Sequential(
            nn.Linear(in_features * out_features, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, num_templates * self.num_cf),
            # nn.LeakyReLU()
        )
        self.bank = bank

    def forward(self, x):
        batch_size = x.shape[0]
        coefficients = self.encoder(x.view(batch_size, -1))
        coefficients = coefficients.view(batch_size, self.num_cf, self.num_templates, 1, 1)
        
        reconstructed_weights = []
        for i in range(batch_size):
            weight_list = [self.bank(coeff) for coeff in coefficients[i]]
            reconstructed = torch.stack(weight_list).mean(0)
            reconstructed_weights.append(reconstructed)
        
        return torch.stack(reconstructed_weights), coefficients

def train_group_reconstructor(model, fc_weights, num_epochs=2500, learning_rate=2e-3, train_bank=True):
    params_to_optimize = model.parameters()
    optimizer = torch.optim.RMSprop(params_to_optimize, lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    similarity_metric = nn.CosineSimilarity(dim=1)  # Changed to dim=1 for batch processing
    best_sim = -float('inf')
    writer = SummaryWriter()
    best_model = None
    for epoch in tqdm.tqdm(range(num_epochs), desc="Training Group Reconstructor"):
        optimizer.zero_grad()
        recon_batch, _ = model(fc_weights)
        loss = F.smooth_l1_loss(recon_batch, fc_weights)
        sim = similarity_metric(recon_batch.view(recon_batch.size(0), -1), 
                                fc_weights.view(fc_weights.size(0), -1)).mean()
        
        if sim > best_sim:
            best_sim = sim
            best_model = copy.deepcopy(model)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(params_to_optimize, max_norm=1.0)
        optimizer.step()
        scheduler.step()

        writer.add_scalar('Loss/train', loss.item(), epoch)
        writer.add_scalar('Similarity/train', sim.item(), epoch)
        if (epoch + 1) % 500 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Similarity: {sim.item():.4f}')
            print("Best sim: ", best_sim)

    writer.close()
    return best_model

def prepare_and_train_reconstructors(timm_model, custom_model, train_banks=False):
    reconstructors = []
    overall_similarity = 0.0
    similarity_metric = nn.CosineSimilarity(dim=1)

    for group in range(custom_model.num_groups):
        print(f"\nTraining reconstructors for group {group}")
        bank1 = custom_model.template_banks1[group]
        bank2 = custom_model.template_banks2[group]

        # Collect weights for all layers in this group
        fc1_weights = []
        fc2_weights = []
        start_layer = group * custom_model.num_layers_in_group
        end_layer = min((group + 1) * custom_model.num_layers_in_group, custom_model.depth)
        for i in range(start_layer, end_layer):
            fc1_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc1.weight'])
            fc2_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc2.weight'])
        
        fc1_weights = torch.stack(fc1_weights)
        fc2_weights = torch.stack(fc2_weights)

        hidden_dim = 8  # Adjust as needed

        rec1 = GroupWeightReconstructor(fc1_weights.shape[2], fc1_weights.shape[1], custom_model.num_templates, hidden_dim, bank1)
        rec2 = GroupWeightReconstructor(fc2_weights.shape[2], fc2_weights.shape[1], custom_model.num_templates, hidden_dim, bank2)

        print(f"Training Reconstructor for group {group} FC1 ....")
        trained_rec1 = train_group_reconstructor(rec1, fc1_weights, train_bank=train_banks)
        print(f"Training Reconstructor for group {group} FC2 ....")
        trained_rec2 = train_group_reconstructor(rec2, fc2_weights, train_bank=train_banks)

        reconstructors.append((trained_rec1, trained_rec2))

        # Compute final similarity for this group
        with torch.no_grad():
            recon_fc1, _ = trained_rec1(fc1_weights)
            recon_fc2, _ = trained_rec2(fc2_weights)
            sim_fc1 = similarity_metric(recon_fc1.view(recon_fc1.size(0), -1), 
                                        fc1_weights.view(fc1_weights.size(0), -1)).mean()
            sim_fc2 = similarity_metric(recon_fc2.view(recon_fc2.size(0), -1), 
                                        fc2_weights.view(fc2_weights.size(0), -1)).mean()
            current_similarity = (sim_fc1 + sim_fc2) / 2
            print(f"Group {group} av similarity: {current_similarity.item()}")
            overall_similarity += current_similarity * (end_layer - start_layer)

    overall_similarity /= custom_model.depth
    print(f"\nOverall average similarity across all blocks: {overall_similarity:.4f}")

    return reconstructors, custom_model

def set_custom_model_weights(custom_model, trained_reconstructors, timm_model):
    device = next(custom_model.parameters()).device
    assert len(trained_reconstructors) == custom_model.num_groups, "Mismatch in number of groups"
    
    similarity_metric = nn.CosineSimilarity(dim=0)
    
    # First, set all weights
    for group in range(custom_model.num_groups):
        start_layer = group * custom_model.num_layers_in_group
        end_layer = min((group + 1) * custom_model.num_layers_in_group, custom_model.depth)
        
        fc1_weights = []
        fc2_weights = []
        for i in range(start_layer, end_layer):
            fc1_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc1.weight'].to(device))
            fc2_weights.append(timm_model.state_dict()[f'blocks.{i}.mlp.fc2.weight'].to(device))
        
        fc1_weights = torch.stack(fc1_weights)
        fc2_weights = torch.stack(fc2_weights)
        
        with torch.no_grad():
            _, coefficients1 = trained_reconstructors[group][0](fc1_weights)
            _, coefficients2 = trained_reconstructors[group][1](fc2_weights)
        
        for i, layer_idx in enumerate(range(start_layer, end_layer)):
            for j in range(len(custom_model.blocks[layer_idx].mlp.coefficients1)):
                custom_model.blocks[layer_idx].mlp.coefficients1[j].data = coefficients1[i, j].to(device)
                custom_model.blocks[layer_idx].mlp.coefficients2[j].data = coefficients2[i, j].to(device)
            
            custom_model.blocks[layer_idx].mlp.bias1.data = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc1.bias'].to(device)
            custom_model.blocks[layer_idx].mlp.bias2.data = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc2.bias'].to(device)

    # Now, calculate similarities for all layers
    overall_similarity_fc1 = 0.0
    overall_similarity_fc2 = 0.0
    num_layers = 0

    for group in range(custom_model.num_groups):
        start_layer = group * custom_model.num_layers_in_group
        end_layer = min((group + 1) * custom_model.num_layers_in_group, custom_model.depth)
        
        for layer_idx in range(start_layer, end_layer):
            fc1_weight = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc1.weight'].to(device)
            fc2_weight = timm_model.state_dict()[f'blocks.{layer_idx}.mlp.fc2.weight'].to(device)

            # Reconstruct weights
            weight1_list = [custom_model.template_banks1[group](c) for c in custom_model.blocks[layer_idx].mlp.coefficients1]
            weight2_list = [custom_model.template_banks2[group](c) for c in custom_model.blocks[layer_idx].mlp.coefficients2]
            
            my_weight1 = torch.stack(weight1_list).mean(0)
            my_weight2 = torch.stack(weight2_list).mean(0)

            sim1 = similarity_metric(fc1_weight.view(-1), my_weight1.view(-1))
            sim2 = similarity_metric(fc2_weight.view(-1), my_weight2.view(-1))
            
            overall_similarity_fc1 += sim1.item()
            overall_similarity_fc2 += sim2.item()
            num_layers += 1

            print(f"Block {layer_idx} - FC1 Similarity: {sim1.item():.4f}, FC2 Similarity: {sim2.item():.4f}")

    avg_similarity_fc1 = overall_similarity_fc1 / num_layers
    avg_similarity_fc2 = overall_similarity_fc2 / num_layers
    print(f"\nOverall average similarity - FC1: {avg_similarity_fc1:.4f}, FC2: {avg_similarity_fc2:.4f}")

    return custom_model

# Usage
trained_reconstructors, custom_model = prepare_and_train_reconstructors(vit, tpbvit, train_banks=True)
custom_model = set_custom_model_weights(custom_model, trained_reconstructors, vit)

## Evaluate Cosine Similarity and Feature Similarity

In [None]:
import torch
import torch.nn as nn
from PIL import Image
import os
import random
import numpy as np
import matplotlib.pyplot as plt
# Load the batch from disk
batch = torch.load("<batch.pt>")

# Modified forward function to return intermediate features
def forward_with_intermediate(model, x, cust = True):
    B = x.shape[0]
    x = model.patch_embed(x)
    intermediate_features = []
    if not cust:
        print("classic")
        cls_tokens = model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + model.pos_embed
        x = model.pos_drop(x)
        for blk in model.blocks:
            x = blk(x)
            intermediate_features.append(x)
    else: 
        print("Custom")
        x = model._pos_embed(x)
        x = model.patch_drop(x)
        for blk in model.blocks:
            x = blk(x)
            intermediate_features.append(x)
    x = model.norm(x)
    return x, intermediate_features

# Function to compare features layerwise
def compare_features_layerwise(model1, model2, batch):
    batch = batch.to(device)
    _, features1 = forward_with_intermediate(model1, batch, True)
    _, features2 = forward_with_intermediate(model2, batch, False)
    
    cos = nn.CosineSimilarity(dim=-1)  # Change dim to -1 for last dimension
    similarities = []
    
    for f1, f2 in zip(features1, features2):
        # Flatten the features except for the batch dimension
        f1_flat = f1.view(f1.size(0), -1)
        f2_flat = f2.view(f2.size(0), -1)
        
        # Compare the flattened features
        similarity = cos(f1_flat, f2_flat).mean()
        similarities.append(similarity.item())
    
    return similarities


for i in range(1):
    # Compare features layerwise
    with torch.no_grad():
        similarities = compare_features_layerwise(custom_model.to(device), vit.to(device), batch)
        av = (np.mean(similarities))
        print(av)
        # # Print similarities for each layer
        for i, sim in enumerate(similarities):
            print(f"Layer {i+1} similarity: {sim:.4f}")


        plt.figure(figsize=(6, 6))
        plt.ylim(-1, 2) 
        plt.plot(range(1, len(similarities) + 1), similarities, marker='o')
        plt.title(f'Layerwise Cosine Similarity factor={FACTORS}, templates={TEMPLATES} Average Sim: {av}')
        plt.xlabel("Layer")
        plt.ylabel("Cosine Similarity")
        plt.grid(True)
        plt.show()
        
