# Installing python modules used later on

In [None]:
#Uncomment if training on google colab!

!pip install pytorch_msssim
!pip install torchinfo

# Mount Google Drive

In [None]:
#Uncomment if training on google colab!

from google.colab import drive
drive.mount('/content/drive', force_remount=True)


# Importing relevant things

In [None]:
import torch
from torchinfo import summary
import random
import numpy as np
from torch import nn, optim
from tqdm.auto import tqdm
from torch.utils.data import DataLoader, Subset, TensorDataset
import os
import matplotlib.pyplot as plt
from torchvision.transforms import transforms
import torch.nn.functional as F
from sklearn import metrics
from sklearn.cluster import KMeans
from scipy.optimize import linear_sum_assignment
from pytorch_msssim import SSIM
import random

# Config params

In [None]:
class Config:
    drive_path = "/content/drive/MyDrive"

    # uncomment if training in colab
    datasets_path = f"{drive_path}/splitted_cifar10_dataset.npz"
    weights_path = f"{drive_path}/weights"

    #for local
    # datasets_path = f"../../Dataset/splitted_cifar10_dataset.npz"
    # weights_path = f"../weights"

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    sigma = 0.1
    num_nodes = 11

    learning_rate = 0.0015
    weight_decay = 1e-4
    grad_clip = 0.1

    batch_size = 1
    ucc_limit = 4
    rcc_limit = 10
    bag_size = 12

config = Config()

print(f"Training on {config.device}")

# Loading the dataset

In [None]:
splitted_dataset = np.load(config.datasets_path)

x_train = splitted_dataset['x_train']
print(f"x_train shape :{x_train.shape}")

y_train = splitted_dataset['y_train']
print(f"y_train shape :{y_train.shape}")

x_val = splitted_dataset['x_val']
print(f"x_val shape :{x_val.shape}")

y_val = splitted_dataset['y_val']
print(f"y_val shape :{y_val.shape}")

x_test = splitted_dataset['x_test']
print(f"x_test shape :{x_test.shape}")

y_test = splitted_dataset['y_test']
print(f"y_test shape: {y_test.shape}")


# Custom Dataloader

This dataloader moves data directly to the device when yielding data

In [None]:
'''
Wrapper on top of dataloader to move tensors to device
'''
class DeviceDataLoader(DataLoader):
    def __init__(self, dataset, batch_size, shuffle=True):
        super().__init__(
            dataset,
            batch_size=batch_size,
            shuffle=shuffle
        )
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def __iter__(self):
        for batch in super().__iter__():
            yield self._move_to_device(batch)

    def _move_to_device(self, batch):
        if isinstance(batch, torch.Tensor):
            return batch.to(self.device)
        elif isinstance(batch, (list, tuple)):
            return [self._move_to_device(item) for item in batch]
        elif isinstance(batch, dict):
            return {key: self._move_to_device(value) for key, value in batch.items()}
        else:
            return batch


# Defining Dataset

In [None]:
class Dataset:
    def __init__(self, x_train, y_train, x_val, y_val, x_test, y_test,
                 debug=False, apply_augmentation=True,
                 batch_size=config.batch_size, bag_size=config.bag_size,
                 ucc_limit=config.ucc_limit, rcc_limit=config.rcc_limit
                 ):
        '''
        Note these are numpy arrays

        :param x_train:
        :param y_train:
        :param x_val:
        :param y_val:
        :param x_test:
        :param y_test:
        '''
        self.num_classes = rcc_limit
        self.bag_size = bag_size
        self.ucc_limit = ucc_limit
        self.rcc_limit = rcc_limit
        self.batch_size = batch_size
        self.debug = debug
        self.debug_bag_size = 6
        self.apply_augmentation = apply_augmentation

        # transforms to apply
        self.transforms = [
            # normal
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.ToTensor()
            ]),
            # random horizontal flips
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ]),
            # random rotations
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomRotation(3),
                transforms.ToTensor()
            ]),
            # random rotations & flips
            transforms.Compose([
                transforms.ToPILImage(),
                transforms.RandomRotation(3),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor()
            ])
        ]

        # converting it all into a tensor (it's not yet one hotified)
        self.x_train = torch.from_numpy(x_train).to(dtype=torch.float32)
        # normalizing the dataset, remove if it doesnt work
        # self.x_train, self.train_mu, self.train_std = self.normalize(self.x_train)
        self.y_train = torch.from_numpy(y_train).to(dtype=torch.float32)

        self.x_test = torch.from_numpy(x_test).to(dtype=torch.float32)
        # normalizing the dataset, remove if it doesnt work
        # self.x_test, self.test_mu, self.test_std = self.normalize(self.x_test)
        self.y_test = torch.from_numpy(y_test).to(dtype=torch.float32)

        # restricting x_val a lot more to 1/10th the test size
        # Generate random indices for sampling without replacement
        random_indices = torch.randperm(len(x_test))
        x_val = x_val[random_indices[:len(x_test)//10]]
        y_val = y_val[random_indices[:len(x_test)//10]]

        self.x_val = torch.from_numpy(x_val).to(dtype=torch.float32)
        # normalizing the dataset, remove if it doesnt work
        # self.x_val, self.val_mu, self.val_std = self.normalize(self.x_val)
        self.y_val = torch.from_numpy(y_val).to(dtype=torch.float32)

        # Dividing all images by 255 to get an image in range 0->1
        self.x_train /= 255
        self.x_test /= 255
        self.x_val /= 255

        print("Converted numpy to torch tensors")

        # create subdatasets ([class_0_imgs, class_1_imgs,... class_9_imgs])
        self.train_sub_datasets = self.create_sub_datasets(self.x_train, self.y_train)
        self.test_sub_datasets = self.create_sub_datasets(self.x_test, self.y_test)
        self.val_sub_datasets = self.create_sub_datasets(self.x_val, self.y_val)

        if not self.debug:
            # create dataloaders
            print("Creating KDE dataloaders")
            self.kde_test_dataloaders = self.create_kde_dataloaders(self.test_sub_datasets)

            print("Created KDE dataloaders, now creating autoencoder dataloaders")
            # batch size is 1 as we care about image level features anyway
            self.autoencoder_test_dataloaders = [DeviceDataLoader(test_sub_dataset, 1) for test_sub_dataset in
                                                 self.test_sub_datasets]
        else:
            # create dataloaders
            print("Creating debug KDE dataloaders")
            self.kde_test_dataloaders = self.create_kde_dataloaders(self.val_sub_datasets)

            print("Created debug KDE dataloaders, now creating debug autoencoder dataloaders")
            # batch size is 1 as we care about image level features anyway
            self.autoencoder_test_dataloaders = [DeviceDataLoader(val_sub_dataset, 1) for val_sub_dataset in
                                                 self.val_sub_datasets]
        print("Created autoencoder dataloaders, now creating ucc dataloaders")
        self.ucc_train_dataloader, self.ucc_test_dataloader, self.ucc_val_dataloader = self.get_dataloaders_for_ucc()
        print("Created ucc dataloaders, now creating rcc dataloaders")
        self.ucc_rcc_train_dataloader, self.ucc_rcc_test_dataloader, self.ucc_rcc_val_dataloader = self.get_dataloaders_for_ucc_and_rcc()

        print("Initilized all dataloaders")

    # create dataloaders
    def create_kde_dataloaders(self, sub_datasets):
        kde_datasets = []

        for chosen_class, pure_sub_dataset in tqdm(enumerate(sub_datasets)):
            total_bags_for_pure_subset = len(pure_sub_dataset) // self.bag_size
            bag_tensors = []

            pure_sub_dataset_idx = 0
            current_bag = self.create_bag()

            while pure_sub_dataset_idx < len(pure_sub_dataset):
                # get the image from this pure sub dataset
                img = pure_sub_dataset[pure_sub_dataset_idx][0]
                bag_idx = pure_sub_dataset_idx % self.bag_size
                current_bag[bag_idx] = img

                if bag_idx == self.bag_size - 1:
                    # the last value has been filled, so add it to the total bags
                    bag_tensors.append(torch.stack(current_bag))

                    # create a new bag for the next set of bags to be filled
                    current_bag = self.create_bag()
                pure_sub_dataset_idx += 1

            kde_datasets.append(TensorDataset(torch.stack(bag_tensors)))

        print("Finished constructing the kde_datasets from the test dataset, now creating dataloaders")

        # NOTE. the batch size here can be different if required.
        kde_data_loaders = [DeviceDataLoader(kde_sub_dataset, self.batch_size) for kde_sub_dataset in kde_datasets]
        return kde_data_loaders

    def get_dataloaders_for_ucc(self):
        train_dataset_with_ucc, test_dataset_with_ucc, val_dataset_with_ucc = self.construct_datasets_with_ucc()
        return DeviceDataLoader(train_dataset_with_ucc, self.batch_size), \
            DeviceDataLoader(test_dataset_with_ucc, self.batch_size), \
            DeviceDataLoader(val_dataset_with_ucc, self.batch_size)

    def get_dataloaders_for_ucc_and_rcc(self):
        train_dataset_with_ucc_and_rcc, test_dataset_with_ucc_and_rcc, val_dataset_with_ucc_and_rcc = self.construct_datasets_with_ucc_and_rcc()
        return DeviceDataLoader(train_dataset_with_ucc_and_rcc, self.batch_size), \
            DeviceDataLoader(test_dataset_with_ucc_and_rcc, self.batch_size), \
            DeviceDataLoader(val_dataset_with_ucc_and_rcc, self.batch_size)

    # create sub datasets
    def create_sub_datasets(self, x, y):
        # Initialize an empty list to store the sub-datasets
        sub_datasets = []

        # Split the original dataset into 10 sub-datasets
        for class_label in range(self.num_classes):
            # Select indices for the current class
            indices = torch.where(y == class_label)[0]

            # Extract data for the current class
            x_class = [torch.tensor(item).permute(2, 0, 1) for item in x[indices]]
            y_class = [torch.tensor(item) for item in y[indices]]

            if len(x_class) > 0 and len(y_class) > 0:
                # Create a TensorDataset for the current class
                class_dataset = TensorDataset(torch.stack(x_class), torch.stack(y_class))

                # Append the current class dataset to the list
                sub_datasets.append(class_dataset)
        return sub_datasets

    # pick random image from ith class
    def pick_random_from_ith_sub_dataset(self, sub_datasets, i, is_eval):
        assert 0 <= i < self.num_classes
        sub_dataset = sub_datasets[i]
        sub_dataset_length = len(sub_dataset)
        random_idx = random.randint(0, sub_dataset_length - 1)
        random_img = sub_dataset[random_idx][0]
        if self.apply_augmentation and not is_eval:
            random_transform = random.choice(self.transforms)
            random_img = random_transform(random_img)
        return random_img.to(torch.float32)

    # construct UCC dataset
    def construct_datasets_with_ucc(self):
        train_dataset_with_ucc = self.construct_dataset_with_ucc(self.train_sub_datasets, False)
        test_dataset_with_ucc = self.construct_dataset_with_ucc(self.test_sub_datasets, True)
        val_dataset_with_ucc = self.construct_dataset_with_ucc(self.val_sub_datasets, True)

        return train_dataset_with_ucc, test_dataset_with_ucc, val_dataset_with_ucc

    def construct_dataset_with_ucc(self, sub_datasets, is_eval):
        bag_tensors = []
        ucc_tensors = []

        # calculate no of bags needed (NOTE: we are not going to pick every image here!)
        total_bags = 0
        for sub_dataset in sub_datasets:
            total_bags += len(sub_dataset)
        total_bags = total_bags // self.bag_size
        loop = self.debug_bag_size if self.debug else total_bags

        # NOTE: we can technically pick more images before I am not enforcing that I am picking every image.
        for b in tqdm(range(loop)):
            # this will keep picking ucc (1 -> 4) in a cyclic manner
            ucc = (b % self.ucc_limit) + 1
            bag_tensor = self.create_bag()

            # you are choosing random classes of size {ucc}. Using this knowledge you have to fill the bag up.
            img_per_class = self.bag_size // ucc
            chosen_classes = random.sample(list(range(self.num_classes)), ucc)
            class_at_each_pos_in_bag = []
            for chosen_class in chosen_classes:
                class_at_each_pos_in_bag.extend([chosen_class] * img_per_class)

            for bag_pos, chosen_class in enumerate(class_at_each_pos_in_bag):
                bag_tensor[bag_pos] = self.pick_random_from_ith_sub_dataset(sub_datasets, chosen_class, is_eval)

            '''
            #Uncomment this section if you want to try random filling
            random_bag_pos = random.sample(list(range(self.bag_size)), self.bag_size)

            # fill all the values for ucc first and then fill the remaining with random sampling with replacement
            for chosen_class, bag_pos in zip(chosen_classes, random_bag_pos[:len(chosen_classes)]):
                bag_tensor[bag_pos] = self.pick_random_from_ith_sub_dataset(sub_datasets, chosen_class, is_eval)

            # fill bag_tensor pos by pos
            for bag_pos in random_bag_pos[len(chosen_classes):]:
                chosen_class = random.choice(chosen_classes)
                bag_tensor[bag_pos] = self.pick_random_from_ith_sub_dataset(sub_datasets, chosen_class, is_eval)

            '''
            bag_tensors.append(torch.stack(bag_tensor))
            ucc_tensors.append(self.one_hot(ucc, self.ucc_limit))

        return TensorDataset(
            torch.stack(bag_tensors),
            torch.stack(ucc_tensors)
        )

    # create UCC and RCC dataset
    def construct_datasets_with_ucc_and_rcc(self):
        train_dataset_with_ucc_and_rcc = self.construct_dataset_with_ucc_and_rcc(self.train_sub_datasets, False)
        test_dataset_with_ucc_and_rcc = self.construct_dataset_with_ucc_and_rcc(self.val_sub_datasets, True)
        val_dataset_with_ucc_and_rcc = self.construct_dataset_with_ucc_and_rcc(self.test_sub_datasets, True)

        return train_dataset_with_ucc_and_rcc, test_dataset_with_ucc_and_rcc, val_dataset_with_ucc_and_rcc

    def construct_dataset_with_ucc_and_rcc(self, sub_datasets, is_eval):
        bag_tensors = []
        ucc_tensors = []
        rcc_tensors = []

        # calculate no of bags needed (NOTE: we are not going to pick every image here!)
        total_bags = 0
        for sub_dataset in sub_datasets:
            total_bags += len(sub_dataset)
        total_bags = total_bags // self.bag_size
        loop = self.debug_bag_size if self.debug else total_bags

        for b in tqdm(range(loop)):  # use this for local testing!
            # for b in tqdm(range(total_bags)):
            # this will keep picking ucc (1 -> 4) in a cyclic manner
            ucc = (b % self.ucc_limit) + 1
            bag_tensor = self.create_bag()
            rcc_tensor = [0] * self.rcc_limit

            # you are choosing random classes of size {ucc}. Using this knowledge you have to fill the bag up.
            img_per_class = self.bag_size // ucc
            chosen_classes = random.sample(list(range(self.num_classes)), ucc)
            class_at_each_pos_in_bag = []
            for chosen_class in chosen_classes:
                class_at_each_pos_in_bag.extend([chosen_class] * img_per_class)

            for bag_pos, chosen_class in enumerate(class_at_each_pos_in_bag):
                bag_tensor[bag_pos] = self.pick_random_from_ith_sub_dataset(sub_datasets, chosen_class, is_eval)
                rcc_tensor[chosen_class] += 1

            '''
            #Uncomment this section if you want to try random filling
            random_bag_pos = random.sample(list(range(self.bag_size)), self.bag_size)

            # fill all the values for ucc first and then fill the remaining with random sampling with replacement
            for chosen_class, bag_pos in zip(chosen_classes, random_bag_pos[:len(chosen_classes)]):
                bag_tensor[bag_pos] = self.pick_random_from_ith_sub_dataset(sub_datasets, chosen_class, is_eval)
                rcc_tensor[chosen_class] += 1

            # fill bag_tensor pos by pos
            for bag_pos in random_bag_pos[len(chosen_classes):]:
                chosen_class = random.choice(chosen_classes)
                bag_tensor[bag_pos] = self.pick_random_from_ith_sub_dataset(sub_datasets, chosen_class, is_eval)
                rcc_tensor[chosen_class] += 1
            '''

            bag_tensors.append(torch.stack(bag_tensor))
            ucc_tensors.append(self.one_hot(ucc, self.ucc_limit))
            rcc_tensors.append(torch.tensor(rcc_tensor).to(torch.float32))

        return TensorDataset(
            torch.stack(bag_tensors),
            torch.stack(ucc_tensors),
            torch.stack(rcc_tensors),
        )

    # util
    def one_hot(self, label, limit):
        # Create a one-hot tensor
        one_hot = torch.zeros(limit)

        # since each label is in range of [1,10] getting it to a range of [0,9]
        one_hot[label - 1] = 1
        return one_hot.to(torch.float32)

    def create_bag(self):
        return [None] * self.bag_size


## Creating the dataset object

In [None]:
dataset = Dataset(x_train, y_train, x_val, y_val, x_test, y_test, debug=False, apply_augmentation=True)

# Checking how one bag looks like

In [None]:
#printing the images in a bag
import matplotlib.pyplot as plt

tensor_to_img_transform = transforms.ToPILImage()
dataloaders = [dataset.ucc_train_dataloader, dataset.ucc_test_dataloader, dataset.ucc_val_dataloader]
names = ["train", "test", "val"]

for ucc_dataloader, name in zip(dataloaders, names):
    print(f"Checking out {name}")
    for data in ucc_dataloader:
        batches, _ = data
        for bag in batches:
            for image_index, image in enumerate(bag):
                # image *= 255
                image = tensor_to_img_transform(image)
                plt.subplot(3, 4, image_index + 1)  # Assuming 12 images per bag
                plt.imshow(image)  # Display the image
                plt.title(f"Bag {image_index + 1}")  # Set the title
                plt.axis('off')  # Turn off axis labels
            plt.show()
            break
        break
plt.show()



## Define the class names

In [None]:
class_names = ["airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"]


## SSIM Loss definition

In [None]:
class SSIMLoss(nn.Module):
    def __init__(self):
        super(SSIMLoss, self).__init__()
        self.ssim = SSIM()

    def forward(self, x, y):
        # Calculate SSIM
        ssim_value = self.ssim(x, y)
        # Subtract SSIM from 1
        loss = 1 - ssim_value
        return loss

# Model architectures

## Autoencoder

In [None]:
class ResidualZeroPaddingBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        first_block=False,
        down_sample=False,
        up_sample=False,
    ):
        super(ResidualZeroPaddingBlock, self).__init__()
        self.first_block = first_block
        self.down_sample = down_sample
        self.up_sample = up_sample

        if self.up_sample:
            self.upsampling = nn.Upsample(scale_factor=2, mode="nearest")

        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            stride=2 if self.down_sample else 1,
        )
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.skip_conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=2 if self.down_sample else 1,
        )

        # Initialize the weights and biases
        nn.init.xavier_uniform_(self.conv1.weight)
        nn.init.constant_(self.conv1.bias, 0.1)
        nn.init.xavier_uniform_(self.conv2.weight)
        nn.init.constant_(self.conv2.bias, 0.1)
        nn.init.xavier_uniform_(self.skip_conv.weight)

    def forward(self, x):
        if self.first_block:
            x = nn.ReLU()(x)
            if self.up_sample:
                x = self.upsampling(x)
            out = nn.ReLU()(self.conv1(x))
            out = self.conv2(out)
            if x.shape != out.shape:
                x = self.skip_conv(x)
        else:
            out = nn.ReLU()(self.conv1(x))
            out = nn.ReLU()(self.conv2(out))
        return x + out

class WideResidualBlocks(nn.Module):
    def __init__(
        self, in_channels, out_channels, n, down_sample=False, up_sample=False
    ):
        super(WideResidualBlocks, self).__init__()
        self.blocks = nn.Sequential(
            *[
                ResidualZeroPaddingBlock(
                    in_channels if i == 0 else out_channels,
                    out_channels,
                    first_block=(i == 0),
                    down_sample=down_sample,
                    up_sample=up_sample,
                )
                for i in range(n)
            ]
        )

    def forward(self, x):
        return self.blocks(x)

class Reshape(nn.Module):
    def __init__(self, *target_shape):
        super(Reshape, self).__init__()
        self.target_shape = target_shape

    def forward(self, x):
        return x.view(x.size(0), *self.target_shape)

class Autoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(
                3,
                16,
                kernel_size=3,
                padding=1,
            ),
            WideResidualBlocks(
                16,
                32,
                1
            ),
            WideResidualBlocks(
                32,
                64,
                1,
                down_sample=True
            ),
            WideResidualBlocks(
                64,
                128,
                1,
                down_sample=True,
            ),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(
                8192,
                2048,
                bias=False,
            ),
            nn.Sigmoid(),
        )

        self.decoder = nn.Sequential(
            nn.Linear(
                2048,
                8192,
            ),
            Reshape(*[128, 8, 8]),
            WideResidualBlocks(
                128,
                128,
                1,
                up_sample=True,
            ),
            WideResidualBlocks(
                128,
                64,
                1,
                up_sample=True,
            ),
            WideResidualBlocks(
                64,
                32,
                1,
            ),
            nn.ReLU(),
            nn.Conv2d(
                32,
                3,
                kernel_size=3,
                padding=1,
            ),
        )

    def forward(self, x):
        x = x.to(torch.float32)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded).to(torch.float32)
        return encoded, decoded

## Kernel Density Estimator

In [None]:
class KDE(nn.Module):
    def __init__(self, device=config.device, num_nodes=config.num_nodes, sigma=config.sigma):
        super(KDE, self).__init__()
        self.num_nodes = num_nodes
        self.sigma = sigma
        self.device = device
        print("KDE Layer initialized")

    def forward(self, data):
        batch_size, bag_size, num_features = data.size()  # Batch, bag, J

        # Create a tensor for the sample points
        k_sample_points = torch.linspace(0, 1, steps=self.num_nodes).repeat(batch_size, bag_size, 1).to(
            self.device)  # B, bag, num_nodes

        # Constants
        k_alfa = 1 / np.sqrt(2 * np.pi * np.square(self.sigma))
        k_beta = -1 / (2 * self.sigma ** 2)

        out_list = []

        for j in range(num_features):
            data_j = data[:, :, j]  # shape (Batch, bag)
            temp_data = data_j.view(-1, bag_size, 1)  # shape (Batch, bag, 1)
            temp_data = temp_data.expand(-1, -1, self.num_nodes)  # shape ( Batch, bag, num_nodes)

            k_diff = k_sample_points - temp_data  # shape ( Batch, bag, num_nodes)
            k_diff_2 = torch.square(k_diff)  # shape ( Batch, bag, num_nodes)
            k_result = k_alfa * torch.exp(k_beta * k_diff_2)  # shape ( Batch, bag, num_nodes)
            k_out_unnormalized = torch.sum(k_result, dim=1)  # (B, num_nodes)
            k_norm_coeff = k_out_unnormalized.sum(dim=1).view(batch_size, 1)  # (B,1)
            k_out = k_out_unnormalized / k_norm_coeff.expand(-1, k_out_unnormalized.size(1))  # (B, num_nodes)

            out_list.append(k_out)
        # out_list is of shape (J, B, num_nodes)
        concat_out = torch.cat(out_list, dim=-1)  # shape is (Batch, J*num_nodes)
        return concat_out  # shape is (Batch, J*num_nodes) -> (1, 8448)


## UCC Prediction model

In [None]:
# UCC Prediction model
class UCCPredictor(nn.Module):
    def __init__(self, device=config.device, ucc_limit=config.ucc_limit):
        super().__init__()
        # Input size: [Batch, Bag, 1024]
        # Output size: [Batch, 4]
        self.kde = KDE(device)
        self.stack = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=2, stride=2),  # shape 2112
            nn.ReLU(),
            nn.Linear(5632, 256, dtype=torch.float32),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, 32, dtype=torch.float32),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(32, ucc_limit, dtype=torch.float32),
            nn.Sigmoid()
        )

        # Initialize weights using Xavier initialization with normal distribution
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.xavier_normal_(m.weight)
                nn.init.constant_(m.bias, 0.1)

        print("UCC Predictor model initialized")

    def forward(self, x):
        kde_prob_distributions = self.kde(x)  # shape (Batch, 22528)
        ucc_logits = self.stack(kde_prob_distributions)  # shape (Batch, 4)
        return ucc_logits


# Combined UCC model
class CombinedUCCModel(nn.Module):
    def __init__(self, device=config.device):
        super().__init__()
        self.autoencoder = Autoencoder()
        self.ucc_predictor = UCCPredictor(device)
        print("Combined UCC model initialized")

    def forward(self, batch):
        # Input size: [batch, bag, 3, 32, 32]
        # output size: [batch, 4] (ucc_logits), [batch * bag,3,32,32] ( decoded images)

        # Stage 1. pass through autoencoder
        batch_size, bag_size, num_channels, height, width = batch.size()
        batches_of_image_bags = batch.view(batch_size * bag_size, num_channels, height, width).to(torch.float32)
        encoded, decoded = self.autoencoder(
            batches_of_image_bags
        )  # we are feeding in Batch*bag images of shape (3,32,32)

        # Stage 2. use the autoencoder latent features to pass through the ucc predictor
        batches_of_image_bags, feature_size = encoded.size()
        encoded = encoded.view(batch_size, bag_size, feature_size)
        ucc_logits = self.ucc_predictor(encoded)

        return ucc_logits, decoded

## RCC Prediction model

This is the additional multi task path which predicts the "Real Class Counts"

In [None]:
# RCC Prediction model
class RCCPredictor(nn.Module):
    def __init__(self, device=config.device, rcc_limit=config.rcc_limit):
        super().__init__()
        # Input size: [Batch, Bag, 1024]
        # Output size: [Batch, 4]
        self.kde = KDE(device)
        self.stack = nn.Sequential(
            nn.MaxPool1d(kernel_size=2, stride=2),
            nn.ReLU(),
            nn.AvgPool1d(kernel_size=2, stride=2),  # shape 2112
            nn.ReLU(),
            nn.Linear(5632, 256, dtype=torch.float32),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(256, 32, dtype=torch.float32),
            nn.Dropout(0.1),
            nn.ReLU(),
            nn.Linear(32, rcc_limit, dtype=torch.float32),
            nn.ReLU()
        )

        # Initialize weights using Xavier initialization with normal distribution
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.xavier_normal_(m.weight)

    def forward(self, x):
        kde_prob_distributions = self.kde(x)  # shape (Batch, 8448)
        rcc_logits = self.stack(kde_prob_distributions)  # shape (Batch, 10)
        return rcc_logits


# Combined RCC model
class CombinedRCCModel(nn.Module):
    def __init__(self, device=config.device):
        super().__init__()
        self.autoencoder = Autoencoder()
        self.ucc_predictor = UCCPredictor(device)
        self.rcc_predictor = RCCPredictor(device)

    def forward(self, batch):
        # Input size: [batch, bag, 3, 32, 32]
        # output size: [batch, 4] (ucc_logits), [batch, 10] (rcc_logits), [batch * bag,3,32,32] ( decoded images)

        # Stage 1. pass through autoencoder
        batch_size, bag_size, num_channels, height, width = batch.size()
        batches_of_image_bags = batch.view(batch_size * bag_size, num_channels, height, width).to(torch.float32)
        encoded, decoded = self.autoencoder(
            batches_of_image_bags
        )  # we are feeding in Batch*bag images of shape (3,32,32)

        # Stage 2. use the autoencoder latent features to pass through the ucc predictor
        batches_of_image_bags, feature_size = encoded.size()
        encoded = encoded.view(batch_size, bag_size, feature_size)
        ucc_logits = self.ucc_predictor(encoded)
        rcc_logits = self.rcc_predictor(encoded)
        return rcc_logits, ucc_logits, decoded

# How many trainable params are there in my models?

## Combined UCC model trainable params

In [None]:
# Combined UCC model
combined_ucc = CombinedUCCModel(config.device).to(config.device)
summary(combined_ucc, input_size=(12, 3, 32, 32), device=config.device, batch_dim=0,col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], verbose=1)


## Combined RCC model trainable params

In [None]:
#Combined RCC model
combined_rcc = CombinedRCCModel(config.device).to(config.device)
summary(combined_rcc, input_size=(12, 3, 32, 32), device=config.device, batch_dim=0,col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], verbose=1)

# EXPERIMENT-1 : UCC Model

This model tries to replicate the paper where we have an autoencoder path and a ucc path.

Similarly experiment-2 will be the improvement model

## Code for plotting the model stats

In [None]:
def plot_ucc_model_stats(
        experiment, epochs,
        ucc_training_losses, ae_training_losses, combined_training_losses,
        ucc_training_accuracy,
        ucc_validation_losses, ae_validation_losses, combined_validation_losses,
        ucc_validation_accuracy
    ):
    fig, axes = plt.subplots(2, 2, figsize=(15, 15))

    # Plot training losses
    axes[0, 0].plot(epochs, ucc_training_losses, marker="o", color="red", label="UCC Training Loss")
    axes[0, 0].plot(epochs, ae_training_losses, marker="o", color="blue", label="AE Training Loss")
    axes[0, 0].plot(epochs, combined_training_losses, marker="o", color="green", label="Combined Training Loss")
    axes[0, 0].set_title(f'{experiment}: Training Loss vs Epochs')
    axes[0, 0].set_xlabel('Epochs')
    axes[0, 0].set_ylabel('Training Loss')
    axes[0, 0].legend()  # Display the legend

    # Plot training accuracy
    axes[0, 1].plot(epochs, ucc_training_accuracy, marker="o", color="red", label="UCC Training Accuracy")
    axes[0, 1].set_title(f'{experiment}: Training Accuracy vs Epochs')
    axes[0, 1].set_xlabel('Epochs')
    axes[0, 1].set_ylabel('Training Accuracy')
    axes[0, 1].legend()  # Display the legend

    # Plot validation losses
    axes[1, 0].plot(epochs, ucc_validation_losses, marker="o", color="red", label="UCC Validation Loss")
    axes[1, 0].plot(epochs, ae_validation_losses, marker="o", color="blue", label="AE Validation Loss")
    axes[1, 0].plot(epochs, combined_validation_losses, marker="o", color="green", label="Combined Validation Loss")
    axes[1, 0].set_title(f'{experiment}: Validation Loss vs Epochs')
    axes[1, 0].set_xlabel('Epochs')
    axes[1, 0].set_ylabel('Validation Loss')
    axes[1, 0].legend()  # Display the legend

    # Plot validation accuracy 1,1
    axes[1, 1].plot(epochs, ucc_validation_accuracy, marker="o", color="red", label="UCC Validation Accuracy")
    axes[1, 1].set_title(f'{experiment}: Validation Accuracy vs Epochs')
    axes[1, 1].set_xlabel('Epochs')
    axes[1, 1].set_ylabel('Validation Accuracy')
    axes[1, 1].legend()  # Display the legend

    # Add space between subplots
    plt.tight_layout()

    # Show the plot
    plt.show()

    # close it properly
    plt.clf()
    plt.cla()
    plt.close()

## UCC Trainer class

In [None]:
class UCCTrainer:
    def __init__(self,
                 name, ucc_model,
                 dataset, save_dir, device=config.device):
        self.name = name
        self.save_dir = save_dir
        self.device = device

        # data
        self.dataset = dataset
        self.train_loader = dataset.ucc_train_dataloader
        self.test_loader = dataset.ucc_test_dataloader
        self.val_loader = dataset.ucc_val_dataloader
        self.kde_loaders = dataset.kde_test_dataloaders  # each dataloader here will return shape of (batch, bag, 3,32,32) of a pure dataset
        self.autoencoder_loaders = dataset.autoencoder_test_dataloaders

        # create the directory if it doesn't exist!
        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, self.name), exist_ok=True)

        self.ucc_model = ucc_model

        # Adam optimizer(s)
        self.ucc_optimizer = optim.Adam(self.ucc_model.parameters(), lr=config.learning_rate,
                                        weight_decay=config.weight_decay)

        # Loss criterion(s)
        # self.ae_loss_criterion = nn.MSELoss()
        self.ae_loss_criterion = SSIMLoss()
        self.ucc_loss_criterion = nn.CrossEntropyLoss()

        # Transforms
        self.tensor_to_img_transform = transforms.ToPILImage()

        # Values which can change based on loaded checkpoint
        self.start_epoch = 0
        self.epoch_numbers = []
        self.training_ae_losses = []
        self.training_ucc_losses = []
        self.training_losses = []
        self.training_ucc_accuracies = []

        self.val_ae_losses = []
        self.val_ucc_losses = []
        self.val_losses = []
        self.val_ucc_accuracies = []

        self.train_correct_predictions = 0
        self.train_total_batches = 0

    # main train code
    def train(self,
              num_epochs,
              resume_epoch_num=None,
              load_from_checkpoint=False,
              epoch_saver_count=2):
        torch.cuda.empty_cache()

        # initialize the params from the saved checkpoint
        self.init_params_from_checkpoint_hook(load_from_checkpoint, resume_epoch_num)

        # set up scheduler
        self.init_scheduler_hook(num_epochs)

        # Custom progress bar for total epochs with color and displaying average epoch batch_loss
        total_progress_bar = tqdm(
            total=num_epochs, desc=f"Total Epochs", position=0,
            bar_format="{desc}: {percentage}% |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            dynamic_ncols=True, ncols=100, colour='red'
        )

        # Train loop
        for epoch in range(self.start_epoch, self.start_epoch + num_epochs):
            # Custom progress bar for each epoch with color
            epoch_progress_bar = tqdm(
                total=len(self.train_loader),
                desc=f"Epoch {epoch + 1}/{self.start_epoch + num_epochs}",
                position=1,
                leave=False,
                dynamic_ncols=True,
                ncols=100,
                colour='green'
            )

            # set all models to train mode
            self.ucc_model.train()

            # set the epoch training batch_loss
            epoch_training_loss = 0.0
            epoch_ae_loss = 0.0
            epoch_ucc_loss = 0.0

            # iterate over each batch
            for batch_idx, data in enumerate(self.train_loader):
                images, one_hot_ucc_labels = data

                #forward propogate through the combined model
                ucc_logits, decoded = self.ucc_model(images)

                # calculate losses from both models for a batch of bags
                ae_loss = self.calculate_autoencoder_loss(images, decoded)
                ucc_loss, batch_ucc_accuracy = self.calculate_ucc_loss_and_acc(ucc_logits, one_hot_ucc_labels, True)

                # calculate combined loss
                batch_loss = ae_loss + ucc_loss

                # do loss backward for all losses
                batch_loss.backward()

                # Gradient clipping (commenting this out as it is causing colab to crash!)
                nn.utils.clip_grad_value_(self.ucc_model.parameters(), config.grad_clip)

                # do optimizer step and zerograd for autoencoder model
                self.ucc_optimizer.step()
                self.ucc_optimizer.zero_grad()

                # scheduler update (remove if it doesnt work!)
                self.ucc_scheduler.step()

                # add to epoch batch_loss
                epoch_training_loss += batch_loss.item()
                epoch_ae_loss += ae_loss.item()
                epoch_ucc_loss += ucc_loss.item()

                # Update the epoch progress bar (overwrite in place)
                batch_stats = {
                    "batch_loss": batch_loss.item(),
                    "ae_loss": ae_loss.item(),
                    "ucc_loss": ucc_loss.item(),
                    "batch_ucc_acc": batch_ucc_accuracy
                }

                epoch_progress_bar.set_postfix(batch_stats)
                epoch_progress_bar.update(1)

            # close the epoch progress bar
            epoch_progress_bar.close()

            # calculate average epoch train statistics
            avg_train_stats = self.calculate_avg_train_stats_hook(epoch_training_loss, epoch_ae_loss, epoch_ucc_loss)

            # calculate validation statistics
            avg_val_stats = self.validation_hook()

            # Store running history
            self.store_running_history_hook(epoch, avg_train_stats, avg_val_stats)

            # Show epoch stats
            print(f"# Epoch {epoch + 1}")
            epoch_postfix = self.calculate_and_print_epoch_stats_hook(avg_train_stats, avg_val_stats)

            # Update the total progress bar
            total_progress_bar.set_postfix(epoch_postfix)

            # Close tqdm bar
            total_progress_bar.update(1)

            # Save model checkpoint periodically
            need_to_save_model_checkpoint = (epoch + 1) % epoch_saver_count == 0
            if need_to_save_model_checkpoint:
                print(f"Going to save model {self.name} @ Epoch:{epoch + 1}")
                self.save_model_checkpoint_hook(epoch)

            print("-" * 60)

        # Close the total progress bar
        total_progress_bar.close()

        # Return the current state
        return self.get_current_running_history_state_hook()

    # hooks
    def init_params_from_checkpoint_hook(self, load_from_checkpoint, resume_epoch_num):
        if load_from_checkpoint:
            # NOTE: resume_epoch_num can be None here if we want to load from the most recently saved checkpoint!
            checkpoint_path = self.get_model_checkpoint_path(resume_epoch_num)
            checkpoint = torch.load(checkpoint_path)

            # load previous state of models
            self.ucc_model.load_state_dict(checkpoint['ucc_model_state_dict'])

            # load previous state of optimizers
            self.ucc_optimizer.load_state_dict(checkpoint['ucc_optimizer_state_dict'])

            # Things we are keeping track of
            self.start_epoch = checkpoint['epoch']
            self.epoch_numbers = checkpoint['epoch_numbers']

            self.training_losses = checkpoint['training_losses']
            self.training_ae_losses = checkpoint['training_ae_losses']
            self.training_ucc_losses = checkpoint['training_ucc_losses']
            self.training_ucc_accuracies = checkpoint['training_ucc_accuracies']

            self.val_losses = checkpoint['val_losses']
            self.val_ae_losses = checkpoint['val_ae_losses']
            self.val_ucc_losses = checkpoint['val_ucc_losses']
            self.val_ucc_accuracies = checkpoint['val_ucc_accuracies']

            print(f"Model checkpoint for {self.name} is loaded from {checkpoint_path}!")

    def init_scheduler_hook(self, num_epochs):
        # here we are doing it at a bag level
        self.ucc_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.ucc_optimizer,
            config.learning_rate,
            epochs=num_epochs,
            steps_per_epoch=len(self.train_loader)
        )

    def calculate_autoencoder_loss(self, images, decoded):
        # data is of shape (batchsize=2,bag=10,channels=3,height=32,width=32)
        # generally batch size of 16 is good for cifar10 so predicting 20 won't be so bad
        batch_size, bag_size, num_channels, height, width = images.size()
        batches_of_bag_images = images.view(batch_size * bag_size, num_channels, height, width).to(torch.float32)
        ae_loss = self.ae_loss_criterion(decoded, batches_of_bag_images)  # compares (Batch * Bag, 3,32,32)
        return ae_loss

    def calculate_ucc_loss_and_acc(self, ucc_logits, one_hot_ucc_labels, is_train_mode=True):
        # compute the ucc_loss between [batch, 4]
        ucc_loss = self.ucc_loss_criterion(ucc_logits, one_hot_ucc_labels)

        # compute the batch stats right here and save it
        ucc_probs = nn.Softmax(dim=1)(ucc_logits)
        predicted = torch.argmax(ucc_probs, 1)
        labels = torch.argmax(one_hot_ucc_labels, 1)
        batch_correct_predictions = (predicted == labels).sum().item()
        batch_size = labels.size(0)

        # calculate batchwise accuracy/ucc_loss
        batch_ucc_accuracy = batch_correct_predictions / batch_size
        if is_train_mode:
            self.train_correct_predictions += batch_correct_predictions
            self.train_total_batches += batch_size
        else:
            self.eval_correct_predictions += batch_correct_predictions
            self.eval_total_batches += batch_size
        return ucc_loss, batch_ucc_accuracy

    def calculate_avg_train_stats_hook(self, epoch_training_loss, epoch_ae_loss, epoch_ucc_loss):
        no_of_bags = len(self.train_loader) * config.batch_size
        avg_training_loss_for_epoch = epoch_training_loss / no_of_bags
        avg_ae_loss_for_epoch = epoch_ae_loss / no_of_bags
        avg_ucc_loss_for_epoch = epoch_ucc_loss / no_of_bags
        avg_ucc_training_accuracy = self.train_correct_predictions / self.train_total_batches

        epoch_train_stats = {
            "avg_training_loss": avg_training_loss_for_epoch,
            "avg_ae_loss": avg_ae_loss_for_epoch,
            "avg_ucc_loss": avg_ucc_loss_for_epoch,
            "avg_ucc_training_accuracy": avg_ucc_training_accuracy
        }

        # reset
        self.train_correct_predictions = 0
        self.train_total_batches = 0

        return epoch_train_stats

    def validation_hook(self):
        # class level init
        self.eval_correct_predictions = 0
        self.eval_total_batches = 0

        val_loss = 0.0
        val_ae_loss = 0.0
        val_ucc_loss = 0.0

        with torch.no_grad():
            # set all models to eval mode
            self.ucc_model.eval()

            for val_batch_idx, val_data in enumerate(self.val_loader):
                val_images, val_one_hot_ucc_labels = val_data

                #forward propogate through the model
                val_ucc_logits, val_decoded = self.ucc_model(val_images)

                # calculate losses from both models for a batch of bags
                val_batch_ae_loss = self.calculate_autoencoder_loss(val_images, val_decoded)
                val_batch_ucc_loss, val_batch_ucc_accuracy = self.calculate_ucc_loss_and_acc(val_ucc_logits,
                                                                                             val_one_hot_ucc_labels,
                                                                                             False)

                # calculate combined loss
                val_batch_loss = val_batch_ae_loss + val_batch_ucc_loss

                # cummulate the losses
                val_ae_loss += val_batch_ae_loss.item()
                val_ucc_loss += val_batch_ucc_loss.item()
                val_loss += val_batch_loss.item()

        # Calculate average validation loss for the epoch
        no_of_bags = len(self.val_loader) * config.batch_size
        avg_val_loss = val_loss / no_of_bags
        avg_val_ucc_loss = val_ucc_loss / no_of_bags
        avg_val_ae_loss = val_ae_loss / no_of_bags
        avg_val_ucc_training_accuracy = self.eval_correct_predictions / self.eval_total_batches

        print("Finished computing val stats, now showing a sample reconstruction")
        # show some sample predictions
        self.show_sample_reconstructions(self.val_loader)

        return {
            "avg_val_loss": avg_val_loss,
            "avg_val_ae_loss": avg_val_ae_loss,
            "avg_val_ucc_loss": avg_val_ucc_loss,
            "avg_val_ucc_training_accuracy": avg_val_ucc_training_accuracy
        }

    def calculate_and_print_epoch_stats_hook(self, avg_train_stats, avg_val_stats):
        epoch_loss = avg_train_stats["avg_training_loss"]
        epoch_ae_loss = avg_train_stats["avg_ae_loss"]
        epoch_ucc_loss = avg_train_stats["avg_ucc_loss"]
        epoch_ucc_accuracy = avg_train_stats["avg_ucc_training_accuracy"]

        epoch_val_loss = avg_val_stats["avg_val_loss"]
        epoch_val_ae_loss = avg_val_stats["avg_val_ae_loss"]
        epoch_val_ucc_loss = avg_val_stats["avg_val_ucc_loss"]
        epoch_val_ucc_accuracy = avg_val_stats["avg_val_ucc_training_accuracy"]

        print(
            f"[TRAIN]: Epoch Loss: {epoch_loss} | AE Loss: {epoch_ae_loss} | UCC Loss: {epoch_ucc_loss} | UCC Acc: {epoch_ucc_accuracy}")
        print(
            f"[VAL]: Val Loss: {epoch_val_loss} | Val AE Loss: {epoch_val_ae_loss} | Val UCC Loss: {epoch_val_ucc_loss} | Val UCC Acc: {epoch_val_ucc_accuracy}")

        return {
            "epoch_loss": epoch_loss,
            "epoch_ae_loss": epoch_ae_loss,
            "epoch_ucc_loss": epoch_ucc_loss,
            "epoch_ucc_acc": epoch_ucc_accuracy,
            "epoch_val_loss": epoch_val_loss,
            "epoch_val_ae_loss": epoch_val_ae_loss,
            "epoch_val_ucc_loss": epoch_val_ucc_loss,
            "epoch_val_ucc_acc": epoch_val_ucc_accuracy
        }

    def store_running_history_hook(self, epoch, avg_train_stats, avg_val_stats):
        self.epoch_numbers.append(epoch + 1)

        self.training_ae_losses.append(avg_train_stats["avg_ae_loss"])
        self.training_ucc_losses.append(avg_train_stats["avg_ucc_loss"])
        self.training_losses.append(avg_train_stats["avg_training_loss"])
        self.training_ucc_accuracies.append(avg_train_stats["avg_ucc_training_accuracy"])

        self.val_ae_losses.append(avg_val_stats["avg_val_ae_loss"])
        self.val_ucc_losses.append(avg_val_stats["avg_val_ucc_loss"])
        self.val_losses.append(avg_val_stats["avg_val_loss"])
        self.val_ucc_accuracies.append(avg_val_stats["avg_val_ucc_training_accuracy"])

    def get_current_running_history_state_hook(self):
        return self.epoch_numbers, \
            self.training_ae_losses, self.training_ucc_losses, self.training_losses, self.training_ucc_accuracies, \
            self.val_ae_losses, self.val_ucc_losses, self.val_losses, self.val_ucc_accuracies

    def save_model_checkpoint_hook(self, epoch):
        # set it to train mode to save the weights (but doesn't matter apparently!)
        self.ucc_model.train()

        # create the directory if it doesn't exist
        model_save_directory = os.path.join(self.save_dir, self.name)
        os.makedirs(model_save_directory, exist_ok=True)

        # Checkpoint the model at the end of each epoch
        checkpoint_path = os.path.join(model_save_directory, f'model_epoch_{epoch + 1}.pt')
        torch.save(
            {
                'ucc_model_state_dict': self.ucc_model.state_dict(),
                'ucc_optimizer_state_dict': self.ucc_optimizer.state_dict(),
                'epoch': epoch + 1,
                'epoch_numbers': self.epoch_numbers,
                'training_losses': self.training_losses,
                'training_ae_losses': self.training_ae_losses,
                'training_ucc_losses': self.training_ucc_losses,
                'training_ucc_accuracies': self.training_ucc_accuracies,
                'val_losses': self.val_losses,
                'val_ae_losses': self.val_ae_losses,
                'val_ucc_losses': self.val_ucc_losses,
                'val_ucc_accuracies': self.val_ucc_accuracies,
            },
            checkpoint_path
        )
        print(f"Saved the model checkpoint for experiment {self.name} for epoch {epoch + 1}")

    def test_model(self):
        # class level init
        self.eval_correct_predictions = 0
        self.eval_total_batches = 0

        test_loss = 0.0
        test_ae_loss = 0.0
        test_ucc_loss = 0.0

        with torch.no_grad():
            # set all models to eval mode
            self.ucc_model.eval()

            for test_batch_idx, test_data in enumerate(self.test_loader):
                test_images, test_one_hot_ucc_labels = test_data

                # forward propogate through the model
                test_ucc_logits, test_decoded = self.ucc_model(test_images)

                # calculate losses from both models for a batch of bags
                test_batch_ae_loss = self.calculate_autoencoder_loss(test_images, test_decoded)
                test_batch_ucc_loss, test_batch_ucc_accuracy = self.calculate_ucc_loss_and_acc(test_ucc_logits,
                                                                                             test_one_hot_ucc_labels,
                                                                                             False)

                # calculate combined loss
                test_batch_loss = test_batch_ae_loss + test_batch_ucc_loss

                # cummulate the losses
                test_ae_loss += test_batch_ae_loss.item()
                test_ucc_loss += test_batch_ucc_loss.item()
                test_loss += test_batch_loss.item()

        # Calculate average validation loss for the epoch
        no_of_bags = len(self.test_loader) * config.batch_size
        avg_test_loss = test_loss / no_of_bags
        avg_test_ucc_loss = test_ucc_loss / no_of_bags
        avg_test_ae_loss = test_ae_loss / no_of_bags
        avg_test_ucc_training_accuracy = self.eval_correct_predictions / self.eval_total_batches

        # show some sample predictions
        self.show_sample_reconstructions(self.test_loader)

        return {
            "avg_test_loss": avg_test_loss,
            "avg_test_ae_loss": avg_test_ae_loss,
            "avg_test_ucc_loss": avg_test_ucc_loss,
            "avg_test_ucc_training_accuracy": avg_test_ucc_training_accuracy
        }

    def show_sample_reconstructions(self, dataloader):
        # Create a subplot grid
        fig, axes = plt.subplots(1, 2, figsize=(3, 3))

        with torch.no_grad():
            # set all models to eval mode
            self.ucc_model.eval()

            for val_data in dataloader:
                val_images, _ = val_data

                #reshape to appropriate size
                batch_size, bag_size, num_channels, height, width = val_images.size()
                bag_val_images = val_images.view(batch_size * bag_size, num_channels, height, width)
                print("Reshaped the original image into bag format")

                # forward propagate through the model
                _, val_reconstructed_images = self.ucc_model(val_images)
                print("Got a sample reconstruction, now trying to reshape in order to show an example")

                # take only one image from the bag
                sample_image = bag_val_images[0]
                predicted_image = val_reconstructed_images[0]

                # get it to cpu
                sample_image = sample_image.to("cpu")
                predicted_image = predicted_image.to("cpu")

                # convert to PIL Image
                sample_image = self.tensor_to_img_transform(sample_image)
                predicted_image = self.tensor_to_img_transform(predicted_image)

                axes[0].imshow(sample_image)
                axes[0].set_title(f"Orig", color='green')
                axes[0].axis('off')

                axes[1].imshow(predicted_image)
                axes[1].set_title(f"Recon", color='red')
                axes[1].axis('off')

                # show only one image
                break

        plt.tight_layout()
        plt.show()

    def js_divergence(self, p, q):
        """
        Calculate the Jensen-Shannon Divergence between two probability distributions p and q.

        Args:
        p (torch.Tensor): Probability distribution p.
        q (torch.Tensor): Probability distribution q.

        Returns:
        torch.Tensor: Jensen-Shannon Divergence between p and q.
        """
        # Calculate the average distribution 'm'
        m = 0.5 * (p + q)

        # Calculate the KL Divergence of 'p' and 'q' from 'm'
        kl_div_p = F.kl_div(p.log(), m, reduction='batchmean')
        kl_div_q = F.kl_div(q.log(), m, reduction='batchmean')

        # Compute the JS Divergence
        js_divergence = 0.5 * (kl_div_p + kl_div_q)

        return js_divergence

    def calculate_min_js_divergence(self):
        num_classes = len(self.kde_loaders)
        kde_per_class = {class_idx: 0.0 for class_idx in range(num_classes)}

        # find the average kde across all classes
        for class_idx, pure_class_kde_loader in tqdm(enumerate(self.kde_loaders)):
            num_bags_in_class = 0
            for images in pure_class_kde_loader:
                # get the first element
                images = images[0]

                #Stage.1 pass through the encoder to get the latent features
                # batch data is of shape ( Batch,bag, 3,32,32)
                batch_size, bag_size, num_channels, height, width = images.size()
                # reshaping to shape ( batch * bag, 3 ,32,32)
                batches_of_bag_images = images.view(batch_size * bag_size, num_channels, height, width).to(
                    torch.float32)
                latent_features = self.ucc_model.autoencoder.encoder(batches_of_bag_images)  # shape (Batch * bag, 48*16)
                latent_features = latent_features.to(torch.float32)

                #Stage.2 pass through KDE
                # encoded is of shape [Batch * Bag, 48*16] ->  make it into shape [Batch, Bag, 48*16]
                batch_times_bag_size, feature_size = latent_features.size()
                bag_size = config.bag_size
                batch_size = batch_times_bag_size // bag_size
                latent_features = latent_features.view(batch_size, bag_size, feature_size)
                batch_kde_distributions = self.ucc_model.ucc_predictor.kde(latent_features)  # shape [Batch=1, 8448]

                #Stage.3 Take sum
                num_bags_in_class += batch_kde_distributions.size(0)
                kde_distributions = torch.sum(batch_kde_distributions, dim=0)
                kde_per_class[class_idx] += kde_distributions

            #Stage.4 Take average
            kde_per_class[class_idx] /= num_bags_in_class

        # find the js_divergence
        min_divergence = torch.inf
        best_i = None
        best_j = None
        for i in range(num_classes):
            for j in range(i + 1, num_classes):
                divergence = self.js_divergence(kde_per_class[i], kde_per_class[j])
                print(f"JS Divergence between {i} & {j} is {divergence}")
                if divergence < min_divergence:
                    min_divergence = divergence
                    best_i = i
                    best_j = j

        print(f"Min JS Divergence is {min_divergence} between classes {best_i} & {best_j}")
        # return the min divergence
        return min_divergence

    def calculate_clustering_accuracy(self):
        all_latent_features = []
        truth_labels_arr = []
        for pure_autoencoder_loader in self.autoencoder_loaders:
            for batch_idx, data in tqdm(enumerate(pure_autoencoder_loader)):
                # batch data is of shape (1,3,32,32), (1,1)
                image, label = data
                latent_features = self.ucc_model.autoencoder.encoder(image)  # shape (1, 48*16)

                latent_features = latent_features.squeeze().detach().cpu().numpy()  # ndarray shape (48*16)
                label = label.squeeze().detach().cpu().numpy()  # ndarray shape (1)

                all_latent_features.append(latent_features)
                truth_labels_arr.append(label.item())

        all_latent_features = np.array(all_latent_features)
        truth_labels_arr = np.array(truth_labels_arr)
        print("Got the latent features for all test images, now doing Kmeans")

        # Do kmeans fit
        estimator = KMeans(n_clusters=10, init='k-means++', n_init=10)
        estimator.fit(all_latent_features)
        predicted_clustering_labels = estimator.labels_

        print("Got the kmeans predicted labels, now computing clustering accuracy")

        # Calculate accuracy
        cost_matrix = np.zeros((10, 10))
        num_samples = np.zeros(10)
        for truth_val in range(10):
            temp_sample_indices = np.where(truth_labels_arr == truth_val)[0]
            num_samples[truth_val] = temp_sample_indices.shape[0]

            temp_predicted_labels = predicted_clustering_labels[temp_sample_indices]

            for predicted_val in range(10):
                temp_matching_pairs = np.where(temp_predicted_labels == predicted_val)[0]
                cost_matrix[truth_val, predicted_val] = 1 - (
                        temp_matching_pairs.shape[0] / temp_sample_indices.shape[0])

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        cost = cost_matrix[row_ind, col_ind]

        clustering_acc = ((1 - cost) * num_samples).sum() / num_samples.sum()
        return clustering_acc

    # find the most recent file and return the path
    def get_model_checkpoint_path(self, epoch_num=None):
        directory = os.path.join(self.save_dir, self.name)
        if epoch_num == None:
            # Get a list of all files in the directory
            files = os.listdir(directory)

            # Filter out only the files (exclude directories)
            files = [f for f in files if os.path.isfile(os.path.join(directory, f))]

            # Sort the files by their modification time in descending order (most recent first)
            files.sort(key=lambda x: os.path.getmtime(os.path.join(directory, x)), reverse=True)

            # Get the name of the most recently added file
            model_file = files[0] if files else None
        else:
            model_file = f"model_epoch_{epoch_num}.pt"
        return os.path.join(directory, model_file)

## Creating the model instances


In [None]:
experiment1 = "ucc"
save_dir = os.path.abspath(config.weights_path)
ucc_model = CombinedUCCModel(config.device).to(config.device)

#creating the trainer
ucc_trainer = UCCTrainer(experiment1, ucc_model, dataset, save_dir)

## Training the model

In [None]:
print("Going to start training..")
exp1_epoch_numbers, exp1_training_ae_losses, exp1_training_ucc_losses, exp1_training_losses, exp1_training_ucc_accuracies, exp1_val_ae_losses, exp1_val_ucc_losses, exp1_val_losses, exp1_val_ucc_accuracies = ucc_trainer.train(20, epoch_saver_count=1)

## Additional Training if required

In [None]:
# exp1_epoch_numbers, exp1_training_ae_losses, exp1_training_ucc_losses, exp1_training_losses, exp1_training_ucc_accuracies, exp1_val_ae_losses, exp1_val_ucc_losses, exp1_val_losses, exp1_val_ucc_accuracies = ucc_trainer.train(10, epoch_saver_count=2, load_from_checkpoint=True, resume_epoch_num=42)

## Plotting the model stats

In [None]:
plot_ucc_model_stats(experiment1, exp1_epoch_numbers, exp1_training_ucc_losses, exp1_training_ae_losses, exp1_training_losses,
                     exp1_training_ucc_accuracies, exp1_val_ucc_losses, exp1_val_ae_losses, exp1_val_losses,
                     exp1_val_ucc_accuracies)

## Testing the model

In [None]:
ucc_trainer.test_model()

## Calculating the Min JS Divergence

In [None]:
exp1_min_js_divg = ucc_trainer.calculate_min_js_divergence()
exp1_min_js_divg

## Calculating the Clustering Accuracy

In [None]:
exp1_clustering_accuracies = ucc_trainer.calculate_clustering_accuracy()
exp1_clustering_accuracies

In [None]:
#----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------#

# EXPERIMENT-2 : UCC-RCC Model

This model is an improvement to the original model as we are also trying to predict the RCC (Real Class Counts) as a separate multitask path. This approach in theory should improve the accuracy of the model.

Additionally we use the SSIM loss for the autoencoder as that is known to be a good loss function when it comes to autoencoders.



## Code for plotting the model stats

In [None]:
def plot_ucc_rcc_model_stats(
        experiment, epochs,
        ucc_training_losses, ae_training_losses, rcc_training_losses, combined_training_losses,
        ucc_training_accuracy, rcc_training_accuracy,
        ucc_validation_losses, ae_validation_losses, rcc_validation_losses, combined_validation_losses,
        ucc_validation_accuracy, rcc_validation_accuracy
    ):
    fig, axes = plt.subplots(2, 2, figsize=(20, 20))

    # Plot training losses
    axes[0, 0].plot(epochs, ucc_training_losses, marker="o", color="red", label="UCC Training Loss")
    axes[0, 0].plot(epochs, ae_training_losses, marker="o", color="blue", label="AE Training Loss")
    axes[0, 0].plot(epochs, rcc_training_losses, marker="o", color="yellow", label="RCC Training Loss")
    axes[0, 0].plot(epochs, combined_training_losses, marker="o", color="green", label="Combined Training Loss")
    axes[0, 0].set_title(f'{experiment}: Training Loss vs Epochs')
    axes[0, 0].set_xlabel('Epochs')
    axes[0, 0].set_ylabel('Training Loss')
    axes[0, 0].legend()  # Display the legend

    # Plot training accuracy
    axes[0, 1].plot(epochs, ucc_training_accuracy, marker="o", color="red", label="UCC Training Accuracy")
    axes[0, 1].plot(epochs, rcc_training_accuracy, marker="o", color="green", label="RCC Training Accuracy")
    axes[0, 1].set_title(f'{experiment}: Training Accuracy vs Epochs')
    axes[0, 1].set_xlabel('Epochs')
    axes[0, 1].set_ylabel('Training Accuracy')
    axes[0, 1].legend()  # Display the legend

    # Plot validation losses
    axes[1, 0].plot(epochs, ucc_validation_losses, marker="o", color="red", label="UCC Validation Loss")
    axes[1, 0].plot(epochs, ae_validation_losses, marker="o", color="blue", label="AE Validation Loss")
    axes[1, 0].plot(epochs, rcc_validation_losses, marker="o", color="yellow", label="RCC Validation Loss")
    axes[1, 0].plot(epochs, combined_validation_losses, marker="o", color="green", label="Combined Validation Loss")
    axes[1, 0].set_title(f'{experiment}: Validation Loss vs Epochs')
    axes[1, 0].set_xlabel('Epochs')
    axes[1, 0].set_ylabel('Validation Loss')
    axes[1, 0].legend()  # Display the legend

    # Plot validation accuracy 1,1
    axes[1, 1].plot(epochs, ucc_validation_accuracy, marker="o", color="red", label="UCC Validation Accuracy")
    axes[1, 1].plot(epochs, rcc_validation_accuracy, marker="o", color="green", label="RCC Validation Accuracy")
    axes[1, 1].set_title(f'{experiment}: Validation Accuracy vs Epochs')
    axes[1, 1].set_xlabel('Epochs')
    axes[1, 1].set_ylabel('Validation Accuracy')
    axes[1, 1].legend()  # Display the legend

    # Add space between subplots
    plt.tight_layout()

    # Show the plot
    plt.show()

    # close it properly
    plt.clf()
    plt.cla()
    plt.close()


## RCC Trainer class

In [None]:

class RCCTrainer:
    def __init__(self,
                 name, rcc_model,
                 dataset, save_dir, device=config.device):
        self.name = name
        self.save_dir = save_dir
        self.device = device

        # data
        self.train_loader = dataset.ucc_rcc_train_dataloader
        self.test_loader = dataset.ucc_rcc_test_dataloader
        self.val_loader = dataset.ucc_rcc_val_dataloader
        self.kde_loaders = dataset.kde_test_dataloaders  # each dataloader here will return shape of (batch, bag, 3,32,32) of a pure dataset
        self.autoencoder_loaders = dataset.autoencoder_test_dataloaders

        # create the directory if it doesn't exist!
        os.makedirs(self.save_dir, exist_ok=True)
        os.makedirs(os.path.join(self.save_dir, self.name), exist_ok=True)

        self.rcc_model = rcc_model

        # Adam optimizer(s)
        self.rcc_optimizer = optim.Adam(self.rcc_model.parameters(), lr=config.learning_rate,
                                        weight_decay=config.weight_decay)

        # Loss criterion(s)
        self.ae_loss_criterion = SSIMLoss()
        self.ucc_loss_criterion = nn.CrossEntropyLoss()
        self.rcc_loss_criterion = nn.MSELoss()

        # Transforms
        self.tensor_to_img_transform = transforms.ToPILImage()

        # Values which can change based on loaded checkpoint
        self.start_epoch = 0
        self.epoch_numbers = []

        self.training_losses = []
        self.training_ae_losses = []
        self.training_ucc_losses = []
        self.training_rcc_losses = []
        self.training_ucc_accuracies = []
        self.training_rcc_accuracies = []

        self.val_losses = []
        self.val_ae_losses = []
        self.val_ucc_losses = []
        self.val_rcc_losses = []
        self.val_ucc_accuracies = []
        self.val_rcc_accuracies = []

        self.train_ucc_correct_predictions = 0
        self.train_ucc_total_batches = 0

        self.train_rcc_correct_predictions = 0
        self.train_rcc_total_batches = 0

    # main train code
    def train(self,
              num_epochs,
              resume_epoch_num=None,
              load_from_checkpoint=False,
              epoch_saver_count=2):
        torch.cuda.empty_cache()

        # initialize the params from the saved checkpoint
        self.init_params_from_checkpoint_hook(load_from_checkpoint, resume_epoch_num)

        # set up scheduler
        self.init_scheduler_hook(num_epochs)

        # Custom progress bar for total epochs with color and displaying average epoch batch_loss
        total_progress_bar = tqdm(
            total=num_epochs, desc=f"Total Epochs", position=0,
            bar_format="{desc}: {percentage}% |{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            dynamic_ncols=True, ncols=100, colour='red'
        )

        # Train loop
        for epoch in range(self.start_epoch, self.start_epoch + num_epochs):
            # Custom progress bar for each epoch with color
            epoch_progress_bar = tqdm(
                total=len(self.train_loader),
                desc=f"Epoch {epoch + 1}/{self.start_epoch + num_epochs}",
                position=1,
                leave=False,
                dynamic_ncols=True,
                ncols=100,
                colour='green'
            )

            # set all models to train mode
            self.rcc_model.train()

            # set the epoch training batch_loss
            epoch_training_loss = 0.0
            epoch_ae_loss = 0.0
            epoch_ucc_loss = 0.0
            epoch_rcc_loss = 0.0

            # iterate over each batch
            for batch_idx, data in enumerate(self.train_loader):
                images, one_hot_ucc_labels, rcc_labels = data

                # forward propogate through the combined model
                rcc_logits, ucc_logits, decoded = self.rcc_model(images)

                # calculate losses from both models for a batch of bags
                ae_loss = self.calculate_autoencoder_loss(images, decoded)
                ucc_loss, batch_ucc_accuracy = self.calculate_ucc_loss_and_acc(ucc_logits, one_hot_ucc_labels, True)
                rcc_loss, batch_rcc_accuracy = self.calculate_rcc_loss_and_acc(rcc_logits, rcc_labels, True)

                # calculate combined loss
                batch_loss = ae_loss + ucc_loss + rcc_loss

                # do loss backward for all losses
                batch_loss.backward()

                # Gradient clipping(causing colab to crash!)
                nn.utils.clip_grad_value_(self.rcc_model.parameters(), config.grad_clip)

                # do optimizer step and zerograd
                self.rcc_optimizer.step()
                self.rcc_optimizer.zero_grad()

                # scheduler update (remove if it doesnt work!)
                self.rcc_scheduler.step()

                # add to epoch batch_loss
                epoch_training_loss += batch_loss.item()
                epoch_ae_loss += ae_loss.item()
                epoch_ucc_loss += ucc_loss.item()
                epoch_rcc_loss += rcc_loss.item()

                # Update the epoch progress bar (overwrite in place)
                batch_stats = {
                    "batch_loss": batch_loss.item(),
                    "ae_loss": ae_loss.item(),
                    "ucc_loss": ucc_loss.item(),
                    "rcc_loss": rcc_loss.item(),
                    "batch_ucc_acc": batch_ucc_accuracy,
                    "batch_rcc_acc": batch_rcc_accuracy
                }

                epoch_progress_bar.set_postfix(batch_stats)
                epoch_progress_bar.update(1)

            # close the epoch progress bar
            epoch_progress_bar.close()

            # calculate average epoch train statistics
            avg_train_stats = self.calculate_avg_train_stats_hook(epoch_training_loss, epoch_ae_loss, epoch_ucc_loss,
                                                                  epoch_rcc_loss)

            # calculate validation statistics
            avg_val_stats = self.validation_hook()

            # Store running history
            self.store_running_history_hook(epoch, avg_train_stats, avg_val_stats)

            # Show epoch stats
            print(f"# Epoch {epoch + 1}")
            epoch_postfix = self.calculate_and_print_epoch_stats_hook(avg_train_stats, avg_val_stats)

            # Update the total progress bar
            total_progress_bar.set_postfix(epoch_postfix)

            # Close tqdm bar
            total_progress_bar.update(1)

            # Save model checkpoint periodically
            need_to_save_model_checkpoint = (epoch + 1) % epoch_saver_count == 0
            if need_to_save_model_checkpoint:
                print(f"Going to save model {self.name} @ Epoch:{epoch + 1}")
                self.save_model_checkpoint_hook(epoch)

            print("-" * 60)

        # Close the total progress bar
        total_progress_bar.close()

        # Return the current state
        return self.get_current_running_history_state_hook()

    # hooks
    # DONE
    def init_params_from_checkpoint_hook(self, load_from_checkpoint, resume_epoch_num):
        if load_from_checkpoint:
            # NOTE: resume_epoch_num can be None here if we want to load from the most recently saved checkpoint!
            checkpoint_path = self.get_model_checkpoint_path(resume_epoch_num)
            checkpoint = torch.load(checkpoint_path)

            # load previous state of models
            self.rcc_model.load_state_dict(checkpoint['rcc_model_state_dict'])

            # load previous state of optimizers
            self.rcc_optimizer.load_state_dict(checkpoint['rcc_optimizer_state_dict'])

            # Things we are keeping track of
            self.start_epoch = checkpoint['epoch']
            self.epoch_numbers = checkpoint['epoch_numbers']

            self.training_losses = checkpoint['training_losses']
            self.training_ae_losses = checkpoint['training_ae_losses']
            self.training_ucc_losses = checkpoint['training_ucc_losses']
            self.training_rcc_losses = checkpoint['training_rcc_losses']
            self.training_ucc_accuracies = checkpoint['training_ucc_accuracies']
            self.training_rcc_accuracies = checkpoint['training_rcc_accuracies']

            self.val_losses = checkpoint['val_losses']
            self.val_ae_losses = checkpoint['val_ae_losses']
            self.val_ucc_losses = checkpoint['val_ucc_losses']
            self.val_rcc_losses = checkpoint['val_rcc_losses']
            self.val_ucc_accuracies = checkpoint['val_ucc_accuracies']
            self.val_rcc_accuracies = checkpoint['val_rcc_accuracies']

            print(f"Model checkpoint for {self.name} is loaded from {checkpoint_path}!")

    # DONE
    def init_scheduler_hook(self, num_epochs):
        self.rcc_scheduler = torch.optim.lr_scheduler.OneCycleLR(
            self.rcc_optimizer,
            config.learning_rate,
            epochs=num_epochs,
            steps_per_epoch=len(self.train_loader)
        )

    # DONE
    def calculate_autoencoder_loss(self, images, decoded):
        # data is of shape (batchsize=2,bag=10,channels=3,height=32,width=32)
        # generally batch size of 16 is good for cifar10 so predicting 20 won't be so bad
        batch_size, bag_size, num_channels, height, width = images.size()
        batches_of_bag_images = images.view(batch_size * bag_size, num_channels, height, width).to(torch.float32)
        ae_loss = self.ae_loss_criterion(decoded, batches_of_bag_images)  # compares (Batch * Bag, 3,32,32)
        return ae_loss

    # DONE
    def calculate_ucc_loss_and_acc(self, ucc_logits, one_hot_ucc_labels, is_train_mode=True):
        # compute the ucc_loss between [batch, 4]
        ucc_loss = self.ucc_loss_criterion(ucc_logits, one_hot_ucc_labels)

        # compute the batch stats right here and save it
        ucc_probs = nn.Softmax(dim=1)(ucc_logits)
        predicted = torch.argmax(ucc_probs, 1)
        labels = torch.argmax(one_hot_ucc_labels, 1)
        batch_correct_predictions = (predicted == labels).sum().item()
        batch_size = labels.size(0)

        # calculate batchwise accuracy/ucc_loss
        batch_ucc_accuracy = batch_correct_predictions / batch_size
        if is_train_mode:
            self.train_ucc_correct_predictions += batch_correct_predictions
            self.train_ucc_total_batches += batch_size
        else:
            self.eval_ucc_correct_predictions += batch_correct_predictions
            self.eval_ucc_total_batches += batch_size
        return ucc_loss, batch_ucc_accuracy

    # DONE
    '''
    NOTE: To improve this I can also add a rcc-ucc-enforcement loss where the number of unique classes should match the ucc exactly
    '''

    def calculate_rcc_loss_and_acc(self, rcc_logits, rcc_labels, is_train_mode=True):
        # compute the rcc_loss between [batch, 10] ( as there are 10 classes)
        batch_times_bag_size = config.batch_size * config.bag_size

        # round it to the nearest integer
        predicted = torch.round(rcc_logits).to(torch.float32)

        # compute the rcc_loss
        rcc_loss = self.rcc_loss_criterion(rcc_logits, rcc_labels)

        # NOTE: not sure if it is dim
        batch_correct_predictions = (predicted == rcc_labels).sum().item()

        # calculate batchwise accuracy/ucc_loss
        batch_rcc_accuracy = batch_correct_predictions / batch_times_bag_size
        if is_train_mode:
            self.train_rcc_correct_predictions += batch_correct_predictions
            self.train_rcc_total_batches += batch_times_bag_size
        else:
            self.eval_rcc_correct_predictions += batch_correct_predictions
            self.eval_rcc_total_batches += batch_times_bag_size
        return rcc_loss, batch_rcc_accuracy

    def calculate_avg_train_stats_hook(self, epoch_training_loss, epoch_ae_loss, epoch_ucc_loss, epoch_rcc_loss):
        no_of_bags = len(self.train_loader) * config.batch_size
        avg_training_loss_for_epoch = epoch_training_loss / no_of_bags
        avg_ae_loss_for_epoch = epoch_ae_loss / no_of_bags
        avg_ucc_loss_for_epoch = epoch_ucc_loss / no_of_bags
        avg_rcc_loss_for_epoch = epoch_rcc_loss / no_of_bags
        avg_ucc_training_accuracy = self.train_ucc_correct_predictions / self.train_ucc_total_batches
        avg_rcc_training_accuracy = self.train_rcc_correct_predictions / self.train_rcc_total_batches

        epoch_train_stats = {
            "avg_training_loss": avg_training_loss_for_epoch,
            "avg_ae_loss": avg_ae_loss_for_epoch,
            "avg_ucc_loss": avg_ucc_loss_for_epoch,
            "avg_rcc_loss": avg_rcc_loss_for_epoch,
            "avg_ucc_training_accuracy": avg_ucc_training_accuracy,
            "avg_rcc_training_accuracy": avg_rcc_training_accuracy
        }

        # reset
        self.train_ucc_correct_predictions = 0
        self.train_ucc_total_batches = 0

        self.train_rcc_correct_predictions = 0
        self.train_rcc_total_batches = 0

        return epoch_train_stats

    def validation_hook(self):
        # class level init
        self.eval_ucc_correct_predictions = 0
        self.eval_ucc_total_batches = 0

        self.eval_rcc_correct_predictions = 0
        self.eval_rcc_total_batches = 0

        val_loss = 0.0
        val_ae_loss = 0.0
        val_ucc_loss = 0.0
        val_rcc_loss = 0.0

        with torch.no_grad():
            # set all models to eval mode
            self.rcc_model.eval()

            for val_batch_idx, val_data in enumerate(self.val_loader):
                val_images, val_one_hot_ucc_labels, val_rcc_labels = val_data

                # forward propogate through the model
                val_rcc_logits, val_ucc_logits, val_decoded = self.rcc_model(val_images)

                # calculate losses from both models for a batch of bags
                val_batch_ae_loss = self.calculate_autoencoder_loss(val_images, val_decoded)
                val_batch_ucc_loss, val_batch_ucc_accuracy = self.calculate_ucc_loss_and_acc(val_ucc_logits,
                                                                                             val_one_hot_ucc_labels,
                                                                                             False)
                val_batch_rcc_loss, val_batch_rcc_accuracy = self.calculate_rcc_loss_and_acc(val_rcc_logits,
                                                                                             val_rcc_labels,
                                                                                             False)

                # calculate combined loss
                val_batch_loss = val_batch_ae_loss + val_batch_ucc_loss + val_batch_rcc_loss

                # cummulate the losses
                val_loss += val_batch_loss.item()
                val_ae_loss += val_batch_ae_loss.item()
                val_ucc_loss += val_batch_ucc_loss.item()
                val_rcc_loss += val_batch_rcc_loss.item()

        # Calculate average validation loss for the epoch
        no_of_bags = len(self.val_loader) * config.batch_size
        avg_val_loss = val_loss / no_of_bags
        avg_val_ucc_loss = val_ucc_loss / no_of_bags
        avg_val_ae_loss = val_ae_loss / no_of_bags
        avg_val_rcc_loss = val_rcc_loss / no_of_bags
        avg_val_ucc_training_accuracy = self.eval_ucc_correct_predictions / self.eval_ucc_total_batches
        avg_val_rcc_training_accuracy = self.eval_rcc_correct_predictions / self.eval_rcc_total_batches

        # show some sample predictions
        self.show_sample_reconstructions(self.val_loader)

        return {
            "avg_val_loss": avg_val_loss,
            "avg_val_ae_loss": avg_val_ae_loss,
            "avg_val_ucc_loss": avg_val_ucc_loss,
            "avg_val_rcc_loss": avg_val_rcc_loss,
            "avg_val_ucc_training_accuracy": avg_val_ucc_training_accuracy,
            "avg_val_rcc_training_accuracy": avg_val_rcc_training_accuracy
        }

    def calculate_and_print_epoch_stats_hook(self, avg_train_stats, avg_val_stats):
        epoch_loss = avg_train_stats["avg_training_loss"]
        epoch_ae_loss = avg_train_stats["avg_ae_loss"]
        epoch_ucc_loss = avg_train_stats["avg_ucc_loss"]
        epoch_rcc_loss = avg_train_stats["avg_rcc_loss"]
        epoch_ucc_accuracy = avg_train_stats["avg_ucc_training_accuracy"]
        epoch_rcc_accuracy = avg_train_stats["avg_rcc_training_accuracy"]

        epoch_val_loss = avg_val_stats["avg_val_loss"]
        epoch_val_ae_loss = avg_val_stats["avg_val_ae_loss"]
        epoch_val_ucc_loss = avg_val_stats["avg_val_ucc_loss"]
        epoch_val_rcc_loss = avg_val_stats["avg_val_rcc_loss"]
        epoch_val_ucc_accuracy = avg_val_stats["avg_val_ucc_training_accuracy"]
        epoch_val_rcc_accuracy = avg_val_stats["avg_val_rcc_training_accuracy"]

        print(
            f"[TRAIN]: Epoch Loss: {epoch_loss} | AE Loss: {epoch_ae_loss} | UCC Loss: {epoch_ucc_loss} | UCC Acc: {epoch_ucc_accuracy} | RCC Loss: {epoch_rcc_loss} | RCC Acc: {epoch_rcc_accuracy}")
        print(
            f"[VAL]: Val Loss: {epoch_val_loss} | Val AE Loss: {epoch_val_ae_loss} | Val UCC Loss: {epoch_val_ucc_loss} | Val UCC Acc: {epoch_val_ucc_accuracy} | Val RCC Loss: {epoch_val_rcc_loss} | Val RCC Acc: {epoch_val_rcc_accuracy}")

        return {
            "epoch_loss": epoch_loss,
            "epoch_ae_loss": epoch_ae_loss,
            "epoch_ucc_loss": epoch_ucc_loss,
            "epoch_rcc_loss": epoch_rcc_loss,
            "epoch_ucc_acc": epoch_ucc_accuracy,
            "epoch_rcc_acc": epoch_rcc_accuracy,
            "epoch_val_loss": epoch_val_loss,
            "epoch_val_ae_loss": epoch_val_ae_loss,
            "epoch_val_ucc_loss": epoch_val_ucc_loss,
            "epoch_val_rcc_loss": epoch_val_rcc_loss,
            "epoch_val_ucc_acc": epoch_val_ucc_accuracy,
            "epoch_val_rcc_acc": epoch_val_rcc_accuracy
        }

    def store_running_history_hook(self, epoch, avg_train_stats, avg_val_stats):
        self.epoch_numbers.append(epoch + 1)

        self.training_ae_losses.append(avg_train_stats["avg_ae_loss"])
        self.training_ucc_losses.append(avg_train_stats["avg_ucc_loss"])
        self.training_rcc_losses.append(avg_train_stats["avg_rcc_loss"])
        self.training_losses.append(avg_train_stats["avg_training_loss"])
        self.training_ucc_accuracies.append(avg_train_stats["avg_ucc_training_accuracy"])
        self.training_rcc_accuracies.append(avg_train_stats["avg_rcc_training_accuracy"])

        self.val_ae_losses.append(avg_val_stats["avg_val_ae_loss"])
        self.val_ucc_losses.append(avg_val_stats["avg_val_ucc_loss"])
        self.val_rcc_losses.append(avg_val_stats["avg_val_rcc_loss"])
        self.val_losses.append(avg_val_stats["avg_val_loss"])
        self.val_ucc_accuracies.append(avg_val_stats["avg_val_ucc_training_accuracy"])
        self.val_rcc_accuracies.append(avg_val_stats["avg_val_rcc_training_accuracy"])

    def get_current_running_history_state_hook(self):
        return self.epoch_numbers, \
            self.training_ae_losses, self.training_ucc_losses, self.training_rcc_losses, self.training_losses, self.training_ucc_accuracies, self.training_rcc_accuracies, \
            self.val_ae_losses, self.val_ucc_losses, self.val_rcc_losses, self.val_losses, self.val_ucc_accuracies, self.val_rcc_accuracies

    def save_model_checkpoint_hook(self, epoch):
        # set it to train mode to save the weights (but doesn't matter apparently!)
        self.rcc_model.train()

        # create the directory if it doesn't exist
        model_save_directory = os.path.join(self.save_dir, self.name)
        os.makedirs(model_save_directory, exist_ok=True)

        # Checkpoint the model at the end of each epoch
        checkpoint_path = os.path.join(model_save_directory, f'model_epoch_{epoch + 1}.pt')
        torch.save(
            {
                'rcc_model_state_dict': self.rcc_model.state_dict(),
                'rcc_optimizer_state_dict': self.rcc_optimizer.state_dict(),
                'epoch': epoch + 1,
                'epoch_numbers': self.epoch_numbers,
                'training_losses': self.training_losses,
                'training_ae_losses': self.training_ae_losses,
                'training_ucc_losses': self.training_ucc_losses,
                'training_rcc_losses': self.training_rcc_losses,
                'training_ucc_accuracies': self.training_ucc_accuracies,
                'training_rcc_accuracies': self.training_rcc_accuracies,
                'val_losses': self.val_losses,
                'val_ae_losses': self.val_ae_losses,
                'val_ucc_losses': self.val_ucc_losses,
                'val_rcc_losses': self.val_rcc_losses,
                'val_ucc_accuracies': self.val_ucc_accuracies,
                'val_rcc_accuracies': self.val_rcc_accuracies
            },
            checkpoint_path
        )
        print(f"Saved the model checkpoint for experiment {self.name} for epoch {epoch + 1}")

    def test_model(self):
        # class level init
        self.eval_ucc_correct_predictions = 0
        self.eval_ucc_total_batches = 0

        self.eval_rcc_correct_predictions = 0
        self.eval_rcc_total_batches = 0

        test_loss = 0.0
        test_ae_loss = 0.0
        test_ucc_loss = 0.0
        test_rcc_loss = 0.0

        with torch.no_grad():
            # set all models to eval mode
            self.rcc_model.eval()

            for test_batch_idx, test_data in enumerate(self.test_loader):
                test_images, test_one_hot_ucc_labels, test_rcc_labels = test_data

                # forward propogate through the model
                test_rcc_logits, test_ucc_logits, test_decoded = self.rcc_model(test_images)

                # calculate losses from both models for a batch of bags
                test_batch_ae_loss = self.calculate_autoencoder_loss(test_images, test_decoded)
                test_batch_ucc_loss, test_batch_ucc_accuracy = self.calculate_ucc_loss_and_acc(test_ucc_logits,
                                                                                               test_one_hot_ucc_labels,
                                                                                               False)
                test_batch_rcc_loss, test_batch_rcc_accuracy = self.calculate_rcc_loss_and_acc(test_rcc_logits,
                                                                                               test_rcc_labels,
                                                                                               False)

                # calculate combined loss
                test_batch_loss = test_batch_ae_loss + test_batch_ucc_loss + test_batch_rcc_loss

                # cummulate the losses
                test_loss += test_batch_loss.item()
                test_ae_loss += test_batch_ae_loss.item()
                test_ucc_loss += test_batch_ucc_loss.item()
                test_rcc_loss += test_batch_rcc_loss.item()

        # Calculate average validation loss for the epoch
        no_of_bags = len(self.test_loader) * config.batch_size
        avg_test_loss = test_loss / no_of_bags
        avg_test_ucc_loss = test_ucc_loss / no_of_bags
        avg_test_rcc_loss = test_rcc_loss / no_of_bags
        avg_test_ae_loss = test_ae_loss / no_of_bags
        avg_test_ucc_training_accuracy = self.eval_ucc_correct_predictions / self.eval_ucc_total_batches
        avg_test_rcc_training_accuracy = self.eval_rcc_correct_predictions / self.eval_rcc_total_batches

        # show some sample predictions
        self.show_sample_reconstructions(self.test_loader)

        return {
            "avg_test_loss": avg_test_loss,
            "avg_test_ae_loss": avg_test_ae_loss,
            "avg_test_ucc_loss": avg_test_ucc_loss,
            "avg_test_rcc_loss": avg_test_rcc_loss,
            "avg_test_ucc_training_accuracy": avg_test_ucc_training_accuracy,
            "avg_test_rcc_training_accuracy": avg_test_rcc_training_accuracy
        }

    def show_sample_reconstructions(self, dataloader):
        # Create a subplot grid
        fig, axes = plt.subplots(1, 2, figsize=(3, 3))

        with torch.no_grad():
            # set all models to eval mode
            self.rcc_model.eval()

            for val_data in dataloader:
                val_images, _, _ = val_data

                # reshape to appropriate size
                batch_size, bag_size, num_channels, height, width = val_images.size()
                bag_val_images = val_images.view(batch_size * bag_size, num_channels, height, width)
                print("Reshaped the original image into bag format")

                # forward propagate through the model
                _, _, val_reconstructed_images = self.rcc_model(val_images)
                print("Got a sample reconstruction, now trying to reshape in order to show an example")

                # take only one image from the bag
                sample_image = bag_val_images[0]
                predicted_image = val_reconstructed_images[0]

                # get it to cpu
                sample_image = sample_image.to("cpu")
                predicted_image = predicted_image.to("cpu")

                # convert to PIL Image
                sample_image = self.tensor_to_img_transform(sample_image)
                predicted_image = self.tensor_to_img_transform(predicted_image)

                axes[0].imshow(sample_image)
                axes[0].set_title(f"Orig", color='green')
                axes[0].axis('off')

                axes[1].imshow(predicted_image)
                axes[1].set_title(f"Recon", color='red')
                axes[1].axis('off')

                # show only one image
                break

        plt.tight_layout()
        plt.show()

    def js_divergence(self, p, q):
        """
        Calculate the Jensen-Shannon Divergence between two probability distributions p and q.

        Args:
        p (torch.Tensor): Probability distribution p.
        q (torch.Tensor): Probability distribution q.

        Returns:
        torch.Tensor: Jensen-Shannon Divergence between p and q.
        """
        # Calculate the average distribution 'm'
        m = 0.5 * (p + q)

        # Calculate the KL Divergence of 'p' and 'q' from 'm'
        kl_div_p = F.kl_div(p.log(), m, reduction='batchmean')
        kl_div_q = F.kl_div(q.log(), m, reduction='batchmean')

        # Compute the JS Divergence
        js_divergence = 0.5 * (kl_div_p + kl_div_q)

        return js_divergence

    def calculate_min_js_divergence(self):
        num_classes = len(self.kde_loaders)
        kde_per_class = {class_idx: 0.0 for class_idx in range(num_classes)}

        # find the average kde across all classes
        for class_idx, pure_class_kde_loader in tqdm(enumerate(self.kde_loaders)):
            num_bags_in_class = 0
            for images in pure_class_kde_loader:
                # get the first element
                images = images[0]

                # Stage.1 pass through the encoder to get the latent features
                # batch data is of shape ( Batch,bag, 3,32,32)
                batch_size, bag_size, num_channels, height, width = images.size()
                # reshaping to shape ( batch * bag, 3 ,32,32)
                batches_of_bag_images = images.view(batch_size * bag_size, num_channels, height, width).to(
                    torch.float32)
                latent_features = self.rcc_model.autoencoder.encoder(
                    batches_of_bag_images)  # shape (Batch * bag, 48*16)
                latent_features = latent_features.to(torch.float32)

                # Stage.2 pass through KDE
                # encoded is of shape [Batch * Bag, 48*16] ->  make it into shape [Batch, Bag, 48*16]
                batch_times_bag_size, feature_size = latent_features.size()
                bag_size = config.bag_size
                batch_size = batch_times_bag_size // bag_size
                latent_features = latent_features.view(batch_size, bag_size, feature_size)
                batch_kde_distributions = self.rcc_model.ucc_predictor.kde(latent_features)  # shape [Batch=1, 8448]

                # Stage.3 Take sum
                num_bags_in_class += batch_kde_distributions.size(0)
                kde_distributions = torch.sum(batch_kde_distributions, dim=0)
                kde_per_class[class_idx] += kde_distributions

            # Stage.4 Take average
            kde_per_class[class_idx] /= num_bags_in_class

        # find the js_divergence
        min_divergence = torch.inf
        best_i = None
        best_j = None
        for i in range(num_classes):
            for j in range(i + 1, num_classes):
                divergence = self.js_divergence(kde_per_class[i], kde_per_class[j])
                print(f"JS Divergence between {i} & {j} is {divergence}")
                if divergence < min_divergence:
                    min_divergence = divergence
                    best_i = i
                    best_j = j

        print(f"Min JS Divergence is {min_divergence} between classes {best_i} & {best_j}")
        # return the min divergence
        return min_divergence

    def calculate_clustering_accuracy(self):
        all_latent_features = []
        truth_labels_arr = []
        for pure_autoencoder_loader in self.autoencoder_loaders:
            for batch_idx, data in tqdm(enumerate(pure_autoencoder_loader)):
                # batch data is of shape (1,3,32,32), (1,1)
                image, label = data
                latent_features = self.rcc_model.autoencoder.encoder(image)  # shape (1, 48*16)

                latent_features = latent_features.squeeze().detach().cpu().numpy()  # ndarray shape (48*16)
                label = label.squeeze().detach().cpu().numpy()  # ndarray shape (1)

                all_latent_features.append(latent_features)
                truth_labels_arr.append(label.item())

        all_latent_features = np.array(all_latent_features)
        truth_labels_arr = np.array(truth_labels_arr)
        print("Got the latent features for all test images, now doing Kmeans")

        # Do kmeans fit
        estimator = KMeans(n_clusters=10, init='k-means++', n_init=10)
        estimator.fit(all_latent_features)
        predicted_clustering_labels = estimator.labels_

        print("Got the kmeans predicted labels, now computing clustering accuracy")

        # Calculate accuracy
        cost_matrix = np.zeros((10, 10))
        num_samples = np.zeros(10)
        for truth_val in range(10):
            temp_sample_indices = np.where(truth_labels_arr == truth_val)[0]
            num_samples[truth_val] = temp_sample_indices.shape[0]

            temp_predicted_labels = predicted_clustering_labels[temp_sample_indices]

            for predicted_val in range(10):
                temp_matching_pairs = np.where(temp_predicted_labels == predicted_val)[0]
                cost_matrix[truth_val, predicted_val] = 1 - (
                        temp_matching_pairs.shape[0] / temp_sample_indices.shape[0])

        row_ind, col_ind = linear_sum_assignment(cost_matrix)
        cost = cost_matrix[row_ind, col_ind]

        clustering_acc = ((1 - cost) * num_samples).sum() / num_samples.sum()
        return clustering_acc

    # find the most recent file and return the path
    def get_model_checkpoint_path(self, epoch_num=None):
        directory = os.path.join(self.save_dir, self.name)
        if epoch_num == None:
            # Get a list of all files in the directory
            files = os.listdir(directory)

            # Filter out only the files (exclude directories)
            files = [f for f in files if os.path.isfile(os.path.join(directory, f))]

            # Sort the files by their modification time in descending order (most recent first)
            files.sort(key=lambda x: os.path.getmtime(os.path.join(directory, x)), reverse=True)

            # Get the name of the most recently added file
            model_file = files[0] if files else None
        else:
            model_file = f"model_epoch_{epoch_num}.pt"
        return os.path.join(directory, model_file)


## Creating the model instances


In [None]:
experiment2 = "ucc-rcc"
save_dir = os.path.abspath(config.weights_path)
rcc_model = CombinedRCCModel(config.device).to(config.device)

#creating the trainer
rcc_trainer = RCCTrainer(experiment2, rcc_model, dataset, save_dir)

## Training the model

In [None]:
exp2_epoch_numbers, exp2_training_ae_losses, exp2_training_ucc_losses, exp2_training_rcc_losses, exp2_training_losses, exp2_training_ucc_accuracies, exp2_training_rcc_accuracies, exp2_val_ae_losses, exp2_val_ucc_losses, exp2_val_rcc_losses, exp2_val_losses, exp2_val_ucc_accuracies, exp2_val_rcc_accuracies = rcc_trainer.train(1, epoch_saver_count=1)

## Additional Training if required


In [None]:
 # exp2_epoch_numbers, exp2_training_ae_losses, exp2_training_ucc_losses, exp2_training_rcc_losses, exp2_training_losses, exp2_training_ucc_accuracies, exp2_training_rcc_accuracies, exp2_val_ae_losses, exp2_val_ucc_losses, exp2_val_rcc_losses, exp2_val_losses, exp2_val_ucc_accuracies, exp2_val_rcc_accuracies = rcc_trainer.train(10, epoch_saver_count=2, load_from_checkpoint=True, resume_epoch_num=42)

## Plotting the model stats

In [None]:
plot_ucc_rcc_model_stats(experiment2, exp2_epoch_numbers, exp2_training_ucc_losses, exp2_training_ae_losses,
                         exp2_training_rcc_losses, exp2_training_losses, exp2_training_ucc_accuracies,
                         exp2_training_rcc_accuracies, exp2_val_ucc_losses, exp2_val_ae_losses, exp2_val_rcc_losses,
                         exp2_val_losses, exp2_val_ucc_accuracies, exp2_val_rcc_accuracies)

## Testing the model

In [None]:
rcc_trainer.test_model()

## Calculating the Min JS Divergence

In [None]:
exp2_min_js_divg = rcc_trainer.calculate_min_js_divergence()
exp2_min_js_divg

## Calculating the Clustering Accuracy

In [None]:
exp2_clustering_accuracies = rcc_trainer.calculate_clustering_accuracy()
exp2_clustering_accuracies