### Model Definition

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import numpy as np
import random


manualSeed = 42
DEFAULT_THRESHOLD = 5e-3

random.seed(manualSeed)
torch.manual_seed(manualSeed)
torch.cuda.manual_seed(manualSeed)
np.random.seed(manualSeed)
cudnn.benchmark = False
torch.backends.cudnn.enabled = False
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device: ", device)
GEN_KERNEL = 3
num_cf = 2


class TemplateBank(nn.Module):
    def __init__(self, num_templates, in_planes, out_planes, kernel_size):
        super(TemplateBank, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.coefficient_shape = (num_templates, 1, 1, 1, 1)
        self.kernel_size = kernel_size
        templates = [
            torch.Tensor(out_planes, in_planes, kernel_size, kernel_size)
            for _ in range(num_templates)
        ]
        for i in range(num_templates):
            nn.init.kaiming_normal_(templates[i])
        self.templates = nn.Parameter(
            torch.stack(templates)
        )  # this is what we will freeze later

    def forward(self, coefficients):
        weights = (self.templates * coefficients).sum(0)
        return weights

    def __repr__(self):
        return (
            self.__class__.__name__
            + " ("
            + "num_templates="
            + str(self.coefficient_shape[0])
            + ", kernel_size="
            + str(self.kernel_size)
            + ")"
            + ", in_planes="
            + str(self.in_planes)
            + ", out_planes="
            + str(self.out_planes)
        )


class SConv2d(nn.Module):
    # TARGET MODULE
    def __init__(self, bank, stride=1, padding=1):
        super(SConv2d, self).__init__()
        self.stride = stride
        self.padding = padding
        self.bank = bank
        self.num_templates = bank.coefficient_shape[0]

        self.coefficients = nn.ParameterList(
            [nn.Parameter(torch.zeros(bank.coefficient_shape)) for _ in range(num_cf)]
        )

    def forward(self, input):
        param_list = []
        for i in range(len(self.coefficients)):
            params = self.bank(self.coefficients[i])
            param_list.append(params)

        final_params = torch.stack(param_list).mean(0)
        return F.conv2d(input, final_params, stride=self.stride, padding=self.padding)


class CustomResidualBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        stride=1,
        downsample=None,
        bank1=None,
        bank2=None,
    ):
        super(CustomResidualBlock, self).__init__()
        self.bank1 = bank1
        self.bank2 = bank2

        # Ensure padding is always 1 for 3x3 convolutions
        if self.bank1 and self.bank2:
            self.conv1 = SConv2d(bank1, stride=stride, padding=1)
            self.conv2 = SConv2d(bank2, stride=1, padding=1)
        else:
            self.conv1 = nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                stride=stride,
                padding=1,
                bias=False,
            )
            self.conv2 = nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                padding=1,
                bias=False,
            )

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Implement downsample as 1x1 convolution when needed
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.downsample = None

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, SConv2d):
                for coefficient in m.coefficients:
                    nn.init.orthogonal_(coefficient)

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out


class ResNetTPB(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNetTPB, self).__init__()
        self.inplanes = 64
        self.layers = layers
        self.conv1 = nn.Conv2d(
            3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.inplanes)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes:
            downsample = nn.Sequential(
                nn.Conv2d(
                    self.inplanes, planes, kernel_size=1, stride=stride, bias=False
                ),
                nn.BatchNorm2d(planes),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))

        # DYNAMICALLY CALCULATE THE NUMBER OF TEMPLATES TO USE FOR EACH RESIDUAL BLOCK
        # Calculate parameters for remaining blocks
        params_per_conv = 9 * planes * planes
        params_per_template = 9 * planes * planes
        num_templates1 = max(
            1, int((blocks - 1) * params_per_conv / params_per_template)
        )
        num_templates2 = (
            num_templates1  # You could potentially use a different calculation here
        )

        print(
            f"Layer with {planes} planes, {blocks} blocks, using {num_templates1} templates for conv1 and {num_templates2} for conv2"
        )

        # Create separate TemplateBanks for conv1 and conv2
        tpbank1 = TemplateBank(num_templates1, planes, planes, GEN_KERNEL)
        tpbank2 = TemplateBank(num_templates2, planes, planes, GEN_KERNEL)

        self.inplanes = planes
        for i in range(1, blocks):
            layers.append(
                block(
                    in_channels=self.inplanes,
                    out_channels=planes,
                    bank1=tpbank1,
                    bank2=tpbank2,
                )
            )

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


def test():
    net = ResNetTPB(CustomResidualBlock, [2, 2, 2, 2])
    y = net(torch.randn(1, 3, 32, 32))
    print(y.size())


### Dataloaders
Feel free to use any dataloader of your choice. I have used the following dataloader for my experiments.

In [None]:
import torch
import torchvision.transforms as transforms
import tqdm
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import torch
import pandas as pd
import os
import json
from torchvision import datasets
from torchvision.transforms import transforms
# Write a base dataloader class for image classification
class ImageDataset(Dataset):
    def __init__(self):
        self.data_path = ""
        self.data_name = ""
        self.num_classes = 0
        self.train_transform = None
        self.train_csv_path = ""
        self.image_paths = []
        self.labels = []

    def get_num_classes(self):
        return self.num_classes

    def __getitem__(self, index):
        img_path = self.image_paths[index]
        label = self.labels[index]
        img = Image.open(img_path).convert("RGB")

        return img, label

    def __len__(self):
        return len(self.image_paths)

    @property
    def label_dict(self):
        return {i: self.class_map[i] for i in range(self.num_classes)}

    def __repr__(self):
        return f"ImageDataset({self.data_name}) with {self.__len__} instances"


class CARS(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = CARS_DATA
        self.data_name = "cars"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[-0.0639, 0.0145, 0.2118], std=[1.2796, 1.3035, 1.3343]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "cars.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values.tolist()
        self.num_classes = 196
        self.split = None
        # json file that contains the class names
        self.class_json = os.path.join(BASE_PATH, "CARS.json")
        self.class_map = json.load(open(self.class_json))


class AIRCRAFT(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = AIRCRAFT_DATA
        self.data_name = "aircraft"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[-0.0266, 0.2407, 0.5663], std=[0.9745, 0.9684, 1.1040]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "aircrafts.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 55
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "AIRCRAFTS.json")
        self.class_map = json.load(open(self.class_json))


class FLOWERS(ImageDataset):
    def __init__(self):

        super().__init__()
        self.data_path = FLOWERS_DATA
        self.data_name = "flowers"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[0.5642, 0.7694, 0.8410], std=[0.2560, 0.2589, 0.2783]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "flowers.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 103  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "FLOWERS.json")
        self.class_map = json.load(open(self.class_json))


class SCENES(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = SCENES_DATA
        self.data_name = "scenes"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.Normalize(
                    mean=[-0.0081, -0.1473, -0.1866], std=[1.1616, 1.1583, 1.1599]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "scenes.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 67
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "SCENES.json")
        self.class_map = json.load(open(self.class_json))


class CHARS(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = CHARS_DATA
        self.data_name = "chars"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.Normalize(mean=[1.4986, 1.6615, 1.8764], std=[1.6015, 1.6373, 1.6300])
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "chars.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 63  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "CHARS.json")
        self.class_map = json.load(open(self.class_json))


class BIRDS(ImageDataset):
    def __init__(
        self,
    ):
        super().__init__()
        self.data_path = BIRDS_DATA
        self.data_name = "birds"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ColorJitter(
                    brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2
                ),
                transforms.RandomAffine(
                    degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)
                ),
                transforms.Normalize(
                    mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734]
                ),
            ]
        )
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.Normalize(mean=[0.0049, 0.1962, 0.1152], std=[1.0027, 1.0053, 1.1734])
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "birds.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 201  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "BIRDS.json")
        self.class_map = json.load(open(self.class_json))


class ACTION(ImageDataset):
    def __init__(self):
        super().__init__()
        self.data_path = ACTION_DATA
        self.data_name = "actions"
        self.train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.RandomHorizontalFlip(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        self.train_csv_path = os.path.join(BASE_PATH, "action.csv")
        self.image_paths = pd.read_csv(self.train_csv_path)["fname"].values
        self.labels = pd.read_csv(self.train_csv_path)["class"].values
        self.num_classes = 20  # not 0 indexed
        self.split = None
        self.class_json = os.path.join(BASE_PATH, "ACTION.json")
        self.class_map = json.load(open(self.class_json))


class SVHN(ImageDataset):
    # TODO: ektu tricky beparshepar
    def __init__(self, split="train", transform=None):
        super().__init__()
        self.data_path = SVHN_DATA
        self.data_name = "svhn"
        self.task_id = 6  # Assign a unique task_id for SVHN
        self.split = split
        self.transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                # transforms.RandomHorizontalFlip(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )
        self.dataset = datasets.SVHN(root=SVHN_DATA, split=split, download=True)
        self.num_classes = 10

    def __getitem__(self, index):
        img, label = self.dataset[index]
        if self.transform:
            img = self.transform(img)
        return img, label, self.task_id

    def __len__(self):
        return len(self.dataset)


def collate_fn(batch):
    images, labels, task_ids = zip(*batch)
    images = torch.stack(images, dim=0)
    labels = torch.tensor(labels)
    task_ids = task_ids[0]
    return images, labels


class TransformedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, transform):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        img, label = self.dataset[index]
        if isinstance(img, torch.Tensor):
            img = img.numpy().transpose(1, 2, 0)
        if self.transform:
            img = self.transform(img)
        return img, label

    def __len__(self):
        return len(self.dataset)


def get_dataloaders(
    dataset_name, train_size=0.8, val_size=0.0, batch_size=32, mode="train"
):
    if dataset_name == "cars":
        dataset = CARS()
    elif dataset_name == "aircraft":
        dataset = AIRCRAFT()
    elif dataset_name == "flowers":
        dataset = FLOWERS()
    elif dataset_name == "scenes":
        dataset = SCENES()
    elif dataset_name == "chars":
        dataset = CHARS()
    elif dataset_name == "birds":
        dataset = BIRDS()

    elif dataset_name == "cifar10":
        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
            ]
        )
        test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
            ]
        )
        cifar_train = datasets.CIFAR10(
            root=CIFAR_DATA, train=True, download=True, transform=train_transform
        )
        cifar_test = datasets.CIFAR10(
            root=CIFAR_DATA, train=False, download=True, transform=test_transform
        )

        train_size = int(train_size * len(cifar_train))
        val_size = int(val_size * len(cifar_train))
        test_size = len(cifar_train) - train_size - val_size

        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            cifar_train, [train_size, val_size, test_size]
        )

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )

        return train_loader, val_loader, test_loader, 10

    elif dataset_name == "cifar100":
        stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        train_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize((224, 224)),
                transforms.CenterCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize(*stats, inplace=True),
            ]
        )
        cifar_train = datasets.CIFAR100(
            root=CIFAR_DATA, train=True, download=True, transform=train_transform
        )
        cifar_test = datasets.CIFAR100(
            root=CIFAR_DATA, train=False, download=True, transform=transforms.ToTensor()
        )

        train_size = int(train_size * len(cifar_train))
        val_size = int(val_size * len(cifar_train))
        test_size = len(cifar_train) - train_size - val_size

        train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
            cifar_train, [train_size, val_size, test_size]
        )

        train_loader = torch.utils.data.DataLoader(
            train_dataset, batch_size=batch_size, shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_dataset, batch_size=batch_size, shuffle=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_dataset, batch_size=batch_size, shuffle=False
        )

        return train_loader, val_loader, test_loader, 100

   

    else:
        raise ValueError(f"Dataset {dataset_name} not found")

    # split the dataset into train, val, and test
    train_size = int(train_size * len(dataset))
    val_size = int(val_size * len(dataset))
    test_size = len(dataset) - train_size - val_size

    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size]
    )
    print(
        f"Train size: {len(train_dataset)}, Val size: {len(val_dataset)}, Test size: {len(test_dataset)}"
    )
    # print(dataset.train_transform)
    print(dataset.test_transform)
    # Create transformed datasets for each split
    train_dataset = TransformedDataset(train_dataset, dataset.train_transform)
    val_dataset = TransformedDataset(val_dataset, dataset.test_transform)
    test_dataset = TransformedDataset(test_dataset, dataset.test_transform)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
    )
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        pin_memory=True,
    )
    test_loader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        drop_last=True,
        pin_memory=True,
    )

    return train_loader, val_loader, test_loader, dataset.get_num_classes()

### Training Loop


In [None]:
def calculate_parameters(model):
    params = sum([np.prod(p.size()) for p in model.parameters()])

    trainable_params = sum(
        [
            np.prod(p.size())
            for p in filter(lambda p: p.requires_grad, model.parameters())
        ]
    )
    batch_norm_params = 0
    for name, module in model.named_modules():
        if isinstance(module, nn.BatchNorm2d):
            batch_norm_params += np.prod(module.weight.size())

    untrainable_params = params - trainable_params
    template_params = 0
    coefficients = 0
    for name, param in model.named_parameters():
        if "templates" in name:
            template_params += np.prod(param.size())
        if "coefficients" in name:
            coefficients += np.prod(param.size())

    print("* number of parameters: {}".format(params))
    print("* untrainable params: {}".format(untrainable_params))
    print("* trainable params: {}".format(trainable_params))

    print("* template params: {}".format(template_params))
    print("* coefficients: {}".format(coefficients))
    print("* batch norm params: {}".format(batch_norm_params))


In [None]:
import gc 
dataloader_dict = {}
TASK_NAME = ["cars", "aircraft", "flowers", "scenes", "chars", "birds", "cifar10", "cifar100"]
num_classes = [196, 55, 103, 67, 63, 201, 10, 100]
SHARED_WEIGHT = "<SharedWeight.pt>"
for task, num_class in zip(TASK_NAME, num_classes):
    print(f"Task: {task}, Num classes: {num_class}")
    if task == "cifar100" or task == "cifar10":
        print("Using CIFAR")
        train_loader, _, test_loader, num_class = get_dataloaders(
            task, train_size=0.8, val_size=0.0, batch_size=256, mode="train"
        )
    else:
        train_loader = dataloader_dict[task]["train"]
        test_loader = dataloader_dict[task]["test"]
    print(
        f"Train size: {len(train_loader.dataset)}, Test size: {len(test_loader.dataset)}"
    )

    training_model = ResNetTPB(CustomResidualBlock, [3, 4, 6, 3], num_classes=1000)
    print(
        training_model.load_state_dict(
            torch.load(SHARED_WEIGHT)["state_dict"], strict=True  # TODO
        )
    )

    for param in training_model.parameters():
        param.requires_grad = False

    resnet_embeddim = training_model.fc.in_features
    training_model.fc = nn.Linear(resnet_embeddim, num_class, bias=True)
    training_model.fc.weight.requires_grad = True
    training_model.fc.bias.requires_grad = True
    cf_parameters = []
    for n, p in training_model.named_parameters():
        if "coefficients" in n:
            p.requires_grad = True
            cf_parameters.append(p)

        if p.requires_grad:
            print(f"Trainable: {n}")
    calculate_parameters(training_model)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    training_model.to(device)
    optimizer = torch.optim.AdamW(
        training_model.parameters(), lr=2e-3, weight_decay=1e-46
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=33, gamma=0.1)
    criterion = nn.CrossEntropyLoss()

    printing_iter = len(train_loader) // 6
    num_epochs = 100
    best_acc = 0.0
    best_loss = 0.0
    best_model = None
    for epoch in range(num_epochs):
        training_model.train()
        mini_batch_size = 64  # TODO
        classifier_loss = 0.0
        for i, (images, labels) in tqdm.tqdm(
            enumerate(train_loader), total=len(train_loader), desc=f"training"
        ):
            for j in range(0, images.size(0), mini_batch_size):
                optimizer.zero_grad()
                classifier_output = training_model(
                    images[j : j + mini_batch_size].to(device)
                )
                loss = criterion(
                    classifier_output, labels[j : j + mini_batch_size].to(device)
                )
                classifier_loss += loss
                loss.backward()
                # gradient clipping
                torch.nn.utils.clip_grad_norm_(training_model.parameters(), 1.0)
                optimizer.step()

            if i % printing_iter == 0:
                print(
                    f"Epoch {epoch}, Iteration {i}, LR: {optimizer.param_groups[0]['lr']}, Loss: {classifier_loss / (i + 1)}"
                )

        training_model.eval()

        with torch.no_grad():
            total = 0
            correct = 0
            val_loss = 0.0
            total_samples = 0
            for i, (images, labels) in tqdm.tqdm(
                enumerate(test_loader), total=len(test_loader), desc="Evaluating"
            ):
                for j in range(0, images.size(0), mini_batch_size):
                    total_samples += images.size(0)
                    sub_images = images[j : j + mini_batch_size].to(device)
                    sub_labels = labels[j : j + mini_batch_size].to(device)
                    classifier_output = training_model(sub_images)
                    _, predicted = torch.max(classifier_output.data, 1)
                    total += sub_labels.size(0)
                    correct += (predicted == sub_labels).sum().item()
                    sub_val_loss = criterion(classifier_output, sub_labels)
                    val_loss += sub_val_loss.item()

            accuracy = 100 * correct / total
            print(
                f"Validation Accuracy: {accuracy}, Average Loss: {val_loss / total_samples}"
            )

            if accuracy > best_acc:
                best_acc = accuracy
                best_loss = val_loss / total_samples
                best_model = training_model.state_dict()
                print(f"New best accuracy: {best_acc}, and current Loss: {best_loss}")
        scheduler.step()

    print(f"Best accuracy: {best_acc}, Loss: {best_loss}")

    del (
        training_model,
        optimizer,
        criterion,
        scheduler,
        train_loader,
        test_loader,
        best_model,
    )
    gc.collect()
    torch.cuda.empty_cache()
