### Import Lib

In [None]:
import os
import re
import cv2
import sys
import math
import yaml
import nltk
import torch
import joblib
import zipfile
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
from textwrap import fill
import torch.optim as optim
import matplotlib.pyplot as plt
from nltk.corpus import stopwords
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

### Installl requirements.txt file

In [None]:
!pip install -r requirements.txt

### Utility files

In [None]:
nltk.download("stopwords")
stop_words = set(stopwords.words("english"))

def config_files():
    with open("../notebook_config.yml", "r") as config_file:
        return yaml.safe_load(config_file)


def dump_file(value: None, filename: None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
    else:
        print("Error: 'value' and 'filename' must be provided.".capitalize())


def load_file(filename: None):
    if filename is not None:
        return joblib.load(filename=filename)
    else:
        print("Error: 'filename' must be provided.".capitalize())


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find("Linear") != -1:
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif classname.find("Conv") != -1:
        nn.init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)


def device_init(device: str = "cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")

def clean_folders():
    TRAIN_MODELS: str = "../artifacts/checkpoints/train_models/"
    BEST_MODEL: str = "../artifacts/checkpoints/best_model/"
    METRICS_PATH: str = "../artifacts/metrics/"
    TRAIN_IMAGES: str = "../artifacts/outputs/train_images/"
    TEST_IMAGE: str = "../artifacts/outputs/test_image/"

    warnings.warn(
        "Warning! This will remove the previous files, which might be useful in the future. "
        "You may want to download the previous files or use MLflow to track and manage them. "
        "Suggestions for updating them are welcome."
    )

    for path in tqdm(
        [TRAIN_MODELS, BEST_MODEL, METRICS_PATH, TRAIN_IMAGES, TEST_IMAGE]
    ):
        for files in os.listdir(path):
            file_path = os.path.join(path, files)
            try:
                if os.path.isfile(file_path):
                    os.remove(file_path)
            except Exception as e:
                print(f"Error occurred while deleting {file_path}: {e}")

        print("{} folders completed".format().capitalize())


def text_preprocessing(instance):
    instance = re.sub(r'[\n\'"()]+|XXXX|x-\d{4}', "", instance)
    instance = re.sub(r"[^a-wy-zA-WY-Z\s]", "", instance)

    instance = instance.lower()

    instance = " ".join(word for word in instance.split() if word not in stop_words)

    return instance

def plot_images(
    predicted: bool = False, device: str = "cuda", model=None, epoch: int = 1
):
    processed_path = config_files()["artifacts"]["processed_data_path"]
    train_images_path = config_files()["artifacts"]["train_images"]
    try:
        train_dataloader = load_file(
            filename=os.path.join(processed_path, "train_dataloader.pkl")
        )
        vocabularies = pd.read_csv(os.path.join(processed_path, "vocabulary.csv"))
        vocabularies["index"] = vocabularies["index"].astype(int)

        images, texts, labels = next(iter(train_dataloader))

        predict = model(image=images.to(device), text=texts.to(device))
        predict = torch.where(predict > 0.5, 1, 0)
        predict = predict.detach().cpu().numpy()

        num_images = images.size(0)
        num_rows = int(math.sqrt(num_images))
        num_cols = math.ceil(num_images / num_rows)

        _, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8))
        axes = axes.flatten()

        for index, (image, ax) in enumerate(zip(images, axes)):
            image = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
            image = (image - image.min()) / (image.max() - image.min())
            label = labels[index].item()
            text_sequences = texts[index].detach().cpu().numpy().tolist()
            words = vocabularies[vocabularies["index"].isin(text_sequences)][
                "vocabulary"
            ].tolist()
            medical_report = " ".join(words).replace("<UNK>", "").strip()
            wrapped_report = fill(medical_report, width=30)

            if predicted:
                title_text = f"**Label**: {label}\n**Predicted**: {predict[index]}\nReport: {wrapped_report}".title()
            else:
                title_text = f"Label: {label}\nReport: {wrapped_report}"

            ax.set_title(title_text, fontsize=9, loc="center")
            ax.imshow(image)
            ax.axis("off")

        plt.tight_layout()
        plt.savefig(os.path.join(train_images_path, "image{}.png".format(epoch)))

    except Exception as e:
        print(f"Error in display_images: {e}")

### Patch Embedding

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(
        self,
        channels: int = 3,
        patch_size: int = 16,
        image_size: int = 128,
        dimension: int = None,
    ):
        super(PatchEmbedding, self).__init__()

        self.in_channels = channels
        self.patch_size = patch_size
        self.dimension = dimension
        self.image_size = image_size

        if self.dimension is None:
            self.dimension = (self.patch_size**2) * self.in_channels

        self.number_of_pathches = (self.image_size // self.patch_size) ** 2

        self.kernel_size, self.stride_size = self.patch_size, self.patch_size
        self.padding_size = self.patch_size // self.patch_size

        self.projection = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.dimension,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=False,
        )

        self.positional_embeddings = nn.Parameter(
            torch.randn(self.padding_size, self.number_of_pathches, self.dimension),
            requires_grad=True,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.projection(x)
            x = x.view(x.size(0), x.size(1), -1)
            x = x.permute(0, 2, 1)
            x = self.positional_embeddings + x
            return x
        else:
            raise ValueError("Input must be a torch tensor.".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, PatchEmbedding):
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]

    pathEmbedding = PatchEmbedding(
        channels=image_channels,
        patch_size=patch_size,
        image_size=image_size,
        dimension=dimension,
    )

    image = torch.randn(
        (image_channels // image_channels, image_channels, image_size, image_size)
    )

    assert (pathEmbedding(image).size()) == (
        image_channels // image_channels,
        (image_size // patch_size) ** 2,
        dimension,
    )

### Scaled Dot Product -> self attention 

In [None]:
def scaled_dot_product(query: torch.Tensor, key: torch.Tensor, values: torch.Tensor):
    if (
        isinstance(query, torch.Tensor)
        and isinstance(key, torch.Tensor)
        and isinstance(values, torch.Tensor)
    ):
        logits = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(query.size(-1))
        attention_weights = torch.softmax(logits, dim=-1)
        attention_output = torch.matmul(attention_weights, values)
        return attention_output

    else:
        raise ValueError("All inputs must be torch tensors.".capitalize())

if __name__ == "__main__":
    query = torch.randn((1, 8, 64, 32))
    key = torch.randn((1, 8, 64, 32))
    values = torch.randn((1, 8, 64, 32))

    assert (scaled_dot_product(query, key, values).size()) == (1, 8, 64, 32)

### Multi Head Attention Layer-> attention is all you need

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, nheads: int = 8, dimension: int = 256):
        super(MultiHeadAttentionLayer, self).__init__()
        self.nheads = nheads
        self.dimension = dimension

        warnings.warn(
            "Please ensure that the dimension is a multiple of the number of heads in the encoder block (e.g., 256 % 8 = 0). "
            "This is a requirement for the Transformer Encoder Block to function properly. "
            "If not, you might need to adjust the dimension or the number of heads."
        )
        assert (
            dimension % self.nheads == 0
        ), "Dimension mismatched with nheads and dimension".title()

        self.QKV = nn.Linear(
            in_features=self.dimension, out_features=3 * self.dimension, bias=False
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            QKV = self.QKV(x)
            query, key, values = torch.chunk(input=QKV, chunks=3, dim=-1)
            query = query.view(
                query.size(0), query.size(1), self.nheads, self.dimension // self.nheads
            )
            key = key.view(
                key.size(0), key.size(1), self.nheads, self.dimension // self.nheads
            )
            values = values.view(
                values.size(0),
                values.size(1),
                self.nheads,
                self.dimension // self.nheads,
            )

            query = query.permute(0, 2, 1, 3)
            key = key.permute(0, 2, 1, 3)
            values = values.permute(0, 2, 1, 3)

            attention_output = scaled_dot_product(query=query, key=key, values=values)
            attention_output = attention_output.view(
                attention_output.size(0),
                attention_output.size(2),
                attention_output.size(1) * attention_output.size(-1),
            )

            return attention_output
        else:
            raise ValueError("Input must be a torch tensor.".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, TransformerEncoderBlock):
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":
    nheads = config_files()["transfomerEncoderBlock"]["nheads"]
    dimension = config_files()["patchEmbeddings"]["dimension"]
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]

    multihead_attention = MultiHeadAttentionLayer(
        nheads=nheads, dimension=dimension
    )

    image = torch.randn((1, (image_size // patch_size) ** 2, dimension))

    assert (multihead_attention(image).size()) == (
        (1, (image_size // patch_size) ** 2, dimension)
    )

### Feed Forward Neural Network

In [None]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(
        self,
        in_features: int = 256,
        out_features: int = 4 * 256,
        dropout: float = 0.1,
        activation: str = "relu",
    ):
        super(FeedForwardNeuralNetwork, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.activation = activation

        if activation == "relu":
            self.activation_function = nn.ReLU()
        elif activation == "gelu":
            self.activation_function = nn.GELU()
        elif activation == "selu":
            self.activation_function = nn.SELU(inplace=True)
        else:
            raise ValueError(
                "Invalid activation function. Choose from'relu', 'gelu', or'selu'."
            )

        self.layers = []

        for idx in range(2):
            self.layers.append(
                nn.Linear(in_features=self.in_features, out_features=self.out_features)
            )
            self.in_features = self.out_features
            self.out_features = in_features

            if idx == 0:
                self.layers.append(self.activation_function)
                self.layers.append(nn.Dropout(self.dropout))

        self.network = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.network(x)
        else:
            raise ValueError("Input must be a torch.Tensor.")


if __name__ == "__main__":
    network = FeedForwardNeuralNetwork(
        in_features=256,
        out_features=4 * 256,
        dropout=0.1,
        activation="relu",
    )

    print(network(torch.randn((1, 64, 256))).size())

### Layer Normalization

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, dimension: int = 256, layer_norm_eps: float = 1e-5):
        super(LayerNormalization, self).__init__()
        self.dimension = dimension
        self.layer_norm_eps = layer_norm_eps

        self.alpha = nn.Parameter(
            data=torch.ones((1, 1, self.dimension)), requires_grad=True
        )
        self.beta = nn.Parameter(
            data=torch.zeros((1, 1, self.dimension)), requires_grad=True
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            mean = torch.mean(x, dim=-1).unsqueeze(-1)
            variance = torch.var(x, dim=-1).unsqueeze(-1)

            y = (x - mean) / torch.sqrt(variance + self.layer_norm_eps)

            return self.alpha * y + self.beta

        else:
            raise ValueError("Input must be a torch.Tensor.".capitalize())


if __name__ == "__main__":
    layer_normalization = LayerNormalization(dimension=256)
    assert (layer_normalization(torch.randn((1, 64, 256))).size())

### Transformer Encoder Block

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        dimension: int = 256,
        nheads: int = 8,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-5,
        activation: str = "relu",
    ):
        super(TransformerEncoderBlock, self).__init__()
        self.dimension = dimension
        self.nheads = nheads
        self.dim_feedforward = dim_feedforward
        self.activation = activation
        self.dropout = dropout
        self.layer_norm_eps = layer_norm_eps

        self.multihead_attention = MultiHeadAttentionLayer(
            nheads=self.nheads, dimension=self.dimension
        )
        self.layer_normalization = LayerNormalization(
            dimension=self.dimension, layer_norm_eps=self.layer_norm_eps
        )
        self.feed_forward_network = FeedForwardNeuralNetwork(
            in_features=self.dimension,
            out_features=4 * self.dim_feedforward,
            activation=self.activation,
            dropout=self.dropout,
        )
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            residual = x

            x = self.multihead_attention(x)
            x = self.dropout(x)
            x = torch.add(x, residual)
            x = self.layer_normalization(x)

            residual = x

            x = self.feed_forward_network(x)
            x = self.dropout(x)
            x = torch.add(x, residual)
            x = self.layer_normalization(x)

            return x
        else:
            raise ValueError("Input must be a torch.Tensor.")

    @staticmethod
    def total_params(model):
        if isinstance(model, TransformerEncoderBlock):
            return sum(params.numel() for params in model.parameters())

if __name__ == "__main__":
    nheads = config_files()["transfomerEncoderBlock"]["nheads"]
    dimension = config_files()["patchEmbeddings"]["dimension"]
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]
    dropout = config_files()["transfomerEncoderBlock"]["dropout"]
    activation = config_files()["transfomerEncoderBlock"]["activation"]
    
    transformer_encoder_block = TransformerEncoderBlock(
        dimension=dimension,
        nheads=nheads,
        dim_feedforward=dimension,
        dropout=dropout,
        activation=activation,
        layer_norm_eps=1e-5,
    )
    
    print(transformer_encoder_block(torch.randn((1, (image_size // patch_size) ** 2, dimension))).size())

### Transformer Encoder -> attention is all you need

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(
        self,
        dimension: int = 256,
        nheads: int = 8,
        num_encoder_layers: int = 6,
        dim_feedforward: int = 1024,
        dropout: float = 0.0,
        layer_norm_eps: float = 1e-05,
        activation: str = "relu",
    ):
        super(TransformerEncoder, self).__init__()
        self.dimension = dimension
        self.nheads = nheads
        self.num_encoder_layers = num_encoder_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.layer_norm_eps = layer_norm_eps
        self.activation = activation

        self.layers = nn.Sequential(
            *[
                TransformerEncoderBlock(
                    dimension=self.dimension,
                    nheads=self.nheads,
                    dim_feedforward=self.dim_feedforward,
                    dropout=self.dropout,
                    activation=self.activation,
                )
                for _ in tqdm(range(self.num_encoder_layers))
            ]
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            for layer in self.layers:
                x = layer(x=x)
            return x
        else:
            raise ValueError("Input must be a torch.Tensor.".capitalize())


if __name__ == "__main__":
    transformerEncoder = TransformerEncoder(
        dimension=256,
        nheads=8,
        num_encoder_layers=8,
        dim_feedforward=1024,
        dropout=0.1,
        layer_norm_eps=1e-05,
        activation="relu",
    )

    print(transformerEncoder(torch.randn((1, 64, 256))).size())

### Loss Function : BCELoss

In [None]:
class LossFunction(nn.Module):
    def __init__(self, loss_name: str = "BCEWithLogitsLoss", reduction: str = "mean"):
        super(LossFunction, self).__init__()
        self.loss_name = loss_name
        self.reduction = reduction

        if loss_name == "BCEWithLogitsLoss":
            self.criterion = nn.BCEWithLogitsLoss(reduction=reduction)
        elif loss_name == "BCELoss":
            self.criterion = nn.BCELoss(reduction=reduction)
        elif loss_name == "CCE":
            self.criterion = nn.CrossEntropyLoss(reduction=reduction)
        else:
            raise ValueError(f"Unsupported loss function: {loss_name}")

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(actual, torch.Tensor):
            return self.criterion(predicted, actual)
        else:
            raise ValueError(
                "Both predicted and actual must be torch.Tensor.".capitalize()
            )


if __name__ == "__main__":
    predicted = torch.tensor(
        [2.5, -1.0, 1.5, -2.0, 2.0, 3.0, -1.5, -0.5], dtype=torch.float
    )
    actual = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0, 0.0], dtype=torch.float)

    criterion = LossFunction(loss_name="BCEWithLogitsLoss", reduction="mean")
    loss = criterion(predicted, actual)

    print("Loss:", loss.item())

### Vision Transformer Model 

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self,
        channels: int = 3,
        patch_size: int = 16,
        image_size: int = 128,
        dimension: int = 256,
        nheads: int = 8,
        num_encoder_layers: int = 6,
        dim_feedforward: int = 1024,
        dropout: float = 0.0,
        layer_norm_eps: float = 1e-05,
        activation: str = "relu",
    ):
        super(VisionTransformer, self).__init__()

        self.channels = channels
        self.patch_size = patch_size
        self.image_size = image_size
        self.dimension = dimension
        self.nheads = nheads
        self.num_encoder_layers = num_encoder_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.layer_norm_eps = layer_norm_eps
        self.activation = activation

        self.patch_embedding = PatchEmbedding(
            channels=channels,
            patch_size=patch_size,
            image_size=image_size,
            dimension=dimension,
        )
        self.transformer_encoder = TransformerEncoder(
            dimension=self.dimension,
            nheads=self.nheads,
            num_encoder_layers=self.num_encoder_layers,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            layer_norm_eps=self.layer_norm_eps,
            activation=self.activation,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.patch_embedding(x)
            x = self.transformer_encoder(x)
            return x
        else:
            raise ValueError("Input must be a torch.Tensor.".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, VisionTransformer):
            return sum(params.numel() for params in model.parameters())
        else:
            raise ValueError("Input must be a VisionTransformer model.".capitalize())


if __name__ == "__main__":
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]
    nheads = config_files()["transfomerEncoderBlock"]["nheads"]
    activation = config_files()["transfomerEncoderBlock"]["activation"]
    dropout = config_files()["transfomerEncoderBlock"]["dropout"]
    num_encoder_layers = config_files()["transfomerEncoderBlock"]["num_encoder_layers"]
    dimension_feedforward = config_files()["transfomerEncoderBlock"][
        "dimension_feedforward"
    ]
    layer_norm_eps = float(config_files()["transfomerEncoderBlock"]["layer_norm_eps"])

    image = torch.randn((1, image_channels, image_size, image_size))

    vision_transformer = VisionTransformer(
        channels=image_channels,
        patch_size=patch_size,
        image_size=image_size,
        dimension=dimension,
        nheads=nheads,
        activation=activation,
        dropout=dropout,
        num_encoder_layers=num_encoder_layers,
        dim_feedforward=dimension_feedforward,
        layer_norm_eps=layer_norm_eps,
    )

    assert vision_transformer(image).size() == (
        1,
        (image_size // patch_size) ** 2,
        dimension,
    )

### Text Transformer

In [None]:
class TextTransformerEncoder(nn.Module):
    def __init__(
        self,
        dimension: int = 256,
        nheads: int = 8,
        num_encoder_layers: int = 6,
        dim_feedforward: int = 1024,
        dropout: float = 0.0,
        layer_norm_eps: float = 1e-05,
        activation: str = "relu",
    ):
        super(TextTransformerEncoder, self).__init__()
        self.dimension = dimension
        self.nheads = nheads
        self.num_encoder_layers = num_encoder_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.layer_norm_eps = layer_norm_eps
        self.activation = activation

        try:
            self.image_size = config_files()["patchEmbeddings"]["image_size"]
            self.patch_size = config_files()["patchEmbeddings"]["patch_size"]
            self.image_channels = config_files()["patchEmbeddings"]["channels"]
        except KeyError:
            raise ValueError(
                "Image configuration not found in the config files.".capitalize()
            )
        else:
            self.sequence_length = (self.image_size // self.patch_size) ** 2

        if self.dimension is None:
            self.dimension = (self.patch_size**2) * self.image_size

        warnings.warn(
            "The number of vocabularies (unique words) is calculated by multiplying the 'dimension with the patch size'. "
            "If you need a larger vocabulary size, please increase the patch size and image size."
        )

        self.number_of_vocabularies = self.image_size * self.dimension

        self.embedding = nn.Embedding(
            num_embeddings=self.number_of_vocabularies, embedding_dim=self.dimension
        )
        self.transformer_encoder = TransformerEncoder(
            dimension=self.dimension,
            nheads=self.nheads,
            num_encoder_layers=self.num_encoder_layers,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            layer_norm_eps=self.layer_norm_eps,
            activation=self.activation,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.embedding(x)
            x = self.transformer_encoder(x)
            return x
        else:
            raise ValueError("Input must be a torch.Tensor.".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, TextTransformerEncoder):
            return sum(params.numel() for params in model.parameters())
        else:
            raise ValueError(
                "Input must be a TextTransformerEncoder model.".capitalize()
            )


if __name__ == "__main__":
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]
    nheads = config_files()["transfomerEncoderBlock"]["nheads"]
    activation = config_files()["transfomerEncoderBlock"]["activation"]
    dropout = config_files()["transfomerEncoderBlock"]["dropout"]
    num_encoder_layers = config_files()["transfomerEncoderBlock"]["num_encoder_layers"]
    dimension_feedforward = config_files()["transfomerEncoderBlock"][
        "dimension_feedforward"
    ]
    layer_norm_eps = float(config_files()["transfomerEncoderBlock"]["layer_norm_eps"])

    sequence_length = (image_size // patch_size) ** 2

    textual_data = torch.randint(0, sequence_length, (1, sequence_length))

    text_transfomer = TextTransformerEncoder(
        dimension=dimension,
        nheads=nheads,
        num_encoder_layers=num_encoder_layers,
        dim_feedforward=dimension_feedforward,
        dropout=dropout,
        layer_norm_eps=layer_norm_eps,
        activation=activation,
    )

    assert (text_transfomer(textual_data).size()) == (1, sequence_length, dimension)

### Multi Modal Classifier

In [None]:
class Classifier(nn.Module):
    def __init__(self, dimension: int = 256):
        super(Classifier, self).__init__()
        self.in_features = dimension * 2
        self.out_features = dimension // 2

        self.layers = list()

        for idx in range(3):
            if idx != 2:
                self.layers.append(
                    nn.Linear(
                        in_features=self.in_features, out_features=self.out_features
                    )
                )
                self.layers.append(nn.ReLU(inplace=True))
                self.layers.append(nn.BatchNorm1d(num_features=self.out_features))

                self.in_features = self.out_features
                self.out_features = self.in_features // 2
            else:
                self.layers.append(
                    nn.Linear(
                        in_features=self.in_features,
                        out_features=self.in_features // self.in_features,
                    )
                )

        self.classifier = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.classifier(x).view(-1)
        else:
            raise ValueError("Input must be a torch.Tensor.".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, Classifier):
            return sum(params.numel() for params in model.parameters())
        else:
            raise ValueError("Input must be a Classifier model.".capitalize())
        

class MultiModalClassifier(nn.Module):
    def __init__(
        self,
        channels: int = 3,
        patch_size: int = 16,
        image_size: int = 128,
        dimension: int = 256,
        nheads: int = 8,
        num_encoder_layers: int = 6,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
        layer_norm_eps: float = 1e-05,
        activation: str = "relu",
    ):
        super(MultiModalClassifier, self).__init__()

        self.channels = channels
        self.patch_size = patch_size
        self.image_size = image_size
        self.dimension = dimension
        self.nheads = nheads
        self.num_encoder_layers = num_encoder_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.layer_norm_eps = layer_norm_eps
        self.activation = activation

        self.vision_transformer = VisionTransformer(
            channels=channels,
            patch_size=patch_size,
            image_size=image_size,
            dimension=dimension,
            nheads=nheads,
            num_encoder_layers=num_encoder_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            layer_norm_eps=layer_norm_eps,
            activation=activation,
        )

        self.text_transformer = TextTransformerEncoder(
            dimension=self.dimension,
            nheads=self.nheads,
            num_encoder_layers=self.num_encoder_layers,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            layer_norm_eps=self.layer_norm_eps,
            activation=self.activation,
        )

        self.classifier = Classifier(dimension=self.dimension)

    def forward(self, image: torch.Tensor, text: torch.Tensor):
        if isinstance(image, torch.Tensor) and isinstance(text, torch.Tensor):
            image_features = self.vision_transformer(image)
            text_features = self.text_transformer(text)

            image_features = torch.mean(input=image_features, dim=1)
            text_features = torch.mean(input=text_features, dim=1)
            fusion = torch.cat((image_features, text_features), dim=1)
            classifier = self.classifier(fusion)

            return classifier
        else:
            raise ValueError("Both inputs must be torch.Tensor.".capitalize())


if __name__ == "__main__":
    image_channels = config_files()["patchEmbeddings"]["channels"]
    image_size = config_files()["patchEmbeddings"]["image_size"]
    patch_size = config_files()["patchEmbeddings"]["patch_size"]
    dimension = config_files()["patchEmbeddings"]["dimension"]
    nheads = config_files()["transfomerEncoderBlock"]["nheads"]
    activation = config_files()["transfomerEncoderBlock"]["activation"]
    dropout = config_files()["transfomerEncoderBlock"]["dropout"]
    num_encoder_layers = config_files()["transfomerEncoderBlock"]["num_encoder_layers"]
    dimension_feedforward = config_files()["transfomerEncoderBlock"][
        "dimension_feedforward"
    ]
    layer_norm_eps = float(config_files()["transfomerEncoderBlock"]["layer_norm_eps"])
    batch_size = config_files()["dataloader"]["batch_size"]

    number_of_patches = (image_size // patch_size) ** 2
    number_of_sequences = (image_size // patch_size) ** 2

    images = torch.randn(1, image_channels, image_size, image_size)
    texts = torch.randint(0, number_of_sequences, (1, number_of_sequences))

    classifier = MultiModalClassifier(
        channels=image_channels,
        patch_size=patch_size,
        image_size=image_size,
        dimension=dimension,
        nheads=nheads,
        num_encoder_layers=num_encoder_layers,
        dim_feedforward=dimension_feedforward,
        dropout=dropout,
        layer_norm_eps=layer_norm_eps,
        activation=activation,
    )
    output = classifier(image=images, text=texts)
    output = classifier(image=images, text=texts)
    assert output.unsqueeze(-1).size() == (
        batch_size,
        batch_size // batch_size,
    ), "Multi Modal Classifier output size mismatch".capitalize()

### DataLoader

In [None]:
class Loader:
    def __init__(
        self,
        channels: int = 3,
        image_size: int = 128,
        patch_size: int = 16,
        batch_size: int = 4,
        split_size: float = 0.25,
    ):
        self.channels = channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.vocabulary = {"<UNK>": 0}
        self.images_data = list()
        self.labels_data = list()
        self.textual_data = list()
        self.text_to_sequence = list()
        self.sequences = list()

        self.sequence_length = (self.image_size // self.patch_size) ** 2

    def image_transform(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def preprocess_csv_file(self):
        dataframe_path = os.path.join(
            config_files()["artifacts"]["raw_data_path"], "image_labels_reports.csv"
        )
        df = pd.read_csv(dataframe_path)
        if ("text" in df.columns) and ("label" in df.columns) and ("img" in df.columns):
            labels = df.loc[:, "label"]
            reports = df.loc[:, "text"]
            images = df.loc[:, "img"]

            df["text"] = df["text"].apply(text_preprocessing)
            reports = reports.apply(text_preprocessing)

            return {
                "labels": labels,
                "reports": reports,
                "images": images,
                "dataframe": df,
            }
        else:
            raise ValueError(
                "The 'text' and 'labels' columns are missing in the CSV file.".capitalize()
            )

    def create_vocabularies(self, instance):
        for word in instance.split(" "):
            if word not in self.vocabulary:
                self.vocabulary[word] = len(self.vocabulary)

    def unzip_image_dataset(self):
        if os.path.exists(config_files()["artifacts"]["raw_data_path"]):
            image_data_path = os.path.join(
                config_files()["artifacts"]["raw_data_path"], "image_dataset.zip"
            )
            processed_data_path = config_files()["artifacts"]["processed_data_path"]

            with zipfile.ZipFile(file=image_data_path, mode="r") as zip_file:
                zip_file.extractall(
                    path=os.path.join(processed_data_path, "image_dataset")
                )

            print(
                "Image dataset unzipped successfully in the folder {}".capitalize().format(
                    processed_data_path
                )
            )
        else:
            raise FileNotFoundError("Could not extract image dataset".capitalize())

    def create_sequences(self, instance, sequence_length: int):
        sequence = [
            self.vocabulary.get(word, self.vocabulary["<UNK>"])
            for word in instance.split()
        ]
        if len(sequence) > sequence_length:
            sequence = sequence[:sequence_length]

        elif len(sequence) < sequence_length:
            sequence.extend(
                [self.vocabulary["<UNK>"]] * (sequence_length - len(sequence))
            )

        assert (
            len(sequence) == sequence_length
        ), f"Error: Sequence length is {len(sequence)} instead of {sequence_length}"

        self.sequences.append(sequence)

        return sequence

    def extracted_image_and_text_features(self):
        dataset = self.preprocess_csv_file()
        images = dataset["images"]
        labels = dataset["labels"]
        reports = dataset["reports"]
        dataframe = dataset["dataframe"]

        try:
            reports.apply(self.create_vocabularies)

            pd.DataFrame(
                list(self.vocabulary.items()), columns=["vocabulary", "index"]
            ).to_csv(
                os.path.join(
                    config_files()["artifacts"]["processed_data_path"], "vocabulary.csv"
                )
            )
        except Exception as e:
            print(f"Error occurred while creating vocabularies: {e}")
            sys.exit(1)

        dataframe["sequences"] = reports.apply(
            self.create_sequences, sequence_length=self.sequence_length
        )

        all_image_path = os.path.join(
            config_files()["artifacts"]["processed_data_path"], "image_dataset"
        )

        for image in tqdm(os.listdir(all_image_path), desc="Processing Images"):
            try:
                if image not in images.values.tolist():
                    print(f"Image not found in dataset: {image}")
                    continue

                try:
                    text = dataframe.loc[dataframe["img"] == image, "text"].values[0]
                    label = dataframe.loc[dataframe["img"] == image, "label"].values[0]
                    sequences = dataframe.loc[
                        dataframe["img"] == image, "sequences"
                    ].values[0]
                except IndexError:
                    print(f"Missing data for image: {image}")
                    continue

                single_image_path = os.path.join(all_image_path, image)

                if not single_image_path.lower().endswith(("jpeg", "png", "jpg")):
                    print(f"Invalid file format: {single_image_path}")
                    continue

                if not os.path.exists(single_image_path):
                    print(f"File does not exist: {single_image_path}")
                    continue

                try:
                    image_data = cv2.imread(single_image_path)
                    image_data = cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
                    if image_data is None:
                        raise ValueError(
                            f"Corrupted or unreadable image: {single_image_path}"
                        )

                    image_data = Image.fromarray(image_data)
                    image_data = self.image_transform()(image_data)

                    if not isinstance(image_data, torch.Tensor):
                        raise TypeError(
                            f"Expected torch.Tensor but got {type(image_data)} for {image}"
                        )

                    self.images_data.append(image_data)
                    self.labels_data.append(label)
                    self.textual_data.append(text)
                    self.text_to_sequence.append(sequences)

                except Exception as e:
                    print(f"Error processing image '{image}': {e}")
                    continue

            except Exception as e:
                print(f"Unexpected error processing file '{image}': {e}")

        assert (
            len(self.images_data)
            == len(self.labels_data)
            == len(self.textual_data)
            == len(self.text_to_sequence)
        ), "Mismatch: 'Image data', 'labels', 'text', and 'text to sequence' are not equal"

        try:
            self.labels_data = torch.tensor(self.labels_data, dtype=torch.long)
            self.text_to_sequence = torch.tensor(
                self.text_to_sequence, dtype=torch.long
            )
        except Exception as e:
            print("Tensor conversion failed: {e}")

        return {
            "images": self.images_data,
            "labels": self.labels_data,
            "text_to_sequence": self.text_to_sequence,
        }

    def create_dataloader(self):
        try:
            dataset = self.extracted_image_and_text_features()
            images = dataset["images"]
            labels = dataset["labels"]
            text_to_sequence = dataset["text_to_sequence"]

            test_image_portion = int(len(images) * self.split_size)
            test_labels_portion = int(len(labels) * self.split_size)
            test_texts_portion = int(text_to_sequence.size(0) * self.split_size)

            train_images = images[:-test_image_portion]
            train_labels = labels[:-test_labels_portion]
            train_texts = text_to_sequence[:-test_texts_portion]

            test_images = images[-test_image_portion:]
            test_labels = labels[-test_labels_portion:]
            test_texts = text_to_sequence[-test_texts_portion:]

            if (
                len(train_images) == 0
                or len(train_labels) == 0
                or len(train_texts) == 0
            ):
                raise ValueError("Train dataset is empty! Check split size.")
            if len(test_images) == 0 or len(test_labels) == 0 or len(test_texts) == 0:
                raise ValueError("Test dataset is empty! Check split size.")

            train_dataloader = DataLoader(
                dataset=list(zip(train_images, train_texts, train_labels)),
                batch_size=self.batch_size,
                shuffle=True,
            )
            test_dataloader = DataLoader(
                dataset=list(zip(test_images, test_texts, test_labels)),
                batch_size=self.batch_size,
                shuffle=False,
            )

            try:
                for filename, value in [
                    ("train_dataloader.pkl", train_dataloader),
                    ("test_dataloader.pkl", test_dataloader),
                ]:
                    dump_file(
                        value=value,
                        filename=os.path.join(
                            config_files()["artifacts"]["processed_data_path"],
                            filename,
                        ),
                    )

                print(
                    "Dataloaders created successfully in the folder {}".capitalize().format(
                        config_files()["artifacts"]["processed_data_path"]
                    )
                )

            except StopIteration:
                raise RuntimeError(
                    "Train dataloader is empty. Check data loading logic."
                )

            return train_dataloader, test_dataloader

        except Exception as e:
            print(f"Error in create_dataloader: {e}")
            return None, None

    @staticmethod
    def details_dataset():
        train_dataloader = load_file(
            filename=os.path.join(
                config_files()["artifacts"]["processed_data_path"],
                "train_dataloader.pkl",
            )
        )
        test_dataloader = load_file(
            filename=os.path.join(
                config_files()["artifacts"]["processed_data_path"],
                "test_dataloader.pkl",
            )
        )

        train_images, _, train_labels = next(iter(train_dataloader))
        _, test_sequences, _ = next(iter(test_dataloader))

        total_train_dataset = sum(image.size(0) for image, _, _ in train_dataloader)
        total_test_dataset = sum(image.size(0) for image, _, _ in test_dataloader)

        pd.DataFrame(
            {
                "Dataset": ["Train", "Test"],
                "Size": [total_train_dataset, total_test_dataset],
                "Number of Batches": [len(train_dataloader), len(test_dataloader)],
                "Image Size": str([train_images.size()]),
                "Sequence Size": str([test_sequences.size()]),
                "Label Size": str([train_labels.size()]),
                "Label Type": str([train_labels.dtype]),
                "Text Type": str([test_sequences.dtype]),
            }
        ).to_csv(
            os.path.join(
                config_files()["artifacts"]["processed_data_path"],
                "dataset_details.csv",
            )
        )
        print(
            "Dataset details saved successfully in the folder {}".capitalize().format(
                config_files()["artifacts"]["processed_data_path"]
            )
        )

    @staticmethod
    def display_images():
        try:
            train_dataloader = load_file(
                filename=os.path.join(
                    config_files()["artifacts"]["processed_data_path"],
                    "train_dataloader.pkl",
                )
            )
            vocabularies = pd.read_csv(
                os.path.join(
                    config_files()["artifacts"]["processed_data_path"], "vocabulary.csv"
                )
            )
            vocabularies["index"] = vocabularies["index"].astype(int)

            images, texts, labels = next(iter(train_dataloader))

            num_images = images.size(0)
            num_rows = int(math.sqrt(num_images))
            num_cols = math.ceil(num_images / num_rows)

            fig, axes = plt.subplots(num_rows, num_cols, figsize=(12, 8))
            axes = axes.flatten()

            for index, (image, ax) in enumerate(zip(images, axes)):
                image = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
                image = (image - image.min()) / (image.max() - image.min())
                label = labels[index].item()
                text_sequences = texts[index].detach().cpu().numpy().tolist()
                words = vocabularies[vocabularies["index"].isin(text_sequences)][
                    "vocabulary"
                ].tolist()
                medical_report = " ".join(words).replace("<UNK>", "").strip()
                wrapped_report = fill(medical_report, width=30)

                title_text = f"Label: {label}\nReport: {wrapped_report}"

                ax.set_title(title_text, fontsize=9, loc="center")
                ax.imshow(image)
                ax.axis("off")

            plt.tight_layout()
            plt.show()

        except Exception as e:
            print(f"Error in display_images: {e}")


if __name__ == "__main__":
    loader = Loader(
        channels = 3,
        image_size = 224,
        batch_size=4,
        split_size=0.25,
    )

    loader.unzip_image_dataset()
    loader.create_dataloader()

### Helper function

In [None]:
def load_dataloader():
    processed_data_path = config_files()["artifacts"]["processed_data_path"]
    if os.path.exists(processed_data_path):
        train_dataloader_path = os.path.join(
            processed_data_path, "train_dataloader.pkl"
        )
        valid_dataloader_path = os.path.join(processed_data_path, "test_dataloader.pkl")

        train_dataloader = load_file(filename=train_dataloader_path)
        valid_dataloader = load_file(filename=valid_dataloader_path)

        return {
            "train_dataloader": train_dataloader,
            "valid_dataloader": valid_dataloader,
        }


def helper(**kwargs):
    model = kwargs["model"]
    lr: float = kwargs["lr"]
    beta1: float = kwargs["beta1"]
    beta2: float = kwargs["beta2"]
    momentum: float = kwargs["momentum"]
    weight_decay: float = kwargs["weight_decay"]
    adam: bool = kwargs["adam"]
    SGD: bool = kwargs["SGD"]

    if model is None:
        nheads = config_files()["transfomerEncoderBlock"]["nheads"]
        dropout = config_files()["transfomerEncoderBlock"]["dropout"]
        image_size = config_files()["patchEmbeddings"]["image_size"]
        patch_size = config_files()["patchEmbeddings"]["patch_size"]
        activation = config_files()["transfomerEncoderBlock"]["activation"]
        dimension = config_files()["patchEmbeddings"]["dimension"]
        image_channels = config_files()["patchEmbeddings"]["channels"]
        num_encoder_layers = config_files()["transfomerEncoderBlock"][
            "num_encoder_layers"
        ]
        dimension_feedforward = config_files()["transfomerEncoderBlock"][
            "dimension_feedforward"
        ]
        layer_norm_eps = float(
            config_files()["transfomerEncoderBlock"]["layer_norm_eps"]
        )

        classifier = MultiModalClassifier(
            channels=image_channels,
            patch_size=patch_size,
            image_size=image_size,
            nheads=nheads,
            dropout=dropout,
            activation=activation,
            dimension=dimension,
            num_encoder_layers=num_encoder_layers,
            dim_feedforward=dimension_feedforward,
            layer_norm_eps=layer_norm_eps,
        )

    else:
        classifier = model

    if adam:
        optimizer = optim.Adam(
            params=classifier.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=weight_decay,
        )
    elif SGD:
        optimizer = optim.SGD(params=classifier.parameters(), lr=lr, momentum=momentum)
    else:
        raise ValueError("Optimizer not supported".capitalize())

    criterion = LossFunction(loss_name="BCEWithLogitsLoss", reduction="mean")

    try:
        dataset = load_dataloader()
    except Exception as e:
        print(f"Error while loading dataset: {e}")
        sys.exit(1)

    return {
        "train_dataloader": dataset["train_dataloader"],
        "test_dataloader": dataset["valid_dataloader"],
        "model": classifier,
        "optimizer": optimizer,
        "criterion": criterion,
    }


if __name__ == "__main__":
    init = helper(
        model = None,
        lr = float(config_files()["trainer"]["lr"]),
        beta1 = config_files()["trainer"]["beta1"],
        beta2 = config_files()["trainer"]["beta2"],
        momentum = config_files()["trainer"]["momentum"],
        weight_decay = float(config_files()["trainer"]["weight_decay"]),
        adam = config_files()["trainer"]["adam"],
        SGD = config_files()["trainer"]["SGD"],
    )

    assert init["model"].__class__ == MultiModalClassifier
    assert init["optimizer"].__class__ == optim.Adam
    assert init["criterion"].__class__ == LossFunction
    assert init["train_dataloader"].__class__ == torch.utils.data.DataLoader
    assert init["test_dataloader"].__class__ == torch.utils.data.DataLoader


###  Trainer Class

In [None]:
class Trainer:
    def __init__(
        self,
        model=None,
        epochs: int = 100,
        lr: float = 2e-4,
        beta1: float = 0.5,
        beta2: float = 0.999,
        momentum: float = 0.95,
        step_size: int = 20,
        gamma: float = 0.75,
        l1_lambda: float = 0.01,
        l2_lambda: float = 0.01,
        device: str = "cuda",
        adam: bool = True,
        SGD: bool = False,
        l1_regularization: bool = False,
        l2_regularization: bool = False,
        lr_scheduler: bool = False,
        verbose: bool = True,
        mlflow: bool = False,
    ):
        try:
            self.model = model
            self.epochs = epochs
            self.lr = lr
            self.beta1 = beta1
            self.beta2 = beta2
            self.momentum = momentum
            self.step_size = step_size
            self.gamma = gamma
            self.l1_lambda = l1_lambda
            self.l2_lambda = l2_lambda
            self.adam = adam
            self.SGD = SGD
            self.device = device
            self.l1_regularization = l1_regularization
            self.l2_regularization = l2_regularization
            self.lr_scheduler = lr_scheduler
            self.verbose = verbose
            self.mlflow = mlflow

            self.device = device_init(device=self.device)

            self.init = helper(
                model=self.model,
                lr=self.lr,
                beta1=self.beta1,
                beta2=self.beta2,
                momentum=self.momentum,
                weight_decay=self.l2_lambda,
                adam=self.adam,
                SGD=self.SGD,
            )
            try:
                self.train_dataloader = self.init["train_dataloader"]
                self.test_dataloader = self.init["test_dataloader"]
                assert (
                    self.train_dataloader.__class__ == torch.utils.data.DataLoader
                ), "Train_dataloader is not a valid DataLoader"
                assert (
                    self.test_dataloader.__class__ == torch.utils.data.DataLoader
                ), "Test_dataloader is not a valid DataLoader"
            except KeyError as e:
                print(f"DataLoader Initialization Error: Missing key {e}")
                sys.exit(1)

            try:
                self.model = self.init["model"]
                assert (
                    self.model.__class__ == MultiModalClassifier
                ), "Model must be an instance of MultiModalClassifier"
            except KeyError:
                print(
                    "Model Initialization Error: 'model' key missing from helper return dictionary"
                )
                sys.exit(1)
            except AssertionError as e:
                print(e)
                sys.exit(1)

            try:
                self.optimizer = self.init["optimizer"]
                if self.adam:
                    assert (
                        self.optimizer.__class__ == optim.Adam
                    ), "Optimizer should be Adam"
                elif self.SGD:
                    assert (
                        self.optimizer.__class__ == optim.SGD
                    ), "Optimizer should be SGD"
            except KeyError:
                print(
                    "Optimizer Initialization Error: 'optimizer' key missing from helper return dictionary"
                )
                sys.exit(1)
            except AssertionError as e:
                print(e)
                sys.exit(1)

            try:
                self.criterion = self.init["criterion"]
                assert (
                    self.criterion.__class__ == LossFunction
                ), "Criterion should be a PyTorch loss function"
            except KeyError:
                print(
                    "Criterion Initialization Error: 'criterion' key missing from helper return dictionary"
                )
                sys.exit(1)
            except AssertionError as e:
                print(e)
                sys.exit(1)

            if self.lr_scheduler:
                self.scheduler = StepLR(
                    optimizer=self.optimizer, step_size=self.step_size, gamma=self.gamma
                )

            self.model = self.model.to(self.device)
            self.criterion = self.criterion.to(self.device)

        except Exception as e:
            print(f"Unexpected Error in Trainer Initialization: {e}")
            sys.exit(1)

        self.train_models = config_files()["artifacts"]["train_models"]
        self.best_model = config_files()["artifacts"]["best_model"]
        self.metrics_path = config_files()["artifacts"]["metrics"]

        self.model_history = {
            "train_loss": [],
            "test_loss": [],
            "train_accuracy": [],
            "test_accuracy": [],
        }

        self.loss = float("inf")

    def l1_regularizer(self, model):
        if isinstance(model, MultiModalClassifier):
            return self.l1_lambda * sum(
                torch.norm(input=params, p=1) for params in model.parameters()
            )
        else:
            raise ValueError("Model must be an instance of MultiModalClassifier")

    def l2_regularizer(self):
        if isinstance(model, MultiModalClassifier):
            return self.l2_lambda * sum(
                torch.norm(input=params, p=2) for params in model.parameters()
            )
        else:
            raise ValueError("Model must be an instance of MultiModalClassifier")

    def saved_checkpoints(self, train_loss: float, epoch: int):
        if self.loss > train_loss:
            self.loss = train_loss
            torch.save(
                {
                    "model": self.model.state_dict(),
                    "optimizer": self.optimizer.state_dict(),
                    "train_loss": self.loss,
                    "epoch": epoch,
                },
                os.path.join(self.best_model, "best_model.pth"),
            )
        torch.save(
            self.model.state_dict(),
            os.path.join(self.train_models, "model{}.pth".format(epoch)),
        )

    def update_train(self, **kwargs):
        predicted = kwargs["predicted"].float()
        labels = kwargs["labels"].float()

        self.optimizer.zero_grad()

        loss = self.criterion(predicted, labels)

        loss.backward()
        self.optimizer.step()

        return loss.item()

    def display_progress(
        self,
        train_loss: list,
        valid_loss: list,
        train_accuracy: list,
        valid_accuracy: list,
        kwargs: dict,
    ):
        if (
            isinstance(train_loss, list)
            and isinstance(valid_loss, list)
            and isinstance(train_accuracy, list)
            and isinstance(valid_accuracy, list)
        ):
            train_loss = np.mean(train_loss)
            valid_loss = np.mean(valid_loss)
            train_accuracy = np.mean(train_accuracy)
            valid_accuracy = np.mean(valid_accuracy)
            number_of_epochs = self.epochs
            epoch = kwargs["epochs"]

            print(
                f"Epoch [{epoch}/{number_of_epochs}] | "
                f"Train Loss: {train_loss:.4f} | "
                f"Test Loss: {valid_loss:.4f} | "
                f"Train Acc: {train_accuracy:.4f} | "
                f"Valid Acc: {valid_accuracy:.4f}"
            )

    def train(self):
        for epoch in tqdm(range(self.epochs), desc="Training"):
            train_loss = []
            valid_loss = []
            train_accuracy = []
            valid_accuracy = []

            for idx, (images, texts, labels) in enumerate(self.train_dataloader):
                images = images.to(self.device)
                texts = texts.to(self.device)
                labels = labels.to(self.device)

                predicted = self.model(image=images, text=texts)

                train_loss.append(self.update_train(predicted=predicted, labels=labels))

                predicted = torch.where(predicted > 0.5, 1, 0)
                predicted = predicted.detach().cpu().numpy()
                labels = labels.detach().cpu().numpy()

                train_accuracy.append(accuracy_score(predicted, labels))

            for idx, (images, texts, labels) in enumerate(self.test_dataloader):
                images = images.to(self.device)
                texts = texts.to(self.device)
                labels = labels.to(self.device)

                predicted = self.model(image=images, text=texts)

                valid_loss.append(
                    self.criterion(predicted.float(), labels.float()).item()
                )

                predicted = torch.where(predicted > 0.5, 1, 0)
                predicted = predicted.detach().cpu().numpy()
                labels = labels.detach().cpu().numpy()

                valid_accuracy.append(accuracy_score(predicted, labels))

            if self.lr_scheduler:
                self.scheduler.step()

            try:
                self.display_progress(
                    train_loss=train_loss,
                    valid_loss=valid_loss,
                    train_accuracy=train_accuracy,
                    valid_accuracy=valid_accuracy,
                    kwargs={"epochs": epoch + 1},
                )

                train_loss_mean = np.mean(train_loss)
                valid_loss_mean = np.mean(valid_loss)
                train_acc_mean = np.mean(train_accuracy)
                valid_acc_mean = np.mean(valid_accuracy)

                self.saved_checkpoints(train_loss=train_loss_mean, epoch=epoch + 1)

                self.model_history["train_loss"].append(train_loss_mean)
                self.model_history["train_accuracy"].append(train_acc_mean)
                self.model_history["test_loss"].append(valid_loss_mean)
                self.model_history["test_accuracy"].append(valid_acc_mean)

                plot_images(predicted=True, device=self.device, model = self.model, epoch=epoch+1)

            except KeyError as e:
                print(f"[Error] Missing key in function arguments: {e}")
            except TypeError as e:
                print(f"[Error] Type mismatch in function arguments: {e}")
            except ValueError as e:
                print(f"[Error] Invalid value encountered: {e}")
            except FileNotFoundError:
                print(
                    "Error: Checkpoint directory not found. Ensure the save path exists."
                )
            except PermissionError:
                print(
                    "Error: Permission denied. Cannot write to the checkpoint directory."
                )
            except Exception as e:
                print(f"[Unexpected Error] {e}")

        dump_file(
            value=self.model_history,
            filename=os.path.join(self.metrics_path, "history.pkl"),
        )

    @staticmethod
    def display_history():
        metrics_path = config_files()["artifacts"]["metrics"]
        history = load_file(filename=os.path.join(metrics_path, "history.pkl"))
        if history is not None:
            _, axes = plt.subplots(2, 2, figsize=(10, 10), sharex=True)

            axes[0, 0].plot(history["train_loss"], label="Train Loss")
            axes[0, 0].plot(history["test_loss"], label="Test Loss")
            axes[0, 0].set_title("Loss")
            axes[0, 0].set_xlabel("Epochs")
            axes[0, 0].legend()

            axes[0, 1].plot(history["train_accuracy"], label="Train Accuracy")
            axes[0, 1].plot(history["test_accuracy"], label="Test Accuracy")
            axes[0, 1].set_title("Accuracy")
            axes[0, 1].set_xlabel("Epochs")
            axes[0, 1].legend()

            axes[1, 0].axis("off")
            axes[1, 1].axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(metrics_path, "history.png"))
            plt.show()
            print("History saved as 'history.png' in the metrics folder".capitalize())
        else:
            print("No history found".capitalize())

if __name__ == "__main__":
    model = config_files()["trainer"]["model"]
    epochs = config_files()["trainer"]["epochs"]
    lr = float(config_files()["trainer"]["lr"])
    beta1 = float(config_files()["trainer"]["beta1"])
    beta2 = float(config_files()["trainer"]["beta2"])
    momentum = config_files()["trainer"]["momentum"]
    step_size = config_files()["trainer"]["step_size"]
    gamma = config_files()["trainer"]["gamma"]
    l1_lambda = config_files()["trainer"]["l1_lambda"]
    l2_lambda = config_files()["trainer"]["l2_lambda"]
    device = config_files()["trainer"]["device"]
    adam = config_files()["trainer"]["adam"]
    SGD = config_files()["trainer"]["SGD"]
    l1_regularization = config_files()["trainer"]["l1_regularization"]
    l2_regularization = config_files()["trainer"]["l2_regularization"]
    lr_scheduler = config_files()["trainer"]["lr_scheduler"]
    verbose = config_files()["trainer"]["verbose"]
    mlflow = config_files()["trainer"]["mlflow"]


    trainer = Trainer(
        model=None,
        epochs=epochs,
        lr=lr,
        beta1=beta1,
        beta2=beta2,
        momentum=momentum,
        step_size=step_size,
        gamma=gamma,
        l1_lambda=l1_lambda,
        l2_lambda=l2_lambda,
        device=device,
        adam=adam,
        SGD=SGD,
        l1_regularization=l1_regularization,
        l2_regularization=l2_regularization,
        lr_scheduler=lr_scheduler,
        verbose=verbose,
        mlflow=mlflow,
    )