# TO-DO

We need to try/finish to do:
- 1 Group: 
    - Implement attention map.
    - Try patchyfing with convolutions.
    - Find out how Tensorboard works. 
- 2 Group:
    - Try Ensemble on Classification (Potential Bright Idea :) 
    - Try different versions of Attention (ARPR).
    - Try different versions of Embedding (LaPE, 2D Positional Embeddings). 

- Others:
    - Make a version with MNIST. 
    - Put LogSoftmax also in the MHSA? 
    - Retry to use Augmentation. 
    - Start writing Report. 

### QUESTIONS/DOUBTS
- 3. A-ViT, second column, after equation (4): ..."We incorporate H(.) into the existing Vision trasnformer block by allocating a single neuron in the MLP layer to do the task". 

### REFERENCES

- https://arxiv.org/pdf/2112.07658.pdf
- https://medium.com/mlearning-ai/vision-transformers-from-scratch-pytorch-a-step-by-step-guide-96c3313c2e0c

### 0: IMPORTING LIBRARIES AND SETTING THE SEEDS

In [None]:

# Importing PyTorch-related libraries
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
from torchvision import transforms
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToPILImage
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchmetrics.classification import Accuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall

# Importing PyTorch Lightning-Related libraries
import pytorch_lightning as pl
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from pytorch_lightning.callbacks import TQDMProgressBar, LearningRateMonitor, ModelCheckpoint

# Importing General Libraries
import os
import csv
import PIL
import random
import numpy as np
from PIL import Image
import seaborn as sns
from pathlib import Path
from scipy.stats import norm
import matplotlib.pyplot as plt
from collections import OrderedDict


In [None]:

def seed_everything(seed):
    """
    Seeds basic parameters for reproductibility of results.
    
    Arguments:
        - seed {int} : Number of the seed.
    """
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    pl.seed_everything(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
# Set the seed.
seed_everything(31)

### 1: DATA INSPECTION

#### 1.1: CREATION OF THE LABEL DICTIONARY

In [None]:

# Initialize the Mapping Dictionary to be empty.
mapping_dict = {}

# Open the file in read mode.
with open('/kaggle/input/tiny-imagenet/tiny-imagenet-200/words.txt', 'r') as file:
    
    # Read each line from the file.
    for line in file:
        # Split the line into tokens based on whitespace.
        tokens = line.strip().split('\t')
        
        # Check if there are at least two tokens.
        if len(tokens) >= 2:
            # Extract the encoded label (left) and actual label (right).
            encoded_label, actual_label = tokens[0], tokens[1]
            
            # Add the mapping to the dictionary.
            mapping_dict[encoded_label] = actual_label

# Print the mapping dictionary.
#print(mapping_dict)


#### 1.2: DISPLAYING EXAMPLES OF THE DATASET

In [None]:

# Loading the dataset using ImageFolder.
dataset0 = datasets.ImageFolder(root="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train/", transform=None)

# Extract class names and their counts.
class_names = dataset0.classes
class_counts = [dataset0.targets.count(i) for i in range(len(class_names))]

# Setting the seed.
np.random.seed(31)

# Create a grid of 10 images with labels.
plt.figure(figsize=(15, 8))
for i in range(10):
    
    # Randomly select an image and its corresponding label.
    index = np.random.randint(len(dataset0))
    image, label = dataset0[index]

    # Display the image with its label
    plt.subplot(2, 5, i+1)
    plt.imshow(np.array(image))  # Convert the PIL Image to a numpy array
    plt.title(f"Label: {class_names[label]}")
    plt.axis('off')

# Displaying Datasets examples.
plt.tight_layout()
plt.show()

#### 1.3: DISPLAYING EXAMPLES OF THE DATASET WITH DECODED LABELS

In [None]:

# Loading the dataset using ImageFolder.
dataset0 = datasets.ImageFolder(root="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train/", transform=None)

# Extract class names and their counts.
class_names = dataset0.classes
class_counts = [dataset0.targets.count(i) for i in range(len(class_names))]

# Setting the seed.
np.random.seed(31)

# Create a grid of 10 images with labels.
plt.figure(figsize=(15, 8))

for i in range(10):
    
    # Randomly select an image and its corresponding label.
    index = np.random.randint(len(dataset0))
    image, encoded_label = dataset0[index]
    
    # Look up the actual label using the mapping dictionary.
    actual_label = mapping_dict.get(class_names[encoded_label], "Unknown Label")
    
    # Trim the label if it exceeds the maximum length.
    actual_label_trimmed = actual_label[:15] + '...' if len(actual_label) > 15 else actual_label

    # Display the image with its label..
    plt.subplot(2, 5, i+1)
    plt.imshow(np.array(image))  
    plt.title(f"Label: {actual_label_trimmed}", wrap=True)
    plt.axis('off')

# Displaying Dataset examples.
plt.tight_layout()
plt.show()

### 2: DATA-MODULE DEFINITION

#### 2.0: CUSTOMIZED TRANSFORM CLASS

In [None]:

class AdaViT_Transformations:
    
    def __init__(self):
        
        # Constructor - Nothing to initialize in this case
        pass

    def __call__(self, sample):
        """
        Call method to perform transformations on the input sample.

        Args:
        - sample (PIL.Image.Image or torch.Tensor): Input image sample.

        Returns:
        - transformed_sample (torch.Tensor): Transformed image sample.
        """

        # Define a series of image transformations using "torchvision.Compose" function.
        transform = transforms.Compose([
            transforms.ToTensor(),  
            # Additional transformations can be added here.
            # transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),  
            # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  
        ])

        # Apply the defined transformations to the input sample.
        transformed_sample = transform(sample)

        return transformed_sample

#### 2.1: CUSTOMIZED TRAINING SET VERSION

In [None]:

class CustomTrainingTinyImagenet(ImageFolder):
    
    def __init__(self, root, transform=None):
        """
        Custom dataset class for Tiny ImageNet Training data.

        Args:
        - root (str): Root directory containing the dataset.
        - transform (callable, optional): Optional transform to be applied to the Input Image.
        """
        super(CustomTrainingTinyImagenet, self).__init__(root, transform=transform)

        # Create mappings between class labels and numerical indices
        self.class_to_index = {cls: idx for idx, cls in enumerate(sorted(self.classes))}
        self.index_to_class = {idx: cls for cls, idx in self.class_to_index.items()}

    def __getitem__(self, index):
        """
        Method to retrieve an item from the dataset.

        Args:
        - index (int): Index of the item to retrieve.

        Returns:
        - sample (torch.Tensor): Transformed image sample.
        - target (int): Numerical index corresponding to the class label.
        """
        # Retrieve the item and its label from the Dataset.
        path, target = self.samples[index]

        # Load the image using the default loader.
        sample = self.loader(path)

        # Apply the specified transformations, if any.
        if self.transform is not None:
            sample = self.transform(sample)

        # Adjust the directory depth to get the target label.
        target_str = os.path.basename(os.path.dirname(os.path.dirname(path)))

        # Convert string label to numerical index using the mapping.
        target = self.class_to_index[target_str]

        return sample, target

    def get_class_from_index(self, index):
        """
        Method to retrieve the class label from a numerical index.

        Args:
        - index (int): Numerical index corresponding to the class label.

        Returns:
        - class_label (str): Class label corresponding to the numerical index.
        """
        
        return self.index_to_class[index]

#### 2.2: CUSTOMIZED VALIDATION SET VERSION

In [None]:

class CustomValidationTinyImagenet(pl.LightningDataModule):
    
    def __init__(self, root, transform=None):
        """
        Custom data module for Tiny ImageNet Validation data.

        Args:
        - root (str): Root directory containing the dataset.
        - transform (callable, optional): Optional transform to be applied to the Input Image.
        """
        self.root = Path(root)
        self.transform = transform

        # Load and preprocess labels
        self.labels = self.load_labels()
        self.label_to_index = {label: idx for idx, label in enumerate(sorted(set(self.labels.values())))}
        self.index_to_label = {idx: label for label, idx in self.label_to_index.items()}

    def load_labels(self):
        """
        Method to load and Pre-Process Labels from the Validation Dataset.

        Returns:
        - labels (dict): Dictionary mapping image names to labels.
        """
        label_path = "/kaggle/input/tiny-imagenet/tiny-imagenet-200/val/val_annotations.txt"
        labels = {}

        with open(label_path, "r") as f:
            lines = f.readlines()

        for line in lines:
            parts = line.split("\t")
            image_name, label = parts[0], parts[1]
            labels[image_name] = label

        return labels

    def __len__(self):
        """
        Method to get the length of the dataset.

        Returns:
        - length (int): Number of items in the dataset.
        """
        return len(self.labels)

    def __getitem__(self, index):
        """
        Method to retrieve an item from the dataset.

        Args:
        - index (int): Index of the item to retrieve.

        Returns:
        - image (torch.Tensor): Transformed image sample.
        - label (int): Numerical index corresponding to the class label.
        """
        image_name = f"val_{index}.JPEG"
        image_path = self.root / image_name

        # Open the image using PIL and convert to RGB.
        image = Image.open(image_path).convert("RGB")

        # Apply the specified transformations, if any.
        if self.transform:
            image = self.transform(image)

        # Use the get method to handle cases where the key is not present.
        label_str = self.labels.get(image_name, 'Label not found')

        # Convert string label to numerical index using the mapping.
        label = self.label_to_index[label_str]

        return image, label

    def get_label_from_index(self, index):
        """
        Method to retrieve the class label from a numerical index.

        Args:
        - index (int): Numerical index corresponding to the class label.

        Returns:
        - class_label (str): Class label corresponding to the numerical index.
        """
        return self.index_to_label[index]

#### 2.3: GENERAL DATA-MODULE DEFINITION

In [None]:

class AViT_DataModule(pl.LightningDataModule):
    
    def __init__(self, train_data_dir, val_data_dir, batch_size, num_workers=4):
        """
        Custom data module for AViT model training and validation.

        Args:
        - train_data_dir (str): Directory path for the training dataset.
        - val_data_dir (str): Directory path for the validation dataset.
        - batch_size (int): Batch size for training and validation DataLoader.
        - num_workers (int, optional): Number of workers for DataLoader (default is 4).
        """
        super(AViT_DataModule, self).__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        # Use AdaViT transformations for data augmentation
        self.transform = AdaViT_Transformations()

    def setup(self, stage=None):
        """
        Method to load and configure datasets for Training and Validation.

        Args:
        - stage (str, optional): 'fit' for Training and 'test' for Validation (default is None).
        """
        # Load Train dataset using CustomTrainingTinyImagenet with the new directory structure.
        self.train_dataset = CustomTrainingTinyImagenet(self.train_data_dir, transform=self.transform)

        # Load Validation dataset.
        self.val_dataset = CustomValidationTinyImagenet(self.val_data_dir, transform=self.transform)

    def train_dataloader(self):
        """
        Method to return the DataLoader for the Training Dataset.

        Returns:
        - train_dataloader (DataLoader): DataLoader for Training.
        """
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        """
        Method to return the DataLoader for the Validation Dataset.

        Returns:
        - val_dataloader (DataLoader): DataLoader for Validation.
        """
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

#### 2.4: TESTING TRAINING AND VALIDATION DATALOADERS

In [None]:

def show_images_labels(images, labels, title):
    """
    Display Images with corresponding Labels.

    Parameters:
    - images (list of tensors): List of Image tensors.
    - labels (list): List of corresponding Labels.
    - title (str): Title for the entire subplot.

    Returns:
    None
    """
    # Create a Subplot with 1 row and len(images) columns.
    fig, axs = plt.subplots(1, len(images), figsize=(8, 4))
    
    # Set the title for the entire subplot.
    fig.suptitle(title)

    # Iterate over Images and Labels.
    for i, (img, label) in enumerate(zip(images, labels)):
        # Display each Image in a subplot.
        axs[i].imshow(transforms.ToPILImage()(img))
        
        # Set the title for each subplot with the corresponding label.
        axs[i].set_title(f"Label: {label}")
        
        # Turn off axis labels for better Visualization.
        axs[i].axis('off')

    # Show the entire subplot.
    plt.show()


In [None]:

# Define the AViT_DataModule.
data_module = AViT_DataModule(
    train_data_dir="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train/",
    val_data_dir="/kaggle/input/tiny-imagenet/tiny-imagenet-200/val/images/",
    batch_size=512  
)

# Setup the Dataloaders.
data_module.setup()

# Get a batch from the Training DataLoader.
train_dataloader = data_module.train_dataloader()
train_batch = next(iter(train_dataloader))

# Get a batch from the Validation DataLoader.
val_dataloader = data_module.val_dataloader()
val_batch = next(iter(val_dataloader))

# Show two Images from the Training Batch.
show_images_labels(train_batch[0][:2], train_batch[1][:2], title='Training Batch')

# Show two Images from the  Validation Batch
show_images_labels(val_batch[0][:2], val_batch[1][:2], title='Validation Batch')

### 3: MODEL DEFINITION

#### 3.0: PATCHING FUNCTION DEFINITION

In [None]:

def Make_Patches_from_Image(images, n_patches):
    """
    Extract patches from input images.

    Parameters:
    - images (torch.Tensor): Input images tensor with shape (batch_size, channels, height, width).
    - n_patches (int): Number of patches in each dimension.

    Returns:
    torch.Tensor: Extracted patches tensor with shape (batch_size, n_patches^2, patch_size^2 * channels).
    """
    # Get the dimensions of the input images.
    n, c, h, w = images.shape

    # Ensure that the input images are square.
    assert h == w, "make_patches_from_image method is implemented for square images only!"

    # Initialize a tensor to store the extracted patches.
    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    # Loop over each image in the batch.
    for idx, image in enumerate(images):
        # Loop over each patch in both dimensions.
        for i in range(n_patches):
            for j in range(n_patches):
                # Extract the patch from the image.
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                # Flatten the patch and store it in the patches tensor.
                patches[idx, i * n_patches + j] = patch.flatten()

    return patches


In [None]:

# Helper function to Visualize Patches.
def visualize_patches(images, n_patches, title):
    """
    Visualize patches extracted from Images.

    Parameters:
    - images (torch.Tensor): Input images tensor with shape (batch_size, channels, height, width).
    - n_patches (int): Number of patches in each dimension.
    - title (str): Title for the entire subplot.

    Returns:
    None
    """
    # Extract patches from the input images using the make_patches_from_image function.
    patches = Make_Patches_from_Image(images, n_patches)
    
    # Create a subplot for visualizing patches.
    fig, axs = plt.subplots(n_patches, n_patches, figsize=(8, 8))
    fig.suptitle(title)
    
    # Calculate the patch size based on the input images.
    patch_size = images.shape[-1] // n_patches

    # Loop over each patch in both dimensions.
    for i in range(n_patches):
        for j in range(n_patches):
            # Calculate the index of the patch.
            patch_index = i * n_patches + j
            # Reshape each patch to (3, patch_size, patch_size).
            patch = patches[0, patch_index].reshape(3, patch_size, patch_size).cpu().numpy()
            # Display the patch in the subplot.
            axs[i, j].imshow(patch.transpose(1, 2, 0))
            axs[i, j].axis('off')

    # Show the entire subplot.
    plt.show()


In [None]:
# Visualize patches for a Training Image.
visualize_patches(train_batch[0], n_patches=8, title='Training Patches')

# Visualize patches for a Validation Image
visualize_patches(val_batch[0], n_patches=8, title='Validation Patches')

#### 3.1: POSITIONAL EMBEDDING DEFINITION

In [None]:
def get_positional_embeddings_Basic(sequence_length, d):
    result = torch.ones(sequence_length, d)
    for i in range(sequence_length):
        for j in range(d):
            result[i][j] = np.sin(i / (10000 ** (j / d))) if j % 2 == 0 else np.cos(i / (10000 ** ((j - 1) / d)))
    return result


##### 3.1.1: SINUSOIDAL POSITIONAL ENCODING (SPE)



In [None]:

def get_positional_embeddings_SPE(sequence_length, d):
    """
    Generate Positional Embeddings for the Transformer Model.

    Parameters:
    - sequence_length (int): Length of the input sequence.
    - d (int): Dimension of the embeddings.

    Returns:
    torch.Tensor: Positional Embeddings tensor of shape (sequence_length, d).
    """
    # Generate a tensor of positions from 0 to sequence_length - 1.
    positions = torch.arange(0, sequence_length).float().view(-1, 1)
    
    # Calculate div_term for both sin and cos terms.
    div_term = torch.exp(torch.arange(0, d, 2).float() * -(np.log(10000.0) / d))

    # Initialize the embeddings tensor with zeros.
    embeddings = torch.zeros(sequence_length, d)
    
    # Compute sin and cos terms and assign them to the embeddings tensor.
    embeddings[:, 0::2] = torch.sin(positions / div_term)
    embeddings[:, 1::2] = torch.cos(positions / div_term)

    return embeddings

##### 3.1.2: LAYER-ADAPTIVE POSITIONAL EMBEDDING (LaPE)

In [None]:

def get_positional_embeddings_LaPE(sequence_length, d, num_layers):
    """
    Generate Layer-adaptive Positional Embeddings for the Transformer Model.

    Parameters:
    - sequence_length (int): Length of the input sequence.
    - d (int): Dimension of the embeddings.
    - num_layers (int): Number of layers in the Transformer model.

    Returns:
    torch.Tensor: Layer-adaptive Positional Embeddings tensor of shape (sequence_length, d, num_layers).
    """
    # Generate a tensor of positions from 0 to sequence_length - 1.
    positions = torch.arange(0, sequence_length).float().view(-1, 1)

    # Precompute div_terms for each layer.
    div_terms = torch.exp(torch.arange(0, d, 2).float() * -(np.log(10000.0) / d))

    # Initialize the embeddings tensor with zeros.
    embeddings = torch.zeros(sequence_length, d, num_layers)

    # Divide the sequence_length by 2 once for efficiency.
    seq_len_div_2 = sequence_length // 2

    # Compute sin and cos terms for each layer and assign them to the embeddings tensor.
    for layer in range(num_layers):
        embeddings[:, :, layer][:, 0:seq_len_div_2] = torch.sin(positions / div_terms[layer])
        embeddings[:, :, layer][:, seq_len_div_2:] = torch.cos(positions / div_terms[layer])

    return embeddings


##### 3.1.2.1: IMAGE POSITIONAL EMBEDDINGS.

##### 3.1.3: VISUALIZE POSITIONAL EMBEDDINGS

In [None]:

# Helper function to Visualize Positional Embeddings.
def visualize_positional_embeddings(embeddings):
    """
    Visualize the Positional Embeddings.

    Parameters:
    - embeddings (torch.Tensor): Positional embeddings tensor.

    Returns:
    None
    """
    
    # Get the number of dimensions (d) from the Embeddings Tensor.
    d = embeddings.size(1)

    # Set the figure size for a larger image.
    plt.figure(figsize=(12, 6))

    # Plot each dimension separately.
    for i in range(d):
        plt.plot(embeddings[:, i].numpy(), label=f'Dimension {i}')

    # Set plot labels.
    plt.xlabel('Position')
    plt.ylabel('Embedding Value')
    plt.title('Visualization of Positional Embeddings')

    # Place the legend on the right and diminish its size.
    plt.legend(loc='center left', bbox_to_anchor=(1, 0.5), fontsize='small')
    
    # Show the plot.
    plt.show()


In [None]:
# Helper function to Visualize Positional Embeddings as a Heatmap.
def visualize_positional_embeddings_heatmap(embeddings):
    """
    Visualize the Positional Embeddings as a Heatmap.

    Parameters:
    - embeddings (torch.Tensor): Positional embeddings tensor.

    Returns:
    None
    """
    
    # Get the number of dimensions (d) from the Embeddings Tensor.
    d = embeddings.size(1)

    # Set the figure size for a larger image.
    plt.figure(figsize=(12, 6))

    # Create a heatmap for the positional embeddings.
    sns.heatmap(embeddings.T.numpy(), cmap='viridis', cbar_kws={'label': 'Embedding Value'})

    # Set plot labels and title.
    plt.xlabel('Position')
    plt.ylabel('Dimension')
    plt.title('Visualization of Positional Embeddings (Heatmap)')
    
    # Show the plot.
    plt.show()


In [None]:
positional_embeddings = get_positional_embeddings_SPE(65, 32)
visualize_positional_embeddings(positional_embeddings)

In [None]:
positional_embeddings = get_positional_embeddings_Basic(65, 32)
visualize_positional_embeddings(positional_embeddings)

In [None]:
positional_embeddings = get_positional_embeddings_LaPE(65, 32, 4)
visualize_positional_embeddings(positional_embeddings)

In [None]:
positional_embeddings = get_positional_embeddings_SPE(65, 32)
visualize_positional_embeddings_heatmap(positional_embeddings)

In [None]:
positional_embeddings = get_positional_embeddings_Basic(65, 32)
visualize_positional_embeddings_heatmap(positional_embeddings)

#### 3.2: MULTI-HEAD SELF-ATTENTION DEFINITION

In [None]:

class MyMHSA(nn.Module):
    
    def __init__(self, d, n_heads=2):
        """
        Multi-Head Self Attention (MHSA) Module.

        Parameters:
        - d (int): Dimension of the input tokens.
        - n_heads (int): Number of attention heads.

        Returns:
        None
        """
        
        super(MyMHSA, self).__init__()
        self.d = d
        self.n_heads = n_heads

        assert d % n_heads == 0, f"Can't divide dimension {d} into {n_heads} heads"

        # Split the dimension into n_heads parts.
        d_head = int(d / n_heads)
        
        # Linear mappings for Query(q), Key(k), and Value(v) for each head.
        self.q_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(d_head, d_head) for _ in range(self.n_heads)])
        
        self.d_head = d_head
        self.softmax = nn.Softmax(dim=-1)
        
        # Initialize weights.
        self.initialize_weights_msa()
        
    def forward(self, sequences):
        """
        Forward pass of the MHSA module.

        Parameters:
        - sequences (torch.Tensor): Input token sequences with shape (N, seq_length, token_dim).

        Returns:
        torch.Tensor: Output tensor after MHSA with shape (N, seq_length, item_dim).
        """
        
        result = []
        for sequence in sequences:
            
            seq_result = []
            for head in range(self.n_heads):
                
                # Compute the q,k,v for every head. 
                q_mapping, k_mapping, v_mapping = self.q_mappings[head], self.k_mappings[head], self.v_mappings[head]

                # Extract the corresponding part of the sequence for the current head.
                seq = sequence[:, head * self.d_head: (head + 1) * self.d_head]
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq)

                # Calculate attention scores and apply softmax.
                attention = self.softmax(q @ k.T / (self.d_head ** 0.5))
                seq_result.append(attention @ v)
            
            # Concatenate the results coming from the different Heads and Stack Vertically the result.
            result.append(torch.hstack(seq_result))
        
        # Concatenate results for all the sequences.
        return torch.cat([torch.unsqueeze(r, dim=0) for r in result])
    
    def initialize_weights_msa(self):
        """
        Initialize weights for linear layers in the MHSA module.

        Parameters:
        None

        Returns:
        None
        """
        
        # Initialize weights for the q, k, v values.
        for q_mapping, k_mapping, v_mapping in zip(self.q_mappings, self.k_mappings, self.v_mappings):
            nn.init.xavier_uniform_(q_mapping.weight)
            nn.init.xavier_uniform_(k_mapping.weight)
            nn.init.xavier_uniform_(v_mapping.weight)


#### 3.3: ViT BLOCK DEFINITION

In [None]:
class MyViTBlock(nn.Module):
    def __init__(self, hidden_d, n_heads, mlp_ratio=10):
        super(MyViTBlock, self).__init__()
        
        self.hidden_d = hidden_d
        self.n_heads = n_heads

        self.norm1 = nn.LayerNorm(hidden_d)
        self.mhsa = MyMHSA(hidden_d, n_heads)
        self.norm2 = nn.LayerNorm(hidden_d)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_d, mlp_ratio * hidden_d),
            nn.GELU(),
            nn.Linear(mlp_ratio * hidden_d, hidden_d)
        )
        
        # Initialize weights.
        self.initialize_weights_block()

    def forward(self, x):
        
        out = x + self.mhsa(self.norm1(x))
        out = out + self.mlp(self.norm2(out))
        return out
    
    def initialize_weights_block(self):
        
        # Initialize weights for linear layers in mlp.
        for layer in self.mlp:
            if isinstance(layer, nn.Linear):
                nn.init.xavier_uniform_(layer.weight)

#### 3.4: ViT MODEL DEFINITION

##### 3.4.1: TARGET DISTRIBUTIONAL PRIOR DEFINITION

In [None]:

def get_distribution_target(length=4, max=1, target_depth=3, buffer=0.02):
    """
    Generate the Target Distributional Prior.

    Parameters:
    - length (int): Length of the distribution.
    - max (float): Maximum value of the distribution.
    - target_depth (int): Depth of the target distribution.
    - buffer (float): Buffer to control scaling factor.

    Returns:
    numpy.ndarray: Target distributional prior.
    """
    
    # Generate a series of values from 0 to length - 1.
    data = np.arange(length)
    
    # Generate a Gausian Normal Distribution centered around target_depth.
    data = norm.pdf(data, loc=target_depth, scale=1)
    
    # Scale the distribution to have a maximum value of 1.
    scaling_factor = (1. - buffer) / sum(data[:target_depth])
    data *= scaling_factor

    return data


##### 3.4.2: MYVIT CLASS DEFINITION

In [None]:

class MyViT(nn.Module):
        
    def __init__(self, chw, n_patches, n_blocks, hidden_d, n_heads, out_d):
        """
        Initialize the MyViT model.

        Parameters:
        - chw (tuple): Input shape (C, H, W).
        - n_patches (int): Number of patches.
        - n_blocks (int): Number of transformer blocks.
        - hidden_d (int): Dimension of the hidden layer.
        - n_heads (int): Number of attention heads.
        - out_d (int): Output dimension.
        """
        
        # Super Constructor.
        super(MyViT, self).__init__()
        
        # Attributes.
        self.chw = chw # ( C , H , W )
        self.n_patches = n_patches
        self.n_blocks = n_blocks
        self.n_heads = n_heads
        self.hidden_d = hidden_d
        self.mlp_ratio=100
        
        # Halting Prior Distribution Loss and Target Distribution.
        self.ponder_loss = 0
        self.distr_prior_loss = 0
        self.kl_loss = nn.KLDivLoss(reduction='batchmean')
        self.distr_target = torch.Tensor(get_distribution_target())
        
        # Input and Patches Sizes.
        assert chw[1] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        assert chw[2] % n_patches == 0, "Input shape not entirely divisible by number of patches"
        self.patch_size = (chw[1] / n_patches, chw[2] / n_patches)

        # 1) Linear Mapper.
        self.input_d = int(chw[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.hidden_d)
        
        # 2) Learnable Classification Token.
        self.class_token = nn.Parameter(torch.rand(1, self.hidden_d))
        
        # 3) Positional Embedding.
        self.register_buffer('positional_embeddings', get_positional_embeddings_SPE(n_patches ** 2 + 1, hidden_d), persistent=False)
        
        # 4) Transformer Encoder Blocks.
        self.blocks = nn.ModuleList([MyViTBlock(hidden_d, n_heads) for _ in range(n_blocks)])
        
        # 5) Classification MLP.
        self.mlp = nn.Sequential(
            nn.Linear(self.hidden_d, self.mlp_ratio * self.hidden_d),
            nn.GELU(),
            nn.Linear(self.mlp_ratio * self.hidden_d, out_d),
            nn.LogSoftmax(dim=-1)
        )
        
        # Initialize weights.
        self.initialize_weights()

    def forward(self, images):
        """
        Forward pass of the MyViT model.

        Parameters:
        - images (torch.Tensor): Input images tensor.

        Returns:
        torch.Tensor: Output tensor.
        """
        
        # Dividing Images into Patches.
        n, c, h, w = images.shape
        patches = Make_Patches_from_Image(images, self.n_patches).to(self.positional_embeddings.device)
        
        # Running Linear Layer Tokenization.
        # Map the Vector corresponding to each patch to the Hidden Size Dimension.
        tokens = self.linear_mapper(patches)
        
        # Adding Classification Token to the Tokens.
        tokens = torch.cat((self.class_token.expand(n, 1, -1), tokens), dim=1)
        
        # Adding Positional Embedding.
        out = tokens + self.positional_embeddings.repeat(n, 1, 1)
        
        ### Halting Procedure ###
        total_token_count=len(out[1]) # out.shape = [512,65,32]
        bs = out.size()[0]  # The batch size
        c_token = torch.zeros(bs,total_token_count)
        r = torch.ones(bs,total_token_count)
        rho = torch.zeros(bs,total_token_count)
        mask = torch.ones(bs,total_token_count)
        
        # Halting Hyperparameters.
        gamma = 5
        beta = -10
        alpha_p = 5e-4
        alpha_d = 0.1
        eps = 0.01
        output = None  # Final output of Adaptive Vision Transformer Block.
        halting_score_layer=[] # List of Layer Halting Score Average.
        
        # Transformer Blocks.
        for i,block in enumerate(self.blocks):
            
            # Previous Layers Token are masked.
            out.data = out.data * mask.float().view(bs,total_token_count, 1)
            
            # Pass data trough each layer(block).
            out = block(out.data) #out.shape = [512,65,32]
            
            # Compute Halting Scores.                 
            t_0 = out[:,:,0] #out[:,:,0] = contains all the halting scores of images tokens ( t_O.shape = [512,65])
            h_score = torch.sigmoid(gamma*t_0 + beta)
            h=[-1,h_score]
            _, h_token = h
            
            # Update list with mean of halting score just computed.
            halting_score_layer.append(torch.mean(h[1][1:])) 
            
            # Set all token halting score to one if we reached last layer(block).
            if i == len(self.blocks)-1:
                h_token = torch.ones(bs,total_token_count) 
                
            # Last Layer Protection.
            out = out * mask.float().view(bs,total_token_count, 1) #out.shape = [512,65,32]
            
            # Update Accumulator.
            c_token = c_token + h_token #c_token.shape = [512,65]
            
            #update rho.
            rho += mask.float() #rho.shape = [512,65]
        
            # Case 1: Threshold eached in this Iteration.
            # token part
            reached_token = c_token > 1 - eps #shape [512,65]
            reached_token = reached_token.float() * mask.float()  #shape [512,65]
            delta1 = out * r.view(bs, total_token_count, 1) * reached_token.view(bs, total_token_count, 1) # [512,65,32] * [512,65,1] * [512,65,1]
            rho = rho + r * reached_token  #shape [512,65]

            # Case 2: Threshold not reached.
            # token part
            not_reached_token = c_token < 1 - eps
            not_reached_token = not_reached_token.float()
            r = r - (not_reached_token.float() * h_token)
            delta2 = out * h_token.view(bs, total_token_count, 1) * not_reached_token.view(bs, total_token_count, 1)
            
            # Update the mask.
            mask = c_token < 1 - eps
            
            if output is None:
                output = delta1 + delta2
            else:
                output = output + (delta1 + delta2)
                
        # Halting Prior Distribution.
        halting_score_distr = torch.stack(halting_score_layer)
        halting_score_distr = halting_score_distr / torch.sum(halting_score_distr)
        halting_score_distr = torch.clamp(halting_score_distr, 0.01, 0.99) 
        
        # Kullback-Leibler Divergence. 
        self.distr_prior_loss = alpha_d * self.kl_loss(halting_score_distr.log(), self.distr_target)
        
        # Ponder Loss.
        self.ponder_loss = alpha_p * torch.mean(rho)
        
        # Getting the Classification Token only.
        output = output[:, 0] #shape=[512,32]
        
        return self.mlp(output) # Map to output dimension (classification head).
    

    def initialize_weights(self):
        """
        Initialize weights for linear layers, embeddings, etc.
        """
        
        # Initialize Weights for Linear Layers, Embeddings, etc.
        nn.init.xavier_uniform_(self.linear_mapper.weight)
        nn.init.normal_(self.class_token.data)

        # Initialize Weights for Classification MLP.
        nn.init.xavier_uniform_(self.mlp[0].weight)
    

#### 3.5: GENERAL AViT MODEL DEFINITION

In [None]:

class AViT_Model(MyViT, pl.LightningModule):
    
    def __init__(self, input_d, n_patches, n_blocks, hidden_d, n_heads, out_d):
        """
        Initialize the AViT_Model, a LightningModule using MyViT as a base.

        Parameters:
        - input_d (int): Dimension of the input.
        - n_patches (int): Number of patches.
        - n_blocks (int): Number of transformer blocks.
        - hidden_d (int): Dimension of the hidden layer.
        - n_heads (int): Number of attention heads.
        - out_d (int): Output dimension.
        """
        super(AViT_Model, self).__init__(input_d, n_patches, n_blocks, hidden_d, n_heads, out_d)

        # Definition of the Cross Entropy Loss.
        self.loss = CrossEntropyLoss()

        # Definition of Accuracies, F1Score, Precision, and Recall Metrics.
        self.acc_top1 = Accuracy(task="multiclass", num_classes=out_d)
        self.acc_top3 = Accuracy(task="multiclass", num_classes=out_d, top_k=3)
        self.acc_top5 = Accuracy(task="multiclass", num_classes=out_d, top_k=5)
        self.acc_top10 = Accuracy(task="multiclass", num_classes=out_d, top_k=10)
        self.f1score = MulticlassF1Score(num_classes=out_d, average='macro')
        self.precision = MulticlassPrecision(num_classes=out_d, average='macro')
        self.recall = MulticlassRecall(num_classes=out_d, average='macro')

        # Definition of lists to be used in the "on_ ... _epoch_end" functions.
        self.training_step_outputs = []
        self.validation_step_outputs = []
        self.test_step_outputs = []

    def _step(self, batch):
        """
        Common computation of the metrics among Training, Validation, and Test Set.

        Parameters:
        - batch (tuple): Input batch tuple.

        Returns:
        tuple: Tuple containing loss and various metrics.
        """
        x, y = batch
        preds = self(x)
        loss = self.loss(preds, y) + self.ponder_loss + self.distr_prior_loss
        acc1 = self.acc_top1(preds, y)
        acc3 = self.acc_top3(preds, y)
        acc5 = self.acc_top5(preds, y)
        acc10 = self.acc_top10(preds, y)
        f1score = self.f1score(preds, y)
        precision = self.precision(preds, y)
        recall = self.recall(preds, y)

        return loss, acc1, acc3, acc5, acc10, f1score, precision, recall

    def training_step(self, batch, batch_idx):
        """
        Training step function.

        Parameters:
        - batch (tuple): Input batch tuple.
        - batch_idx (int): Batch index.

        Returns:
        torch.Tensor: Training loss.
        """
        # Compute the Training Loss and Accuracy.
        loss, acc, _, _, _, _, _, _ = self._step(batch)

        # Create a Dictionary to represent the output of the Training step.
        training_step_output = {
            "train_loss": loss.item(),
            "train_acc": acc.item()
        }

        # Append the dictionary to the list.
        self.training_step_outputs.append(training_step_output)

        # Perform logging.
        self.log("train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_acc", acc, on_step=True, on_epoch=True, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        """
        Validation step function.

        Parameters:
        - batch (tuple): Input batch tuple.
        - batch_idx (int): Batch index.

        Returns:
        None
        """
        # Compute the Validation Loss and Accuracy.
        loss, acc1, acc3, acc5, acc10, _, _, _ = self._step(batch)

        # Create a Dictionary to represent the output of the validation step.
        validation_step_output = {
            "val_loss": loss.item(),
            "val_acc": acc1.item(),
            "val_acc_3": acc3.item(),
            "val_acc_5": acc5.item(),
            "val_acc_10": acc10.item(),
        }

        # Append the dictionary to the list.
        self.validation_step_outputs.append(validation_step_output)

        # Perform logging.
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc", acc1, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc_3", acc3, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc_5", acc5, on_epoch=True, prog_bar=True, logger=True)
        self.log("val_acc_10", acc10, on_epoch=True, prog_bar=True, logger=True)

    def on_validation_epoch_end(self):
        """
        Method called at the end of the validation epoch.

        Returns:
        None
        """
        # Calculate the Mean Loss and Accuracy from the list of dictionaries.
        loss_tot = torch.tensor([item["val_loss"] for item in self.validation_step_outputs]).mean()
        acc_tot = torch.tensor([item["val_acc"] for item in self.validation_step_outputs]).mean()
        acc_tot_3 = torch.tensor([item["val_acc_3"] for item in self.validation_step_outputs]).mean()
        acc_tot_5 = torch.tensor([item["val_acc_5"] for item in self.validation_step_outputs]).mean()
        acc_tot_10 = torch.tensor([item["val_acc_10"] for item in self.validation_step_outputs]).mean()

        # Log the mean values.
        self.log("val_loss", loss_tot)
        self.log("val_acc", acc_tot)
        self.log("val_acc_3", acc_tot_3)
        self.log("val_acc_5", acc_tot_5)
        self.log("val_acc_10", acc_tot_10)

        # Print messages.
        message_loss = f'Epoch {self.current_epoch} Validation Loss -> {loss_tot}'
        message_accuracy = f'      Validation Accuracy -> {acc_tot}'
        message_accuracy_3 = f'      Validation Accuracy Top-3 -> {acc_tot_3}'
        message_accuracy_5 = f'      Validation Accuracy Top-5-> {acc_tot_5}'
        message_accuracy_10 = f'      Validation Accuracy Top-10-> {acc_tot_10}'
        print(message_loss + message_accuracy + message_accuracy_3 + message_accuracy_5 + message_accuracy_10)

        # Clear the list to free memory.
        self.validation_step_outputs.clear()


    def configure_optimizers(self):
        """
        Configure the optimizer.

        Returns:
        torch.optim.Optimizer: The optimizer.
        """
        # Configure the Adam Optimizer.
        optimizer = optim.Adam(self.parameters(), lr=1.5e-3, weight_decay=1.5e-4)

        # Configure the Cosine Annealing Learning Rate Scheduler.
        # scheduler = CosineAnnealingLR(optimizer, T_max=5, eta_min=1.5e-6)

        # return {"optimizer": optimizer, "lr_scheduler": scheduler}
        return optimizer


### 4: MODEL TRAINING

#### 4.1: CALLBACKS DEFINITION

In [None]:
# Checkpoint CallBack Definition.
my_checkpoint_call = ModelCheckpoint(
    dirpath="/kaggle/working/checkpoints/",
    filename="Best_Model",
    monitor="val_acc",
    mode="max",
    save_top_k=1,
    save_last=True
)

# Learning Rate CallBack Definition.
my_lr_monitor_call = LearningRateMonitor(logging_interval="epoch")

# Early Stopping CallBack Definition.
my_early_stopping_call = pl.callbacks.EarlyStopping(monitor="val_loss", patience=30, mode="min", min_delta=0.001)

# Progress Bar CallBack Definition.
my_progress_bar_call = TQDMProgressBar(refresh_rate=10)

# TensorBoardLogger CallBack Definition.
tb_logger = TensorBoardLogger(save_dir="/kaggle/working/logs", name="AViT")

# CSV CallBack Definition.
csv_logger = CSVLogger("/kaggle/working/logs", name="AViT")


#### 4.2: MODEL INSTANTIATION & TRAINING

In [None]:

# Instantiate the Adaptive Vision Transformer Model.
model = AViT_Model((3, 64, 64), 
                   n_patches=8, 
                   n_blocks=4, 
                   hidden_d=32, 
                   n_heads=4, 
                   out_d = 200)

datamodule = AViT_DataModule(train_data_dir="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train/", 
                             val_data_dir="/kaggle/input/tiny-imagenet/tiny-imagenet-200/val/images/", 
                             batch_size=512)

# Setup the Dataloaders. 
data_module.setup()

# Create a PyTorch Lightning Trainer.
trainer = pl.Trainer(
    max_epochs=10,
    accelerator="auto", 
    devices="auto",
    log_every_n_steps=1,
    logger=tb_logger,
    callbacks=[my_progress_bar_call,
               my_checkpoint_call,
               my_lr_monitor_call,
               my_early_stopping_call,
               ]
)


In [None]:
# Train the model
trainer.fit(model, datamodule)

### 5: BEST MODEL EXTRAPOLATION

In [None]:
# Get the path of the best Model.
best_model_path = my_checkpoint_call.best_model_path

# Load the best model from the Checkpoint.
best_model = AViT_Model.load_from_checkpoint(
    checkpoint_path=best_model_path,
    input_d=(3, 64, 64),
    n_patches=8,
    n_blocks=4,
    hidden_d=32,
    n_heads=4,
    out_d=200
)

# Access the Best Model's Accuracy.
best_model_accuracy = trainer.checkpoint_callback.best_model_score.item()
print(f"Best Model Accuracy: {best_model_accuracy}")

### 6: SAVING THE BEST MODEL

In [None]:
# Save it as a pth file.
# Specify the path where you want to save the model.
model_path = f"/kaggle/working/best_model_acc_{best_model_accuracy:.5f}.pth"

# Save the model's state dict to the specified file.
torch.save(best_model.state_dict(), model_path)

# Save it as a CheckPoint (Specific of PyTorch Lightning = Model State Dictionary + Training State + Optimizer State).
# Specify the path where you want to save the model checkpoint.
ckpt_path = f"/kaggle/working/best_model_acc_{best_model_accuracy:.5f}.ckpt"

# Save the model's state dict to the specified file.
torch.save(best_model.state_dict(), ckpt_path)

### 7: TRAINING FROM A SAVED CHECKPOINT

# Load the Best Model from the Checkpoint.
checkpoint = torch.load("/kaggle/input/adavit-model-checkpoints/best_model_acc.pth")

# 'Checkpoint' is an OrderedDict or an odict_keys Object.
checkpoint_keys = list(checkpoint.keys()) if isinstance(checkpoint, OrderedDict) else checkpoint.keys()

# Convert checkpoint_keys to a Dictionary.
state_dict = {key: checkpoint[key] for key in checkpoint_keys}

# Instantiate the Loaded Model (same schema as the checkpoint).
loaded_model = AViT_Model((3, 64, 64), 
                   n_patches=8, 
                   n_blocks=4, 
                   hidden_d=32, 
                   n_heads=4, 
                   out_d = 200)

# Now, load the state_dict into the Model. 
loaded_model.load_state_dict(state_dict)


# Resume the Trainer from the last Checkpoint.
resume_trainer = pl.Trainer(
    max_epochs=1,
    accelerator="auto", 
    devices="auto",
    log_every_n_steps=1,
    logger=tb_logger,
    callbacks=[my_progress_bar_call,
               my_checkpoint_call,
               my_lr_monitor_call,
               my_early_stopping_call,
               ]
)

# Train the Model.
resume_trainer.fit(new_model, datamodule)