## Import the Lib

In [None]:
import os
import sys
import math
import yaml
import torch
import joblib
import warnings
import torch.nn as nn

## Utility Files

In [None]:
def config_files():
    with open("../../notebook_config.yml", mode="r") as file:
        return yaml.safe_load(file)


def dump_files(value=None, filename=None):
    if (value is None) and (filename is None):
        raise ValueError("Either values or filename must be provided".capitalize())
    else:
        joblib.dump(value=value, filename=filename)


def load_files(filename: str = None):
    if filename is None:
        raise ValueError("Filename must be provided".capitalize())
    else:
        return joblib.load(filename=filename)


def device_init(device: str = "cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")


def weight_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Patch Embedding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dimension: int = 512):
        super(PositionalEncoding, self).__init__()
        self.dimension = dimension

        self.encoding = nn.Parameter(
            torch.randn(
                size=(
                    self.dimension // self.dimension,
                    self.dimension // self.dimension,
                    self.dimension,
                ),
                requires_grad=True,
            )
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            return x + self.encoding


class PatchEmbedding(nn.Module):
    def __init__(
        self,
        image_channels: int = 3,
        image_size: int = 224,
        patch_size: int = 16,
        embedding_dimension: int = 512,
    ):
        super(PatchEmbedding, self).__init__()
        self.image_channels = image_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.embedding_dimension = embedding_dimension

        self.total_patches = (self.image_size // self.patch_size) ** 2

        if self.embedding_dimension is None:
            warnings.warn(
                "Embedding dimension not specified. Using the default value calculated as: image_channels × patch_size × patch_size."
            )
            self.embedding_dimension = (
                self.image_channels**self.patch_size * self.patch_size
            )

        self.projection = nn.Conv2d(
            in_channels=self.image_channels,
            out_channels=self.embedding_dimension,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding=self.patch_size // self.patch_size,
            bias=False,
        )
        self.encoding = PositionalEncoding(dimension=self.embedding_dimension)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            x = self.projection(x)
            x = x.view(x.size(0), x.size(-1) * x.size(-2), x.size(1))
            x = self.encoding(x)
            return x


if __name__ == "__main__":
    image_channels = 3
    image_size = 224
    patch_size = 16
    embedding_dimension = 768

    patchEmbedding = PatchEmbedding(
        image_channels=image_channels,
        image_size=image_size,
        patch_size=patch_size,
        embedding_dimension=embedding_dimension,
    )

    images = torch.randn((64, 3, 224, 224))

    assert patchEmbedding(images).size() == (
        64,
        (image_size // patch_size) ** 2,
        768,
    ), "Patch Embedding is not working properly".capitalize()

## Multi Head Attention Layer 

In [None]:
def scaled_dot_product(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
    if (
        not isinstance(query, torch.Tensor)
        and isinstance(key, torch.Tensor)
        and isinstance(value, torch.Tensor)
    ):
        raise TypeError("All inputs must be torch.Tensor".capitalize())

    key = key.transpose(-2, -1)
    scores = torch.matmul(query, key) / math.sqrt(key.size(-1))
    scores = torch.softmax(scores, dim=-1)
    attention = torch.matmul(scores, value)
    return attention


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, nheads: int = 6, dimension: int = 768):
        super(MultiHeadAttentionLayer, self).__init__()
        self.nheads = nheads
        self.dimension = dimension

        assert (
            self.dimension % self.nheads == 0
        ), "Dimension must be divisible by number of heads".capitalize()

        warnings.warn(
            "Invalid number of dimensions provided. To avoid errors, ensure the dimension is calculated as: in_channels × patch_size × patch_size."
        )

        self.QKV = nn.Linear(
            in_features=self.dimension, out_features=3 * self.dimension, bias=False
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())
        else:
            QKV = self.QKV(x)
            query, key, value = torch.chunk(input=QKV, chunks=3, dim=-1)
            assert (
                query.size() == key.size() == value.size()
            ), "Query, key, and value must have the same size".capitalize()

            query = query.view(
                query.size(0), query.size(1), self.nheads, query.size(-1) // self.nheads
            )
            key = key.view(
                key.size(0), key.size(1), self.nheads, key.size(-1) // self.nheads
            )
            value = value.view(
                value.size(0), value.size(1), self.nheads, value.size(-1) // self.nheads
            )

            query = query.permute(0, 2, 1, 3)
            key = key.permute(0, 2, 1, 3)
            value = value.permute(0, 2, 1, 3)

            attention = scaled_dot_product(query=query, key=key, value=value)
            attention = attention.view(
                attention.size(0),
                attention.size(-2),
                attention.size(1),
                attention.size(-1),
            )
            attention = attention.view(
                attention.size(0),
                attention.size(1),
                attention.size(2) * attention.size(3),
            )

            assert (
                attention.size() == x.size()
            ), "Attention output must have the same size as input".capitalize()

            return attention


if __name__ == "__main__":
    nheads = 8
    dimension = 256

    images = torch.randn((1, 196, 256))
    multihead_attention = MultiHeadAttentionLayer(nheads=nheads, dimension=dimension)

    assert (
        multihead_attention(x=images)
    ).size() == images.size(), "MultiHeadAttention is not working properly".capitalize()