In this notebook we provide the reconstruction pipeline code for the RECAST network, where RECAST module has been used in the MLP layers of a Vision Transfer. A similar pipeline can be used for other types of networks as well (RECAST applied to the attention layers, CNN layers etc.) by changing the [model definition](models) and the forward pass in `reconstruction_loss_dynamic()` accordingly.

## Model Definition

In [None]:
import numpy as np
import torch
import torch.nn as nn
import argparse
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import matplotlib.pyplot as plt
import gc
import torch.backends.cudnn as cudnn
import numpy as np
import random
import os
import shutil
import tqdm
manualSeed = 42
DEFAULT_THRESHOLD = 5e-3

random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
np.random.seed(manualSeed)
cudnn.benchmark = False
torch.backends.cudnn.enabled = False
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: ", device)
batch_size = 256
import gc

FACTORS = 6 # number of groups
TEMPLATES = 2 # number of templates per bank, corresponds to number of layers in a group
MULT = 1 # optional multiplier for the number of coefficients set
num_cf = 2 # number of coefficients sets per target module

def calculate_parameters(model):
    attention_params = 0
    template_params = 0
    coefficients_params = 0
    mlp_params = 0
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    classifer_params = sum(p.numel() for p in model.head.parameters())
 
    for n, p in model.named_parameters():
        if ".attn." in n:
            attention_params += p.numel()
            # print("Attention params: ", n, p.numel())
        if "template_banks" in n:
            template_params += p.numel()
            # print("Template params: ", n, p.numel())
        if "mlp" in n:
            mlp_params += p.numel()
            # print("MLP params: ", n, p.numel())
        if "coefficients" in n:
            coefficients_params += p.numel()
            # print("Coefficients params: ", n, p.numel())


    print("Classifier head: ", model.head)
    print(f"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}")
    print(f"Attention parameters: {attention_params}")
    print(f"Templates params: {template_params}")
    print(f"MLP params: {mlp_params}")
    print(f"Coefficients params: {coefficients_params}")

class MLPTemplateBank(nn.Module):
    def __init__(self, num_templates, in_features, out_features):
        super(MLPTemplateBank, self).__init__()
        self.num_templates = num_templates
        self.coefficient_shape = (num_templates, 1, 1)
        templates = [torch.Tensor(out_features, in_features) for _ in range(num_templates)]
        for i in range(num_templates):
            init.kaiming_normal_(templates[i])
        self.templates = nn.Parameter(torch.stack(templates))
    def forward(self, coefficients):
        params = self.templates * coefficients
        summed_params = torch.sum(params, dim=0)
        return summed_params    
    
    def __repr__(self):
        return f"MLPTemplateBank(num_templates={self.templates.shape[0]}, in_features={self.templates.shape[1]}, out_features={self.templates.shape[2]}, coefficients={self.coefficient_shape})"
        
class SharedMLP(nn.Module):
    # TARGET MODULE
    def __init__(self, bank1, bank2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0.):
        super(SharedMLP, self).__init__()
        self.bank1 = None
        self.bank2 = None

        if bank1 != None and bank2 != None:
            self.bank1 = bank1
            self.bank2 = bank2
            self.coefficients1 = nn.ParameterList([nn.Parameter(torch.zeros(bank1.coefficient_shape), requires_grad = True) for _ in range(num_cf)])
            self.coefficients2 = nn.ParameterList([nn.Parameter(torch.zeros(bank2.coefficient_shape), requires_grad = True) for _ in range(num_cf)])
            self.bias1 = nn.Parameter(torch.zeros(bank1.templates.shape[1]))
            self.bias2 = nn.Parameter(torch.zeros(bank2.templates.shape[1]))

        self.act = act_layer()
        self.norm = nn.Identity()
        self.drop = nn.Dropout(drop)
        self.init_weights()
    def init_weights(self):
        if self.bank1 != None:
            for cf in self.coefficients1:
                nn.init.orthogonal_(cf)
        if self.bank2 != None:
            for cf in self.coefficients2:
                nn.init.orthogonal_(cf)
    def forward(self, x):
        # print(f"CF1: ",self.coefficients1)
        # print(f"CF2: ",self.coefficients2)
        if self.bank1 != None:
            # weights1 = self.bank1(self.coefficients1)
            weight1 = []
            for c in self.coefficients1:
                w = self.bank1(c)
                weight1.append(w)
            weights1 = torch.stack(weight1).mean(0) # TODO
        if self.bank2 != None:
            # weights2 = self.bank2(self.coefficients2)
            weight2 = []
            for c in self.coefficients2:
                w = self.bank2(c)
                weight2.append(w)
            weights2 = torch.stack(weight2).mean(0) # TODO

        x = F.linear(x, weights1, self.bias1)
        x = self.act(x)
        x = self.norm(x)
        x = F.linear(x, weights2, self.bias2)
        x = self.drop(x)
        return x


# original timm module for vision transformer
class Attention(nn.Module):
    def __init__(
            self, dim, num_heads=6, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self,x):
        B, N, C = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1) # attention proba
        attn = self.attn_drop(attn)
        x = attn @ v # attention output
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, bank1=None, bank2=None):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
        self.ls1 = nn.Identity()
        self.ls2 = nn.Identity()
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        self.mlp = SharedMLP(bank1, bank2, act_layer=act_layer, norm_layer=norm_layer, drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        return x      

class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0)
        self.norm = nn.Identity()
    def forward(self, x):
        x = self.proj(x)
        x = nn.Flatten(start_dim=2, end_dim = 3)(x).permute(0, 2, 1)
        x = self.norm(x)
        return x


# create vision transformer with template bank
class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm):
        super().__init__()
        self.img_size = img_size
        self.dim = embed_dim
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_features = self.embed_dim
        self.num_prefix_tokens = 1
        self.num_patches = (img_size // patch_size) ** 2
        self.num_prefix_tokens = 1
        self.has_class_token = True
        self.cls_token = nn.Parameter(torch.ones(1, 1, self.embed_dim))
        self.patch_embed = PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
        )

        num_patches =  (self.img_size // self.patch_size) ** 2
        print("Num patches: ", num_patches)
        embed_len = num_patches + self.num_prefix_tokens 
        self.pos_embed =  nn.Parameter(torch.ones(1, num_patches + self.num_prefix_tokens, embed_dim)*.02, requires_grad=True)

        self.pos_drop = nn.Dropout(p=drop_rate)
        self.patch_drop = nn.Identity()
        self.fc_norm = nn.Identity()
        self.head_drop = nn.Dropout(drop_rate)

        self.num_groups = FACTORS
        self.num_layers_in_group = depth // self.num_groups # how many consective encoder layers share the same template bank
        print("Num layers in group: ", self.num_layers_in_group)
        self.num_templates = TEMPLATES
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.template_banks1 = nn.ModuleList([MLPTemplateBank(self.num_templates, embed_dim, mlp_hidden_dim) for _ in range(self.num_groups)])
        self.template_banks2 = nn.ModuleList([MLPTemplateBank(self.num_templates, mlp_hidden_dim, embed_dim) for _ in range(self.num_groups)])
        self.depth = depth

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_groups)]  # stochastic depth decay rule
        self.blocks = nn.ModuleList()
        for i in range(depth):
            group_idx = i //self.num_layers_in_group
            print(group_idx)
            bank1 = self.template_banks1[group_idx]
            bank2 = self.template_banks2[group_idx]
            self.blocks.append(Block(
                dim=self.embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[group_idx], norm_layer=norm_layer, bank1=bank1, bank2=bank2))
        print(f"Num blocks: {len(self.blocks)}")
        self.norm = norm_layer(self.embed_dim)
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)
        
    def _pos_embed(self, x):
        to_cat = []
        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)
        x = x + self.pos_embed
        return self.pos_drop(x)
    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x
    
    def forward_head(self, x):
        x = x[:, 0]
        x = self.fc_norm(x)
        x = self.head_drop(x)
        x = self.head(x)
        return x
    

    def forward(self, x):
        features = self.forward_features(x)
        head_output = self.forward_head(features)
        return head_output

my_model = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)
calculate_parameters(my_model)

from pytorch_model_summary import summary
my_model.eval()
print(summary(my_model, torch.zeros(1, 3,224,224), show_hierarchical=True))
del my_model 
torch.cuda.empty_cache()

## Load existing Weights

In [None]:
# load timm vit weights into the custom model
import timm 
vit = timm.create_model("vit_small_patch16_224", pretrained=True)
print("Base model")
calculate_parameters(vit)
# freeze the weights of the timm model
for param in vit.parameters():
    param.requires_grad = False
FOUND = []
# load the weights from the timm model to the custom model
def load_weights(model, timm_model):
    model_dict = model.state_dict()
    timm_dict = timm_model.state_dict()
    new_dict = {}
    for k, v in timm_dict.items():
        if k in model_dict:
            # check the shape of the weights
            if model_dict[k].shape == timm_dict[k].shape:
                new_dict[k] = v 
                FOUND.append(k)
        else:
            if "mlp.fc1.bias" in k:
                new_k = k.replace("mlp.fc1.bias", "mlp.bias1") # name in model_dict
                print(f"{k} --> {new_k}")
                if model_dict[new_k].shape == timm_dict[k].shape:
                    new_dict[new_k] = v
                    FOUND.append(new_k)
            elif "mlp.fc2.bias" in k:
                new_k = k.replace("mlp.fc2.bias", "mlp.bias2",)
                print(f"{k} --> {new_k}")
                if model_dict[new_k].shape == timm_dict[k].shape:
                    new_dict[new_k] = v
                    FOUND.append(new_k)
            else:
                print("Key not found: ", k)
    model_dict.update(new_dict)
    model.load_state_dict(model_dict)
    return model

tpbvit = VisionTransformer(img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm)

tpbvit = load_weights(tpbvit, vit)
tpbvit_state_dict = list(tpbvit.state_dict().keys()) 
print("Found: ", len(FOUND))
target_params = []

for k in tpbvit_state_dict:
    if k not in FOUND:
        target_params.append(k)

for name, param in tpbvit.named_parameters():
    if name not in target_params :
        param.requires_grad = False
    # check if the param is layer norm or batch norm
    # if "norm" in name:
    #     param.requires_grad = True

calculate_parameters(tpbvit) 

for n, p in tpbvit.named_parameters():
    if p.requires_grad:
        print(n, p.requires_grad)

## Reconstruction Process

In [None]:
import torch 
import torch.nn as nn
import torch.optim as optim
import timm 
import copy
import torch 
import torch.nn as nn
import torch.optim as optim
import timm 
import copy
def reconstruction_loss_dynamic(current_model, pretrained_model, criterion=nn.SmoothL1Loss(), w1_weight=2.0, w2_weight=2.0):
    corr_state_dict = pretrained_model.state_dict()
    loss_dict = {}
    total_loss = 0.0
    w1_loss = 0.0
    w2_loss = 0.0
    ortho_loss = 0.0
    
    # Determine the device of the current model
    device = next(current_model.parameters()).device
    
    for id, block in enumerate(current_model.blocks):  
        if block.mlp.bank1 is not None:
            mlp_cf1 = block.mlp.coefficients1
            mlp_bank1 = block.mlp.bank1
            weights1 = []
            noise_std1 = 1e-4 # set to zero for no noise
            for c in mlp_cf1:
                if current_model.training:
                    noise = torch.randn_like(c) * noise_std1
                    c = c + noise
                w = mlp_bank1(c)
                weights1.append(w)
    
            _weights1 = torch.stack(weights1).mean(0) # TODO
            corr_weight1 = corr_state_dict[f'blocks.{id}.mlp.fc1.weight'].to(device)
            w1_l = criterion(_weights1, corr_weight1) * w1_weight


        if block.mlp.bank2 is not None:
            mlp_cf2 = block.mlp.coefficients2
            mlp_bank2 = block.mlp.bank2
            noise_std2 = 1e-4  # set to zero for no noise
            weights2 = []
            for c in mlp_cf2:
                if current_model.training:
                    noise = torch.randn_like(c) * noise_std2
                    c = c + noise
                w = mlp_bank2(c)
                weights2.append(w)
            _weights2 = torch.stack(weights2).mean(0) # TODO
            corr_weight2 = corr_state_dict[f'blocks.{id}.mlp.fc2.weight'].to(device)
            w2_l = criterion(_weights2, corr_weight2) * w2_weight

        
        total_loss += w1_l + w2_l
        w1_loss += w1_l.item()
        w2_loss += w2_l.item()


    return loss_dict, total_loss, w1_loss, w2_loss

In [None]:
optimizer = optim.RMSprop(
         tpbvit.parameters(),
        lr=1e-1,
    )
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.1)
best_model = None
best_model_loss = 1e9
for epoch in range(1000):
    optimizer.zero_grad()
    losses_dict, total_loss, w1, w2 = reconstruction_loss_dynamic(tpbvit.to(device), vit)
    if total_loss < best_model_loss:
        best_model_loss = total_loss.item()
        best_model = copy.deepcopy(tpbvit)
        # print("New Best Loss: ", best_model_loss)
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(tpbvit.parameters(), max_norm=1.0)
    optimizer.step()
    scheduler.step()
    if (epoch+1) %10 == 0:
        print("Epoch: ", epoch, "Loss: ", total_loss.item(), "W1: ", w1, "W2: ", w2)
        print("LR: ", optimizer.param_groups[0]['lr'])

## Evaluate Reconstruction

### Comparing Cosine Similarity

In [None]:
def evaluate_dynamic_reconstruction(current_model, pretrained_model):
    similarity = 0.0
    metric = nn.CosineSimilarity(dim=0)
    corr_state_dict = pretrained_model.state_dict()
    w1_loss = 0.0
    w2_loss = 0.0
    b1_loss = 0.0
    b2_loss = 0.0
    w1_l = 0
    w2_l = 0
    b1_l = 0
    b2_l = 0
    similarity_dict = {}
    for id, block in enumerate(current_model.blocks):
        if block.mlp.bank1 != None :
            mlp_cf1 = block.mlp.coefficients1
            mlp_bank1 = block.mlp.bank1
            weights1 = torch.stack([mlp_bank1(cf) for cf in mlp_cf1]).mean(0) # TODO
            corr_weight1 = corr_state_dict[f'blocks.{id}.mlp.fc1.weight']
            w1_l = metric(weights1.view(-1), corr_weight1.view(-1))
            w1_loss += w1_l.item()

            bias1 = block.mlp.bias1
            corr_bias1 = corr_state_dict[f'blocks.{id}.mlp.fc1.bias']
            b1_l = metric(bias1.view(-1), corr_bias1.view(-1))
            b1_loss += b1_l.item()
            group_idx = id // current_model.num_layers_in_group
            # print("Group: ", group_idx)
            similarity_dict[f"template_banks1.{group_idx}.templates"] = w1_l.item()
            similarity_dict[f"blocks.{id}.mlp.bias1"] = b1_l.item()
            for i, cf in enumerate(mlp_cf1):
                similarity_dict[f"blocks.{id}.mlp.coefficients1.{i}"] = w1_l.item()
        if block.mlp.bank2 != None:
            mlp_cf2 = block.mlp.coefficients2
            mlp_bank2 = block.mlp.bank2
            weights2 = torch.stack([mlp_bank2(cf) for cf in mlp_cf2]).mean(0) # TODO
            corr_weight2 = corr_state_dict[f'blocks.{id}.mlp.fc2.weight']
            w2_l = metric(weights2.view(-1), corr_weight2.view(-1))
            w2_loss += w2_l.item()

            bias2 = block.mlp.bias2
            corr_bias2 = corr_state_dict[f'blocks.{id}.mlp.fc2.bias']
            b2_l = metric(bias2.view(-1), corr_bias2.view(-1))
            b2_loss += b2_l.item()
            group_idx = id // current_model.num_layers_in_group
            # print("Group: ", group_idx)
            similarity_dict[f"template_banks2.{group_idx}.templates"] = w2_l.item()
            similarity_dict[f"blocks.{id}.mlp.bias2"] = b2_l.item()
            for i, cf in enumerate(mlp_cf2):
                similarity_dict[f"blocks.{id}.mlp.coefficients2.{i}"] = w2_l.item()
            
        print("Block: ", id, "W1: ", w1_l, "W2: ", w2_l, "B1: ", b1_l, "B2: ", b2_l)

    print("W1: ", w1_loss/len(current_model.blocks), "W2: ", w2_loss/len(current_model.blocks), "B1: ", b1_loss/len(current_model.blocks), "B2: ", b2_loss/len(current_model.blocks))
    return similarity_dict

evaluate_dynamic_reconstruction(best_model.to("cpu"), vit)
calculate_parameters(best_model)

### Comparing layerwise feature similarity

In [None]:
import torch
import torch.nn as nn
from PIL import Image
import os
import random
import numpy as np
# Load the batch from disk
batch = torch.load("<DATALOADER_BATCH_PATH>")

# Modified forward function to return intermediate features
def forward_with_intermediate(model, x, cust = True):
    B = x.shape[0]
    x = model.patch_embed(x)
    intermediate_features = []
    if not cust:
        print("classic")
        cls_tokens = model.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + model.pos_embed
        x = model.pos_drop(x)
        for blk in model.blocks:
            x = blk(x)
            intermediate_features.append(x)
    else: 
        print("Custom")
        x = model._pos_embed(x)
        x = model.patch_drop(x)
        for blk in model.blocks:
            x = blk(x)
            intermediate_features.append(x)
    x = model.norm(x)
    return x, intermediate_features

# Function to compare features layerwise
def compare_features_layerwise(model1, model2, batch):
    batch = batch.to(device)
    _, features1 = forward_with_intermediate(model1.to("cpu"), batch.to("cpu"), True)
    _, features2 = forward_with_intermediate(model2.to("cpu"), batch.to("cpu"), False)
    
    cos = nn.CosineSimilarity(dim=-1)  # Change dim to -1 for last dimension
    similarities = []
    
    for f1, f2 in zip(features1, features2):
        # print(f1.shape, f2.shape)
        # Flatten the features except for the batch dimension
        f1_flat = f1.view(f1.size(0), -1)
        f2_flat = f2.view(f2.size(0), -1)
        
        # Compare the flattened features
        similarity = cos(f1_flat, f2_flat).mean()
        similarities.append(similarity.item())
    
    return similarities

# # # Make sure both models are on the same device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
best_model = best_model.to(device)
vit = vit.to(device)
best_model.eval()
vit.eval()
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
for i in range(1):
    # Compare features layerwise
    with torch.no_grad():
        similarities = compare_features_layerwise(best_model.to(device), vit.to(device), batch.to(device))
        av = (np.mean(similarities))
        print(av)
        # # Print similarities for each layer
        for i, sim in enumerate(similarities):
            print(f"Layer {i+1} similarity: {sim:.4f}")


        plt.figure(figsize=(6, 6))
        plt.ylim(-1, 2) 
        plt.plot(range(1, len(similarities) + 1), similarities, marker='o')
        plt.title(f'Layerwise Cosine Similarity factor={FACTORS}, templates={TEMPLATES} num_cf = {num_cf}Average Sim: {av:.3f}')
        plt.xlabel("Layer")
        plt.ylabel("Cosine Similarity")
        plt.grid(True)
        plt.show()
