### Model Description
A modified resnet where the MLP layer has been replaced with the SharedMLP layer that uses templatebanks and coefficients to generate weights. The same script can be used to train the other RECAST-ViT models by replacing the defined model and loading the appropriate weights.

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random


manualSeed = 42
DEFAULT_THRESHOLD = 5e-3

random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
np.random.seed(manualSeed)
cudnn.benchmark = False
torch.backends.cudnn.enabled = False
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
FACTORS = 6  # number of groups
TEMPLATES = (
    2  # number of templates per bank, corresponds to number of layers in a group
)
MULT = 1  # optional multiplier for the number of coefficients set
num_cf = 2  # number of coefficients sets per target module
def calculate_parameters(model):
    attention_params = 0
    template_params = 0
    coefficients_params = 0
    mlp_params = 0
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    classifer_params = sum(p.numel() for p in model.head.parameters())

    for n, p in model.named_parameters():
        if ".attn." in n:
            attention_params += p.numel()
            # print("Attention params: ", n, p.numel())
        if "template_banks" in n:
            template_params += p.numel()
            # print("Template params: ", n, p.numel())
        if "mlp" in n:
            mlp_params += p.numel()
            # print("MLP params: ", n, p.numel())
        if "coefficients" in n:
            coefficients_params += p.numel()
            # print("Coefficients params: ", n, p.numel())

    print("Classifier head: ", model.head)
    print(
        f"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}"
    )
    print(f"Attention parameters: {attention_params}")
    print(f"Templates params: {template_params}")
    print(f"MLP params: {mlp_params}")
    print(f"Coefficients params: {coefficients_params}")


class MLPTemplateBank(nn.Module):
    def __init__(self, num_templates, in_features, out_features):
        super(MLPTemplateBank, self).__init__()
        self.num_templates = num_templates
        self.coefficient_shape = (num_templates, 1, 1)
        templates = [
            torch.Tensor(out_features, in_features) for _ in range(num_templates)
        ]
        for i in range(num_templates):
            nn.init.kaiming_normal_(templates[i])
        self.templates = nn.Parameter(torch.stack(templates))

    def forward(self, coefficients):
        params = self.templates * coefficients
        summed_params = torch.sum(params, dim=0)
        return summed_params

    def __repr__(self):
        return f"MLPTemplateBank(num_templates={self.templates.shape[0]}, in_features={self.templates.shape[1]}, out_features={self.templates.shape[2]}, coefficients={self.coefficient_shape})"


class SharedMLP(nn.Module):
    def __init__(
        self, bank1, bank2, act_layer=nn.GELU, norm_layer=nn.LayerNorm, drop=0.0
    ):
        super(SharedMLP, self).__init__()
        self.bank1 = None
        self.bank2 = None

        if bank1 != None and bank2 != None:
            self.bank1 = bank1
            self.bank2 = bank2
            self.coefficients1 = nn.ParameterList(
                [
                    nn.Parameter(
                        torch.zeros(bank1.coefficient_shape), requires_grad=True
                    )
                    for _ in range(num_cf)
                ]
            )
            self.coefficients2 = nn.ParameterList(
                [
                    nn.Parameter(
                        torch.zeros(bank2.coefficient_shape), requires_grad=True
                    )
                    for _ in range(num_cf)
                ]
            )
            self.bias1 = nn.Parameter(torch.zeros(bank1.templates.shape[1]))
            self.bias2 = nn.Parameter(torch.zeros(bank2.templates.shape[1]))

        self.act = act_layer()
        self.norm = nn.Identity()
        self.drop = nn.Dropout(drop)
        self.init_weights()

    def init_weights(self):
        if self.bank1 != None:
            for cf in self.coefficients1:
                nn.init.orthogonal_(cf)
        if self.bank2 != None:
            for cf in self.coefficients2:
                nn.init.orthogonal_(cf)

    def forward(self, x):
        if self.bank1 != None:
            weight1 = []
            for c in self.coefficients1:
                w = self.bank1(c)
                weight1.append(w)
            weights1 = torch.stack(weight1).mean(0)
        if self.bank2 != None:
            weight2 = []
            for c in self.coefficients2:
                w = self.bank2(c)
                weight2.append(w)
            weights2 = torch.stack(weight2).mean(0)

        x = F.linear(x, weights1, self.bias1)
        x = self.act(x)
        x = self.norm(x)
        x = F.linear(x, weights2, self.bias2)
        x = self.drop(x)
        return x


# original timm module for vision transformer
class Attention(nn.Module):
    def __init__(
        self,
        dim,
        num_heads=6,
        qkv_bias=False,
        qk_scale=None,
        attn_drop=0.0,
        proj_drop=0.0,
    ):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = qk_scale or self.head_dim**-0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = qkv[0], qkv[1], qkv[2]
        q = q.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(B, N, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)  # attention proba
        attn = self.attn_drop(attn)
        x = attn @ v  # attention output
        x = x.transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x


class Block(nn.Module):
    def __init__(
        self,
        dim,
        num_heads,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop=0.0,
        attn_drop=0.0,
        drop_path=0.0,
        act_layer=nn.GELU,
        norm_layer=nn.LayerNorm,
        bank1=None,
        bank2=None,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads,
            qkv_bias=qkv_bias,
            qk_scale=qk_scale,
            attn_drop=attn_drop,
            proj_drop=drop,
        )
        self.ls1 = nn.Identity()
        self.ls2 = nn.Identity()
        self.drop_path = nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = SharedMLP(
            bank1, bank2, act_layer=act_layer, norm_layer=norm_layer, drop=drop
        )

    def forward(self, x):
        x = x + self.drop_path(self.ls1(self.attn(self.norm1(x))))
        x = x + self.drop_path(self.ls2(self.mlp(self.norm2(x))))
        return x


class PatchEmbed(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = (img_size, img_size)
        patch_size = (patch_size, patch_size)
        self.img_size = img_size
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, padding=0
        )
        self.norm = nn.Identity()

    def forward(self, x):
        x = self.proj(x)
        x = nn.Flatten(start_dim=2, end_dim=3)(x).permute(0, 2, 1)
        x = self.norm(x)
        return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_ratio=4.0,
        qkv_bias=False,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
    ):
        super().__init__()
        self.img_size = img_size
        self.dim = embed_dim
        self.num_classes = num_classes
        self.embed_dim = embed_dim
        self.patch_size = patch_size
        self.num_features = self.embed_dim
        self.num_prefix_tokens = 1
        self.num_patches = (img_size // patch_size) ** 2
        self.num_prefix_tokens = 1
        self.has_class_token = True
        self.cls_token = nn.Parameter(torch.ones(1, 1, self.embed_dim))
        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_chans,
            embed_dim=embed_dim,
        )

        num_patches = (self.img_size // self.patch_size) ** 2
        print("Num patches: ", num_patches)
        embed_len = num_patches + self.num_prefix_tokens
        self.pos_embed = nn.Parameter(
            torch.ones(1, num_patches + self.num_prefix_tokens, embed_dim) * 0.02,
            requires_grad=True,
        )

        self.pos_drop = nn.Dropout(p=drop_rate)
        self.patch_drop = nn.Identity()
        self.fc_norm = nn.Identity()
        self.head_drop = nn.Dropout(drop_rate)

        self.num_groups = FACTORS
        self.num_layers_in_group = (
            depth // self.num_groups
        )  # how many consective layers share the same template bank
        print("Num layers in group: ", self.num_layers_in_group)
        self.num_templates = TEMPLATES
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.template_banks1 = nn.ModuleList(
            [
                MLPTemplateBank(self.num_templates, embed_dim, mlp_hidden_dim)
                for _ in range(self.num_groups)
            ]
        )
        self.template_banks2 = nn.ModuleList(
            [
                MLPTemplateBank(self.num_templates, mlp_hidden_dim, embed_dim)
                for _ in range(self.num_groups)
            ]
        )
        self.depth = depth

        dpr = [
            x.item() for x in torch.linspace(0, drop_path_rate, self.num_groups)
        ]  # stochastic depth decay rule
        self.blocks = nn.ModuleList()
        for i in range(depth):
            group_idx = i // self.num_layers_in_group
            print(group_idx)
            bank1 = self.template_banks1[group_idx]
            bank2 = self.template_banks2[group_idx]
            self.blocks.append(
                Block(
                    dim=self.embed_dim,
                    num_heads=num_heads,
                    mlp_ratio=mlp_ratio,
                    qkv_bias=qkv_bias,
                    qk_scale=qk_scale,
                    drop=drop_rate,
                    attn_drop=attn_drop_rate,
                    drop_path=dpr[group_idx],
                    norm_layer=norm_layer,
                    bank1=bank1,
                    bank2=bank2,
                )
            )
        print(f"Num blocks: {len(self.blocks)}")
        self.norm = norm_layer(self.embed_dim)
        self.head = (
            nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.normal_(m.bias, std=1e-6)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.bias, 0)
                nn.init.constant_(m.weight, 1.0)

    def _pos_embed(self, x):
        to_cat = []
        to_cat.append(self.cls_token.expand(x.shape[0], -1, -1))
        x = torch.cat(to_cat + [x], dim=1)
        x = x + self.pos_embed
        return self.pos_drop(x)

    def forward_features(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        x = self._pos_embed(x)
        x = self.patch_drop(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return x

    def forward_head(self, x):
        x = x[:, 0]
        x = self.fc_norm(x)
        x = self.head_drop(x)
        x = self.head(x)
        return x

    def forward(self, x):
        features = self.forward_features(x)
        head_output = self.forward_head(features)
        return head_output


### Dataloaders
Feel free to use any dataloader of your choice. I have used the following dataloader for my experiments.

In [None]:
import torch
import torchvision.transforms as transforms
import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import torch
import pandas as pd
import os
import json
from torchvision import datasets
from torchvision.transforms import transforms
# Write a base dataloader class for image classification
class ImageDataset(Dataset):
    def __init__(self):
        self.data_path = ""
        self.data_name = ""
        self.num_classes = 0
        self.train_transform = None
        self.train_csv_path = ""
        self.image_paths = []
        self.labels = []

    def get_num_classes(self):
        return self.num_classes

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label = self.labels[index]
        img = Image.open(img_path).convert("RGB")

        return img, label

    def __len__(self):
        return len(self.image_paths)

    @property
    def label_dict(self):
        return {i: self.class_map[i] for i in range(self.num_classes)}

    def __repr__(self):
        return f"ImageDataset({self.data_name}) with {self.__len__} instances"


class CARS(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = CARS_DATA
        self.data_name = "cars"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "cars.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values.tolist()
        self.num_classes = 196
        self.split = None
        # json file that contains the class names
        self.class_json = os.path.join(BASE_PATH, "CARS.json")
        self.class_map = json.load(open(self.class_json))


class AIRCRAFT(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = AIRCRAFT_DATA
        self.data_name = "aircraft"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "aircrafts.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 55
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "AIRCRAFTS.json")
        self.class_map = json.load(open(self.class_json))


class FLOWERS(ImageDataset):
    def __init__(self):

        super().__init__()
        self.data_path = FLOWERS_DATA
        self.data_name = "flowers"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "flowers.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 103  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "FLOWERS.json")
        self.class_map = json.load(open(self.class_json))


class SCENES(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = SCENES_DATA
        self.data_name = "scenes"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "scenes.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 67
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "SCENES.json")
        self.class_map = json.load(open(self.class_json))


class CHARS(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = CHARS_DATA
        self.data_name = "chars"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "chars.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 63  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "CHARS.json")
        self.class_map = json.load(open(self.class_json))


class BIRDS(ImageDataset):
    def __init__(
        self,
    ):
        super().__init__()
        self.data_path = BIRDS_DATA
        self.data_name = "birds"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.Normalize(mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734])
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "birds.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 201  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "BIRDS.json")
        self.class_map = json.load(open(self.class_json))


class ACTION(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = ACTION_DATA
        self.data_name = "actions"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.RandomHorizontalFlip(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "action.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 20  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "ACTION.json")
        self.class_map = json.load(open(self.class_json))


class SVHN(ImageDataset):
    # TODO: ektu tricky beparshepar
    def __init__(self, split="train", transform=None):
        super().__init__()
        self.data_path = SVHN_DATA
        self.data_name = "svhn"
        self.task_id = 6  # Assign a unique task_id for SVHN
        self.split = split
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.RandomHorizontalFlip(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        self.dataset = datasets.SVHN(root=SVHN_DATA, split=split, download=True)
        self.num_classes = 10

    def __getitem__(self, index):
        img, label = self.dataset[index]
        if self.transform:
            img = self.transform(img)
        return img, label, self.task_id

    def __len__(self):
        return len(self.dataset)


def collate_fn(batch):
    images, labels, task_ids = zip(*batch)
    images = torch.stack(images, dim=0)
    labels = torch.tensor(labels)
    task_ids = task_ids[0]
    return images, labels


class TransformedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        img, label = self.dataset[index]
        if isinstance(img, torch.Tensor):
            img = img.numpy().transpose(1, 2, 0)
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.dataset)


def get_dataloaders(
    dataset_name, train_size=0.8, val_size=0.0, batch_size=32, mode="train"
):
    if dataset_name == "cars":
        dataset = CARS()
    elif dataset_name == "aircraft":
        dataset = AIRCRAFT()
    elif dataset_name == "flowers":
        dataset = FLOWERS()
    elif dataset_name == "scenes":
        dataset = SCENES()
    elif dataset_name == "chars":
        dataset = CHARS()
    elif dataset_name == "birds":
        dataset = BIRDS()

    elif dataset_name == "cifar10":
        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
            ]
        )
        cifar_train = datasets.CIFAR10(
            root=CIFAR_DATA, train=True, download=True, transform=train_transform
        )
        cifar_test = datasets.CIFAR10(
            root=CIFAR_DATA, train=False, download=True, transform=test_transform
        )

        train_size = int(train_size * len(cifar_train))
        val_size = int(val_size * len(cifar_train))
        test_size = len(cifar_train) - train_size - val_size

        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            cifar_train, [train_size, val_size, test_size]
        )

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )

        return train_loader, val_loader, test_loader, 10

    elif dataset_name == "cifar100":
        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize(*stats, inplace=True),
            ]
        )
        cifar_train = datasets.CIFAR100(
            root=CIFAR_DATA, train=True, download=True, transform=train_transform
        )
        cifar_test = datasets.CIFAR100(
            root=CIFAR_DATA, train=False, download=True, transform=transforms.ToTensor()
        )

        train_size = int(train_size * len(cifar_train))
        val_size = int(val_size * len(cifar_train))
        test_size = len(cifar_train) - train_size - val_size

        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            cifar_train, [train_size, val_size, test_size]
        )

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )

        return train_loader, val_loader, test_loader, 100

   

    else:
        raise ValueError(f"Dataset {dataset_name} not found")

    # split the dataset into train, val, and test
    train_size = int(train_size * len(dataset))
    val_size = int(val_size * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    print(
        f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}"
    )
    # print(dataset.train_transform)
    print(dataset.test_transform)
    # Create transformed datasets for each split
    train_dataset = TransformedDataset(train_dataset, dataset.train_transform)
    val_dataset = TransformedDataset(val_dataset, dataset.test_transform)
    test_dataset = TransformedDataset(test_dataset, dataset.test_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader, dataset.get_num_classes()

### Training loop
I used a pickle file to store the dataloaders beforehand (`dataloader_dict`) and load the dataloaders in the training loop by using task name. 
`SHARED_WEIGHT` refers to the reconstructed weights of the model.

In [None]:
def calculate_parameters(model):
    attention_params = 0
    template_params = 0
    coefficients_params = 0
    mlp_params = 0
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    classifer_params = sum(p.numel() for p in model.head.parameters())

    for n, p in model.named_parameters():
        if ".attn." in n:
            attention_params += p.numel()
            # print("Attention params: ", n, p.numel())
        if "template_banks" in n:
            template_params += p.numel()
            # print("Template params: ", n, p.numel())
        if "mlp" in n:
            mlp_params += p.numel()
            # print("MLP params: ", n, p.numel())
        if "coefficients" in n:
            coefficients_params += p.numel()
            # print("Coefficients params: ", n, p.numel())

    print("Classifier head: ", model.head)
    print(
        f"Total parameters: {total_params//1000000}M, Trainable parameters: {trainable_params}, Classifier parameters: {classifer_params}"
    )
    print(f"Attention parameters: {attention_params}")
    print(f"Templates params: {template_params}")
    print(f"MLP params: {mlp_params}")
    print(f"Coefficients params: {coefficients_params}")

In [None]:
import gc 
dataloader_dict = {}
TASK_NAME = ["cars", "aircraft", "flowers", "scenes", "chars", "birds", "cifar10", "cifar100"]
num_classes = [196, 55, 103, 67, 63, 201, 10, 100]
SHARED_WEIGHT = "<SharedWeight.pt>"
for task, num_class in zip(TASK_NAME, num_classes):
    print(f"Task: {task}, Num classes: {num_class}")
    if task == "cifar100" or task == "cifar10":
        print("Using CIFAR")
        train_loader, _, test_loader, num_class = get_dataloaders(
            task, train_size=0.8, val_size=0.0, batch_size=256, mode="train"
        )
    else:
        train_loader = dataloader_dict[task]["train"]
        test_loader = dataloader_dict[task]["test"]
    print(
        f"Train size: {len(train_loader.dataset)}, Test size: {len(test_loader.dataset)}"
    )

    training_model = VisionTransformer(
        img_size=224,
        patch_size=16,
        in_chans=3,
        num_classes=1000,
        embed_dim=384,
        depth=12,
        num_heads=6,
        mlp_ratio=4.0,
        qkv_bias=True,
        qk_scale=None,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0,
        norm_layer=nn.LayerNorm,
    )
    target_state_dict = torch.load(SHARED_WEIGHT)["state_dict"]
    print(training_model.load_state_dict(target_state_dict, strict=False))
    for param in training_model.parameters():
        param.requires_grad = False

    training_model.head = nn.Linear(training_model.embed_dim, num_class, bias=True)
    training_model.head.weight.requires_grad = True
    training_model.head.bias.requires_grad = True

    cf_parameters = []
    for n, p in training_model.named_parameters():
        if "coefficients" in n:
            p.requires_grad = True
            cf_parameters.append(p)
        if p.requires_grad:
            print(f"Trainable: {n}")
    calculate_parameters(training_model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    training_model.to(device)
    optimizer = torch.optim.AdamW(
        training_model.parameters(), lr=2e-3, weight_decay=1e-6
    )  # ViT
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=33, gamma=0.1)
    print(optimizer)
    criterion = nn.CrossEntropyLoss()

    printing_iter = len(train_loader) // 4
    num_epochs = 100  # TODO
    best_acc = 0.0
    best_loss = 0.0
    best_model = None
    for epoch in range(num_epochs):
        training_model.train()
        mini_batch_size = 128  # TODO
        classifier_loss = 0.0
        for i, (images, labels) in tqdm.tqdm(
            enumerate(train_loader), total=len(train_loader), desc=f"training"
        ):
            for j in range(0, images.size(0), mini_batch_size):
                optimizer.zero_grad()
                classifier_output = training_model(
                    images[j : j + mini_batch_size].to(device)
                )
                loss = criterion(
                    classifier_output, labels[j : j + mini_batch_size].to(device)
                )
                classifier_loss += loss
                loss.backward()
                # gradient clipping
                torch.nn.utils.clip_grad_norm_(training_model.parameters(), 1.0)
                optimizer.step()

            if i % printing_iter == 0:
                print(
                    f"Epoch {epoch}, Iteration {i}, LR: {optimizer.param_groups[0]['lr']}, Loss: {classifier_loss / (i + 1)}"
                )

        training_model.eval()

        with torch.no_grad():
            total = 0
            correct = 0
            val_loss = 0.0
            total_samples = 0
            for i, (images, labels) in tqdm.tqdm(
                enumerate(test_loader), total=len(test_loader), desc="Evaluating"
            ):
                for j in range(0, images.size(0), mini_batch_size):
                    total_samples += images.size(0)
                    sub_images = images[j : j + mini_batch_size].to(device)
                    sub_labels = labels[j : j + mini_batch_size].to(device)
                    classifier_output = training_model(sub_images)
                    _, predicted = torch.max(classifier_output.data, 1)
                    total += sub_labels.size(0)
                    correct += (predicted == sub_labels).sum().item()
                    sub_val_loss = criterion(classifier_output, sub_labels)
                    val_loss += sub_val_loss.item()

            accuracy = 100 * correct / total
            print(
                f"Validation Accuracy: {accuracy}, Average Loss: {val_loss / total_samples}"
            )

            if accuracy > best_acc:
                best_acc = accuracy
                best_loss = val_loss / total_samples
                best_model = training_model.state_dict()
                print(f"New best accuracy: {best_acc}, and current Loss: {best_loss}")
        scheduler.step()

    print(f"Best accuracy: {best_acc}, Loss: {best_loss}")

    del (
        training_model,
        optimizer,
        criterion,
        scheduler,
        train_loader,
        test_loader,
        best_model,
    )
    gc.collect()
    torch.cuda.empty_cache()
