### Lib

In [None]:
import os
import sys
import cv2
import math
import yaml
import torch
import joblib
import zipfile
import argparse
import torch.nn as nn
import pandas as pd
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

## Utility files

In [None]:
def dump(value=None, filename=None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)

    else:
        raise ValueError("Both value and filename must be provided".capitalize())


def load(filename=None):
    if filename is not None:
        return joblib.load(filename=filename)

    else:
        raise ValueError("Filename must be provided".capitalize())


def device_init(device="cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    else:
        return torch.device("cpu")


def config():
    with open("../../config.yml", "r") as file:
        return yaml.safe_load(file)

## Dataloader

In [None]:
class Loader:
    def __init__(
        self,
        image_path: str = None,
        image_channels: int = 3,
        image_size: int = 128,
        batch_size: int = 64,
        split_size: float = 0.25,
    ):
        self.image_path = image_path
        self.image_channels = image_channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.X = list()
        self.Y = list()

        try:
            self.CONFIG = config()
        except Exception as e:
            print("An error occurred while loading config file: ", e)
        else:
            self.RAW_DATA_PATH = self.CONFIG["path"]["RAW_DATA_PATH"]
            self.PROCESSED_DATA_PATH = self.CONFIG["path"]["PROCESSED_DATA_PATH"]

    def dataset_split(self, X: list, y: list):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=self.split_size, random_state=42
            )

            return {
                "X_train": X_train,
                "X_test": X_test,
                "y_train": y_train,
                "y_test": y_test,
            }

        else:
            raise TypeError("X and y must be list".capitalize())

    def transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def unzip_folder(self):
        if os.path.exists(self.RAW_DATA_PATH):
            with zipfile.ZipFile(self.image_path, "r") as zip:
                zip.extractall(self.RAW_DATA_PATH)

        else:
            raise FileNotFoundError("RAW Path not found".capitalize())

    def extract_features(self):
        self.directory = os.path.join(self.RAW_DATA_PATH, "dataset")
        self.categories = config()["dataloader"]["labels"]

        for category in tqdm(self.categories):
            image_path = os.path.join(self.directory, category)

            for image in os.listdir(image_path):
                image = os.path.join(image_path, image)

                if (image is not None) and (image.endswith((".jpg", ".png", ".jpeg"))):
                    image = cv2.imread(image)
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                    image = self.transforms()(Image.fromarray(image))
                    label = self.categories.index(category)

                    self.X.append(image)
                    self.Y.append(label)

                else:
                    print("Image not found".capitalize())

        assert len(self.X) == len(
            self.Y
        ), "Image size and Label size not equal".capitalize()

        try:
            dataset = self.dataset_split(X=self.X, y=self.Y)
        except TypeError as e:
            print("An error occured: ", e)
        except Exception as e:
            print("An error occured: ", e)

        else:
            return dataset

    def create_dataloader(self):
        dataset = self.extract_features()

        train_dataloader = DataLoader(
            dataset=list(zip(dataset["X_train"], dataset["y_train"])),
            batch_size=self.batch_size,
            shuffle=True,
        )

        valid_dataloader = DataLoader(
            dataset=list(zip(dataset["X_test"], dataset["y_test"])),
            batch_size=self.batch_size,
            shuffle=True,
        )

        for value, filename in [
            (train_dataloader, "train_dataloader"),
            (valid_dataloader, "valid_dataloader"),
        ]:
            dump(
                value=value,
                filename=os.path.join(
                    self.PROCESSED_DATA_PATH, "{}.pkl".format(filename)
                ),
            )

        print("Dataloader is saved in the folder {}".format(self.PROCESSED_DATA_PATH))

    @staticmethod
    def display_images():
        FILES_PATH = config()["path"]["FILES_PATH"]
        PROCESSED_PATH = config()["path"]["PROCESSED_DATA_PATH"]

        os.makedirs(FILES_PATH, exist_ok=True)

        if os.path.exists(FILES_PATH):
            plt.figure(figsize=(20, 20))

            dataloader = load(
                filename=os.path.join(PROCESSED_PATH, "train_dataloader.pkl")
            )

            data, label = next(iter(dataloader))

            labels = config()["dataloader"]["labels"]

            number_of_rows = len(data) // int(
                math.sqrt(config()["dataloader"]["batch_size"])
            )
            number_of_columns = len(data) // number_of_rows

            for index, image in enumerate(data):
                X = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
                X = (X - X.min()) / (X.max() - X.min())

                plt.subplot(2 * number_of_rows, 2 * number_of_columns, 2 * index + 1)
                plt.imshow(X)
                plt.title(labels[label[index]].capitalize())
                plt.axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(FILES_PATH, "image.png"))
            plt.show()

            print("Images are saved in {}".format(FILES_PATH))

        else:
            raise FileNotFoundError("The folder {} does not exist".format(FILES_PATH))

    @staticmethod
    def dataset_details():
        FILES_PATH = config()["path"]["FILES_PATH"]
        PROCESSED_PATH = config()["path"]["PROCESSED_DATA_PATH"]

        os.makedirs(FILES_PATH, exist_ok=True)

        if os.path.exists(FILES_PATH):
            plt.figure(figsize=(20, 20))

            train_dataloader = load(
                filename=os.path.join(PROCESSED_PATH, "train_dataloader.pkl")
            )
            valid_dataloader = load(
                filename=os.path.join(PROCESSED_PATH, "valid_dataloader.pkl")
            )

            train_data, _ = next(iter(train_dataloader))

            dataset = pd.DataFrame(
                {
                    "train_data": [
                        sum(actual.size(0) for actual, _ in train_dataloader)
                    ],
                    "train_labels": [
                        sum(target.size(0) for _, target in train_dataloader)
                    ],
                    "valid_data": [
                        sum(actual.size(0) for actual, _ in valid_dataloader)
                    ],
                    "valid_labels": [
                        sum(target.size(0) for _, target in valid_dataloader)
                    ],
                    "total_data": [
                        sum(actual.size(0) for actual, _ in train_dataloader)
                        + sum(actual.size(0) for actual, _ in valid_dataloader)
                    ],
                    "batch_size": [train_data.size(0)],
                    "channels": [train_data.size(1)],
                    "height": [train_data.size(2)],
                    "width": [train_data.size(3)],
                },
                index=["Dataset Details"],
            ).to_csv(os.path.join(FILES_PATH, "dataset_details.csv"))

            print(
                "Dataset details saved to {}".format(
                    os.path.join(FILES_PATH, "dataset_details.csv").capitalize()
                )
            )

        else:
            raise FileNotFoundError("The folder {} does not exist".format(FILES_PATH))


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataloader for ViT".title())
    parser.add_argument(
        "--image_path",
        type=str,
        default=config()["dataloader"]["image_path"],
        help="Path to the image dataset".capitalize(),
    )
    parser.add_argument(
        "--channels",
        type=int,
        default=config()["dataloader"]["channels"],
        help="Number of channels in the image".capitalize(),
    )
    parser.add_argument(
        "--image_size",
        type=int,
        default=config()["dataloader"]["image_size"],
        help="Size of the image".capitalize(),
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=config()["dataloader"]["batch_size"],
        help="Batch size for the dataloader".capitalize(),
    )
    parser.add_argument(
        "--split_size",
        type=float,
        default=config()["dataloader"]["split_size"],
        help="Split size for the dataloader".capitalize(),
    )
    args = parser.parse_args()

    loader = Loader(
        image_path=args.image_path,
        image_channels=args.channels,
        image_size=args.image_size,
        batch_size=args.batch_size,
        split_size=args.split_size,
    )
    # loader.unzip_folder()
    loader.extract_features()
    loader.create_dataloader()

    try:
        Loader.display_images()
    except FileNotFoundError as e:
        print("An error is occcured", e)
    except Exception as e:
        print("An error is occcured", e)

    try:
        Loader.dataset_details()
    except Exception as e:
        print("An error is occcured", e)
    except Exception as e:
        print("An error is occcured", e)


# Transformer Encoder
    1. Positional Encoding
    2. Scaled Dot Product Attention
    3. Multi Head Attention Layer 
    4. Layer Normalization 
    5. PointWise Feed Forward Neural Network
    6. Encoder Block
    7. Transformer

## Positional Encoding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(
        self, sequence_length: int = 200, dimension: int = 512, constant: int = 10000
    ):
        super(PositionalEncoding, self).__init__()

        self.sequence_length = sequence_length
        self.dimension = dimension
        self.constant = constant

        positional_encoding = torch.randn((self.sequence_length, self.dimension))

        for pos in range(self.sequence_length):
            for index in range(self.dimension):
                if index % 2 == 0:
                    positional_encoding[pos, index] = math.sin(
                        pos / (self.constant ** ((2 * index) / self.dimension))
                    )
                else:
                    positional_encoding[pos, index] = math.cos(
                        pos / (self.constant ** ((2 * index) / self.dimension))
                    )

        self.positional_encoding = nn.Parameter(positional_encoding, requires_grad=True)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.positional_encoding.unsqueeze(0)[:, : x.size(1), :]

        else:
            raise TypeError("Input must be a torch.Tensor")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Positional Encoding for the transformer".title()
    )
    parser.add_argument(
        "--sequence_length",
        type=int,
        default=200,
        help="Length of the sequence".capitalize(),
    )
    parser.add_argument(
        "--dimension",
        type=int,
        default=config()["ViT"]["dimension"],
        help="Dimension of the positional encoding".capitalize(),
    )
    parser.add_argument(
        "--constant",
        type=int,
        default=10000,
        help="Constant used in the positional encoding".capitalize(),
    )

    args = parser.parse_args()

    positional_encoding = PositionalEncoding(
        sequence_length=args.sequence_length,
        dimension=args.dimension,
        constant=args.constant,
    )

    assert positional_encoding(
        torch.randn((40, args.sequence_length, args.dimension))
    ).size() == (
        1,
        args.sequence_length,
        args.dimension,
    )


## Scaled Dot Product Attention

In [None]:
def scaled_dot_product_attention(
    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None
):
    if (
        isinstance(query, torch.Tensor)
        and isinstance(key, torch.Tensor)
        and isinstance(value, torch.Tensor)
    ):
        result = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(query.size(-1))

        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)
            result = torch.add(result, mask)

        attention_weight = torch.softmax(input=result, dim=-1)

        attention = torch.matmul(attention_weight, value)

        assert (
            attention.size() == query.size() == key.size() == value.size()
        ), "Sizes of inputs are not equal".capitalize()

        return attention

    else:
        raise TypeError("All inputs must be of type torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Scaled dot product attetion for transformer".title()
    )

    args = parser.parse_args()

    batch_size = config()["ViT"]["batch_size"]
    nheads = config()["ViT"]["nheads"]
    dimension = config()["ViT"]["dimension"]

    query = torch.randn((batch_size, nheads, 200, dimension // nheads))
    key = torch.randn((batch_size, nheads, 200, dimension // nheads))
    value = torch.randn((batch_size, nheads, 200, dimension // nheads))
    mask = torch.randint(0, 2, (batch_size, 200))

    attention = scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
        mask=None,
    )

    assert attention.size() == (
        batch_size,
        nheads,
        200,
        dimension // nheads,
    ), "Sizes of inputs are not equal".capitalize()

    attention = scaled_dot_product_attention(
        query=query,
        key=key,
        value=value,
        mask=mask,
    )

    assert attention.size() == (
        batch_size,
        nheads,
        200,
        dimension // nheads,
    ), "Sizes of inputs are not equal".capitalize()


## Multi Head Attention Layer

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        dimension: int = 512,
        nheads: int = 8,
        dropout: float = 0.1,
        bias: bool = True,
    ):
        super(MultiHeadAttention, self).__init__()

        self.dimension = dimension
        self.nheads = nheads
        self.dropout = dropout
        self.bias = bias

        assert (
            self.dimension % self.nheads == 0
        ), "dimension must be divisible by nheads".capitalize()

        self.QKV = nn.Linear(
            in_features=self.dimension, out_features=3 * self.dimension, bias=self.bias
        )

        self.layers = nn.Linear(
            in_features=self.dimension, out_features=self.dimension, bias=self.bias
        )

    def forward(self, x: torch.Tensor, mask=None):
        if isinstance(x, torch.Tensor):
            QKV = self.QKV(x)

            self.query, self.key, self.value = torch.chunk(input=QKV, chunks=3, dim=-1)

            assert (
                self.query.size() == self.key.size() == self.value.size()
            ), "QKV must have the same size".capitalize()

            self.query = self.query.view(
                self.query.size(0),
                self.query.size(1),
                self.nheads,
                self.dimension // self.nheads,
            )

            self.key = self.key.view(
                self.key.size(0),
                self.key.size(1),
                self.nheads,
                self.dimension // self.nheads,
            )

            self.value = self.value.view(
                self.value.size(0),
                self.value.size(1),
                self.nheads,
                self.dimension // self.nheads,
            )

            self.query = self.query.permute(0, 2, 1, 3)
            self.key = self.key.permute(0, 2, 1, 3)
            self.value = self.value.permute(0, 2, 1, 3)

            self.attention = scaled_dot_product_attention(
                query=self.query,
                key=self.key,
                value=self.value,
                mask=mask,
            )

            self.attention = self.attention.view(
                self.attention.size(0),
                self.attention.size(2),
                self.attention.size(1) * self.attention.size(3),
            )

            assert (
                self.attention.size() == x.size()
            ), "Attention output size does not match input size"

            return self.layers(self.attention)

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="MultiHeadAttention Layer for the transformer".title()
    )
    parser.add_argument(
        "--dimension",
        type=int,
        default=config()["ViT"]["dimension"],
        help="Dimension of the input tensor".capitalize(),
    )
    parser.add_argument(
        "--nheads",
        type=int,
        default=config()["ViT"]["nheads"],
        help="Number of heads in the multihead attention layer".capitalize(),
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=config()["ViT"]["dropout"],
        help="Dropout rate for the multihead attention layer".capitalize(),
    )

    args = parser.parse_args()

    batch_size = config()["ViT"]["batch_size"]

    attention = MultiHeadAttention(
        dimension=args.dimension, nheads=args.nheads, dropout=args.dropout, bias=True
    )

    assert attention(torch.randn((batch_size, 200, args.dimension))).size() == (
        batch_size,
        200,
        args.dimension,
    ), "MultiHeadAttention Layer is not working properly".capitalize()


## Layer Normalization

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, normalized_shape: int = 512, epsilon: float = 1e-05):
        super(LayerNormalization, self).__init__()

        self.normalized_shape = normalized_shape
        self.epsilon = epsilon

        self.gamma = nn.Parameter(torch.ones((normalized_shape,)))
        self.beta = nn.Parameter(torch.zeros((normalized_shape,)))

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            self.mean = torch.mean(input=x, dim=-1)
            self.variance = torch.var(input=x, dim=-1)

            self.mean = self.mean.unsqueeze(-1)
            self.variance = self.variance.unsqueeze(-1)

            return (
                self.gamma * (x - self.mean) / torch.sqrt(self.variance + self.epsilon)
                + self.beta
            )

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Layer Normalization for the Transformer".title()
    )
    parser.add_argument(
        "--normalized_shape",
        type=int,
        default=config()["ViT"]["dimension"],
        help="The shape of the input tensor".capitalize(),
    )
    parser.add_argument(
        "--epsilon",
        type=float,
        default=config()["ViT"]["eps"],
        help="The epsilon value for the variance".capitalize(),
    )

    args = parser.parse_args()

    batch_size = config()["ViT"]["batch_size"]

    layer_norm = LayerNormalization(
        normalized_shape=args.normalized_shape, epsilon=args.epsilon
    )

    assert layer_norm(torch.rand((batch_size, 200, args.normalized_shape))).size() == (
        batch_size,
        200,
        args.normalized_shape,
    ), "Layer Normalization failed".capitalize()


## PointWise Feed Forward Neural Network

In [None]:
class FeedForwardNetwork(nn.Module):
    def __init__(
        self,
        in_features: int = 512,
        out_features: int = 2048,
        dropout: float = 0.5,
        activation: str = "relu",
        bias: bool = True,
    ):
        super(FeedForwardNetwork, self).__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.activation = activation
        self.bias = bias

        if self.activation == "elu":
            self.activation = nn.ELU(inplace=True)

        elif self.activation == "gelu":
            self.activation = nn.GELU()

        elif self.activation == "leaky_relu":
            self.activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
        else:
            self.activation = nn.ReLU(inplace=True)

        self.layers = []

        for index in range(2):
            self.layers.append(
                nn.Linear(
                    in_features=self.in_features,
                    out_features=self.out_features,
                    bias=self.bias,
                )
            )
            self.in_features = self.out_features
            self.out_features = in_features

            if index == 0:
                self.layers.append(self.activation)
                self.layers.append(nn.Dropout(p=self.dropout))

        self.network = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.network(x)

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Pointwise Feed Forward Network for Transformer".title()
    )
    parser.add_argument(
        "--in_features",
        type=int,
        default=config()["ViT"]["dimension"],
        help="Number of input features".capitalize(),
    )
    parser.add_argument(
        "--out_features",
        type=int,
        default=config()["ViT"]["dim_feedforward"],
        help="Number of output features".capitalize(),
    )
    parser.add_argument(
        "--activation",
        type=str,
        default="gelu",
        choices=["gelu", "relu", "silu", "leaky_relu", "elu"],
        help="Activation function".capitalize(),
    )

    args = parser.parse_args()

    batch_size = config()["ViT"]["batch_size"]
    dimension = args.in_features

    x = torch.randn((batch_size, 200, dimension))

    net = FeedForwardNetwork(
        in_features=args.in_features,
        out_features=args.out_features,
        activation=args.activation,
        bias=True,
    )

    assert net(x=x).size() == (
        batch_size,
        200,
        dimension,
    ), "Output shape is incorrect in PointWise FeedForward Network".capitalize()


## Encoder Layer

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        dimension: int = 512,
        nheads: int = 8,
        dropout: float = 0.1,
        dim_feedforward: int = 2048,
        epsilon: float = 1e-5,
        activation: str = "relu",
        bias: bool = True,
    ):
        super(TransformerEncoderBlock, self).__init__()

        self.dimension = dimension
        self.nheads = nheads
        self.dropout = dropout
        self.epsilon = epsilon
        self.dim_feedforward = dim_feedforward
        self.activation = activation
        self.bias = bias

        self.multihead_attention = MultiHeadAttention(
            dimension=self.dimension,
            nheads=self.nheads,
            dropout=self.dropout,
            bias=self.bias,
        )

        self.feedforward_network = FeedForwardNetwork(
            in_features=self.dimension,
            out_features=self.dim_feedforward,
            dropout=self.dropout,
            activation=self.activation,
            bias=self.bias,
        )

        self.layernorm = LayerNormalization(
            normalized_shape=self.dimension, epsilon=self.epsilon
        )

    def forward(self, x: torch.Tensor, mask=None):
        if isinstance(x, torch.Tensor):
            residual = x

            x = self.multihead_attention(x=x, mask=mask)
            x = torch.dropout(input=x, p=self.dropout, train=self.training)
            x = torch.add(x, residual)
            x = self.layernorm(x)

            residual = x

            x = self.feedforward_network(x=x)
            x = torch.dropout(input=x, p=self.dropout, train=self.training)
            x = torch.add(x, residual)
            x = self.layernorm(x)

            return x

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Encoder Block for the Transformer".title()
    )
    parser.add_argument(
        "--dimension",
        type=int,
        default=config()["ViT"]["dimension"],
        help="Dimension of the input tensor".capitalize(),
    )
    parser.add_argument(
        "--nheads",
        type=int,
        default=config()["ViT"]["nheads"],
        help="Number of heads in the multi-head attention".capitalize(),
    )
    parser.add_argument(
        "--dim_feedforward",
        type=int,
        default=config()["ViT"]["dim_feedforward"],
        help="Dimension of the feedforward network".capitalize(),
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=config()["ViT"]["dropout"],
        help="Dropout rate".capitalize(),
    )
    parser.add_argument(
        "--activation",
        type=str,
        default=config()["ViT"]["activation"],
        help="Activation function".capitalize(),
    )
    parser.add_argument(
        "--eps",
        type=float,
        default=config()["ViT"]["eps"],
        help="Epsilon value for the layer normalization".capitalize(),
    )

    args = parser.parse_args()

    batch_size = config()["ViT"]["batch_size"]

    transformerEncoder = TransformerEncoderBlock(
        dimension=args.dimension,
        nheads=args.nheads,
        dim_feedforward=args.dim_feedforward,
        dropout=args.dropout,
        activation=args.activation,
        bias=True,
        epsilon=args.eps,
    )

    assert transformerEncoder(torch.randn(batch_size, 200, args.dimension)).size() == (
        batch_size,
        200,
        args.dimension,
    ), "Encoder block is not working properly".capitalize()


## Transformer - Encoder

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        dimension: int = 512,
        nheads: int = 8,
        num_encoder_layers: int = 8,
        dropout: float = 0.1,
        dim_feedforward: int = 2048,
        epsilon: float = 1e-5,
        activation: str = "relu",
        bias: bool = True,
    ):

        super(TransformerEncoder, self).__init__()

        self.dimension = dimension
        self.nheads = nheads
        self.num_encoder_layers = num_encoder_layers
        self.dropout = dropout
        self.dim_feedforward = dim_feedforward
        self.epsilon = epsilon
        self.activation = activation
        self.bias = bias

        self.transformerEncoder = nn.Sequential(
            *[
                TransformerEncoderBlock(
                    dimension=self.dimension,
                    nheads=self.nheads,
                    dropout=self.dropout,
                    dim_feedforward=self.dim_feedforward,
                    epsilon=self.epsilon,
                    activation=self.activation,
                    bias=self.bias,
                )
                for _ in range(self.num_encoder_layers)
            ]
        )

    def forward(self, x: torch.Tensor, mask=None):
        if isinstance(x, torch.Tensor):
            for layer in self.transformerEncoder:
                x = layer(x=x, mask=mask)

            return x

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Transformer Encoder block for ViT".title()
    )
    parser.add_argument(
        "--dimension",
        type=int,
        default=config()["ViT"]["dimension"],
        help="Dimension of the input tensor".capitalize(),
    )
    parser.add_argument(
        "--nheads",
        type=int,
        default=config()["ViT"]["nheads"],
        help="Number of heads in the multi-head attention".capitalize(),
    )
    parser.add_argument(
        "--dim_feedforward",
        type=int,
        default=config()["ViT"]["dim_feedforward"],
        help="Dimension of the feedforward network".capitalize(),
    )
    parser.add_argument(
        "--dropout",
        type=float,
        default=config()["ViT"]["dropout"],
        help="Dropout rate".capitalize(),
    )
    parser.add_argument(
        "--activation",
        type=str,
        default=config()["ViT"]["activation"],
        help="Activation function".capitalize(),
    )
    parser.add_argument(
        "--eps",
        type=float,
        default=config()["ViT"]["eps"],
        help="Epsilon value for the layer normalization".capitalize(),
    )
    parser.add_argument(
        "--num_layers",
        type=int,
        default=config()["ViT"]["num_layers"],
        help="Number of layers in the transformer encoder".capitalize(),
    )

    args = parser.parse_args()

    batch_size = config()["ViT"]["batch_size"]

    transformerEncoder = TransformerEncoder(
        dimension=args.dimension,
        nheads=args.nheads,
        dim_feedforward=args.dim_feedforward,
        num_encoder_layers=args.num_layers,
        dropout=args.dropout,
        activation=args.activation,
        bias=True,
        epsilon=args.eps,
    )

    assert transformerEncoder(torch.randn(batch_size, 200, 512)).size() == (
        batch_size,
        200,
        args.dimension,
    ), "TransformerEncoder output size is incorrect"

    print("TransformerEncoder test passed")
