## Import the Lib

In [None]:
import os
import sys
import cv2
import math
import yaml
import torch
import joblib
import zipfile
import warnings
import torch.nn as nn
from PIL import Image
from tqdm import tqdm
import torch.optim as optim
from dotenv import load_dotenv
from torchvision import transforms
from torch.utils.data import DataLoader
from langchain_openai import ChatOpenAI
from sklearn.model_selection import train_test_split
from langchain_core.output_parsers import StrOutputParser

## Utility Files

In [None]:
def config_files():
    with open("../../notebook_config.yml", mode="r") as file:
        return yaml.safe_load(file)


def dump_files(value=None, filename=None):
    if (value is None) and (filename is None):
        raise ValueError("Either values or filename must be provided".capitalize())
    else:
        joblib.dump(value=value, filename=filename)


def load_files(filename: str = None):
    if filename is None:
        raise ValueError("Filename must be provided".capitalize())
    else:
        return joblib.load(filename=filename)


def device_init(device: str = "cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")


def weight_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## DataLoader (Yet to be completed)

In [None]:
class Loader:
    def __init__(
        self,
        image_path: str = "./data/raw",
        image_channels: int = 3,
        image_size: int = 224,
        batch_size: int = 64,
        split_size: float = 0.25,
    ):
        self.image_path = image_path
        self.image_channels = image_channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.train_images = []
        self.train_labels = []
        self.valid_images = []
        self.valid_labels = []

    def unzip_folder(self):
        if not os.path.exists("./data/processed"):
            os.makedirs("./data/processed")

        with zipfile.ZipFile(file=self.image_path, mode="r") as file:
            file.extractall(path="./data/processed")

        print("""Extracted file saved in the "./data/processed" folder""")

    def split_dataset(self, **kwargs):
        X = kwargs["X"]
        y = kwargs["y"]

        if not isinstance(X, list) and not isinstance(y, list):
            raise ValueError("Invalid data type".capitalize())

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=self.split_size, random_state=42
        )

        return {
            "X_train": X_train,
            "X_test": X_test,
            "y_train": y_train,
            "y_test": y_test,
        }

    def image_transforms(self, type: str = "RGB"):
        if type == "RGB":
            return transforms.Compose(
                [
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.ToTensor(),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.ToTensor(),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.Normalize([0.5], [0.5]),
                ]
            )

    def extract_features(self):
        train_path = "./data/processed/Training"
        valid_path = "./data/processed/Testing"

        class_names = ["glioma", "meningioma", "pituitary", "notumor"]

        for path in [train_path, valid_path]:
            for label in class_names:
                image_path = os.path.join(path, label)
                for image in tqdm(
                    os.listdir(image_path), desc="Extracting images".title()
                ):
                    single_image_path = os.path.join(image_path, image)
                    if not single_image_path.endswith(("png", "jpg", "jpeg")):
                        raise ValueError("Invalid image format")

                    image = cv2.imread(single_image_path)
                    image = Image.fromarray(image)
                    image = self.image_transforms(
                        "GRAY" if self.image_channels == 1 else "RGB"
                    )(image)

                    if path == train_path:
                        self.train_images.append(image)
                        self.train_labels.append(class_names.index(label))
                    else:
                        self.valid_images.append(image)
                        self.valid_labels.append(class_names.index(label))

        assert len(self.train_images) == len(self.train_labels)
        assert len(self.valid_images) == len(self.valid_labels)

        train_dataset = self.split_dataset(X=self.train_images, y=self.train_labels)

        return {
            "X_train": torch.stack(train_dataset["X_train"][:400]).float(),
            "X_test": torch.stack(train_dataset["X_test"][:400]).float(),
            "y_train": torch.tensor(train_dataset["y_train"][:100], dtype=torch.long),
            "y_test": torch.tensor(train_dataset["y_test"][:100], dtype=torch.long),
            "valid_images": torch.stack(self.valid_images[:50]).float(),
            "valid_labels": torch.tensor(self.valid_labels[:50], dtype=torch.long),
        }

    def create_dataloader(self):
        dataset = self.extract_features()

        train_dataloader = DataLoader(
            dataset=list(zip(dataset["X_train"], dataset["y_train"])),
            batch_size=self.batch_size,
            shuffle=True,
        )
        test_dataloader = DataLoader(
            dataset=list(zip(dataset["X_test"], dataset["y_test"])),
            batch_size=self.batch_size,
            shuffle=True,
        )
        valid_dataloader = DataLoader(
            dataset=list(zip(dataset["valid_images"], dataset["valid_labels"])),
            batch_size=self.batch_size,
            shuffle=True,
        )

        for value, filename in tqdm(
            [
                (train_dataloader, "train_dataloader.pkl"),
                (test_dataloader, "test_dataloader.pkl"),
                (valid_dataloader, "valid_dataloader.pkl"),
            ],
            desc="Saving dataloaders".title(),
        ):
            dump_files(
                value=value, filename=os.path.join("./data/processed/", filename)
            )

        print("Files saved in the folder ./data/processed/")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Dataloader for the Medical Assistant Task".title()
    )
    parser.add_argument(
        "--image_path",
        type=str,
        default="./data/raw/dataset.zip",
        help="Path to the dataset".capitalize(),
    )
    parser.add_argument(
        "--image_channels",
        type=int,
        default=1,
        help="Number of image channels".capitalize(),
    )
    parser.add_argument(
        "--image_size", type=int, default=224, help="Image size".capitalize()
    )
    parser.add_argument(
        "--batch_size", type=int, default=16, help="Batch size".capitalize()
    )
    parser.add_argument(
        "--split_size", type=float, default=0.30, help="Split size".capitalize()
    )

    args = parser.parse_args()

    image_path = args.image_path
    image_channels = args.image_channels
    image_size = args.image_size
    batch_size = args.batch_size
    split_size = args.split_size

    loader = Loader(
        image_path=image_path,
        image_channels=image_channels,
        image_size=image_size,
        batch_size=batch_size,
        split_size=split_size,
    )

    loader.unzip_folder()
    loader.create_dataloader()


## Patch Embedding

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dimension: int = 512):
        super(PositionalEncoding, self).__init__()
        self.dimension = dimension

        self.encoding = nn.Parameter(
            torch.randn(
                size=(
                    self.dimension // self.dimension,
                    self.dimension // self.dimension,
                    self.dimension,
                ),
                requires_grad=True,
            )
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            return x + self.encoding


class PatchEmbedding(nn.Module):
    def __init__(
        self,
        image_channels: int = 3,
        image_size: int = 224,
        patch_size: int = 16,
        embedding_dimension: int = 512,
    ):
        super(PatchEmbedding, self).__init__()
        self.image_channels = image_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.embedding_dimension = embedding_dimension

        self.total_patches = (self.image_size // self.patch_size) ** 2

        if self.embedding_dimension is None:
            warnings.warn(
                "Embedding dimension not specified. Using the default value calculated as: image_channels × patch_size × patch_size."
            )
            self.embedding_dimension = (
                self.image_channels**self.patch_size * self.patch_size
            )

        self.projection = nn.Conv2d(
            in_channels=self.image_channels,
            out_channels=self.embedding_dimension,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding=self.patch_size // self.patch_size,
            bias=False,
        )
        self.encoding = PositionalEncoding(dimension=self.embedding_dimension)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            x = self.projection(x)
            x = x.view(x.size(0), x.size(-1) * x.size(-2), x.size(1))
            x = self.encoding(x)
            return x


if __name__ == "__main__":
    image_channels = 3
    image_size = 224
    patch_size = 16
    embedding_dimension = 768

    patchEmbedding = PatchEmbedding(
        image_channels=image_channels,
        image_size=image_size,
        patch_size=patch_size,
        embedding_dimension=embedding_dimension,
    )

    images = torch.randn((64, 3, 224, 224))

    assert patchEmbedding(images).size() == (
        64,
        (image_size // patch_size) ** 2,
        768,
    ), "Patch Embedding is not working properly".capitalize()

## Multi Head Attention Layer 

In [None]:
def scaled_dot_product(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
    if (
        not isinstance(query, torch.Tensor)
        and isinstance(key, torch.Tensor)
        and isinstance(value, torch.Tensor)
    ):
        raise TypeError("All inputs must be torch.Tensor".capitalize())

    key = key.transpose(-2, -1)
    scores = torch.matmul(query, key) / math.sqrt(key.size(-1))
    scores = torch.softmax(scores, dim=-1)
    attention = torch.matmul(scores, value)
    return attention

if __name__ == "__main__":
    image_channels = 1
    image_size = 224
    patch_size = 16
    total_patches = (image_size // patch_size) ** 2
    embedding_dimension = image_channels * patch_size * patch_size

    attention = scaled_dot_product(
        query=torch.randn(1, total_patches, embedding_dimension),
        key=torch.randn(1, total_patches, embedding_dimension),
        value=torch.randn(1, total_patches, embedding_dimension),
    )

    assert attention.size() == (
        total_patches // total_patches,
        total_patches,
        embedding_dimension,
    ), "Attention output size must match the input size".capitalize()


class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, nheads: int = 6, dimension: int = 768):
        super(MultiHeadAttentionLayer, self).__init__()
        self.nheads = nheads
        self.dimension = dimension

        assert (
            self.dimension % self.nheads == 0
        ), "Dimension must be divisible by number of heads".capitalize()

        warnings.warn(
            "Invalid number of dimensions provided. To avoid errors, ensure the dimension is calculated as: in_channels × patch_size × patch_size."
        )

        self.QKV = nn.Linear(
            in_features=self.dimension, out_features=3 * self.dimension, bias=False
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())
        else:
            QKV = self.QKV(x)
            query, key, value = torch.chunk(input=QKV, chunks=3, dim=-1)
            assert (
                query.size() == key.size() == value.size()
            ), "Query, key, and value must have the same size".capitalize()

            query = query.view(
                query.size(0), query.size(1), self.nheads, query.size(-1) // self.nheads
            )
            key = key.view(
                key.size(0), key.size(1), self.nheads, key.size(-1) // self.nheads
            )
            value = value.view(
                value.size(0), value.size(1), self.nheads, value.size(-1) // self.nheads
            )

            query = query.permute(0, 2, 1, 3)
            key = key.permute(0, 2, 1, 3)
            value = value.permute(0, 2, 1, 3)

            attention = scaled_dot_product(query=query, key=key, value=value)
            attention = attention.view(
                attention.size(0),
                attention.size(-2),
                attention.size(1),
                attention.size(-1),
            )
            attention = attention.view(
                attention.size(0),
                attention.size(1),
                attention.size(2) * attention.size(3),
            )

            assert (
                attention.size() == x.size()
            ), "Attention output must have the same size as input".capitalize()

            return attention


if __name__ == "__main__":
    nheads = 8
    dimension = 256

    images = torch.randn((1, 196, 256))
    multihead_attention = MultiHeadAttentionLayer(nheads=nheads, dimension=dimension)

    assert (
        multihead_attention(x=images)
    ).size() == images.size(), "MultiHeadAttention is not working properly".capitalize()

## Layer Normalization Layer

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, normalized_shape: int = 768, eps=1e-05):
        super(LayerNormalization, self).__init__()

        self.dimension = normalized_shape
        self.eps = eps
        self.alpha = nn.Parameter(
            data=torch.ones(
                (
                    self.dimension // self.dimension,
                    self.dimension // self.dimension,
                    self.dimension,
                )
            ),
            requires_grad=True,
        )
        self.beta = nn.Parameter(
            data=torch.zeros(
                (
                    self.dimension // self.dimension,
                    self.dimension // self.dimension,
                    self.dimension,
                )
            ),
            requires_grad=True,
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())
        else:
            mean = torch.mean(x, dim=-1)
            variance = torch.var(x, dim=-1)

            mean = mean.unsqueeze(-1)
            variance = variance.unsqueeze(-1)

            x_bar = (x - mean) / torch.sqrt(variance + self.eps)

            return self.alpha * x_bar + self.beta


if __name__ == "__main__":
    image_channels = 3
    image_size = 224
    patch_size = 16
    
    total_patches = (image_size // patch_size) ** 2
    dimension = (image_channels * patch_size * patch_size)
    
    norm = LayerNormalization(normalized_shape=dimension)
    images = torch.randn((image_channels//image_channels, total_patches, dimension))

    assert (
        norm(images).size()
    ) == images.size(), "Layer Normalization is not working properly".capitalize()

## Feed Forward Neural Network

In [None]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(
        self,
        d_model: int = 768,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "gelu",
    ):
        super(FeedForwardNeuralNetwork, self).__init__()
        self.d_model = d_model
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation_func = activation

        self.in_features = self.d_model
        self.out_features = self.dim_feedforward

        if self.activation_func == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif self.activation_func == "leaky":
            self.activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        elif self.activation_func == "gelu":
            self.activation = nn.GELU()
        else:
            raise ValueError("Invalid activation function".capitalize())

        self.layers = []

        for index in range(2):
            self.layers += [
                nn.Linear(
                    in_features=self.in_features,
                    out_features=self.out_features,
                    bias=False,
                )
            ]

            if index == 0:
                self.layers += [self.activation]
                self.layers += [nn.Dropout(p=self.dropout)]

            self.in_features = self.out_features
            self.out_features = self.d_model

        self.network = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            x = self.network(x)
            return x


if __name__ == "__main__":
    image_size = 224
    patch_size = 16
    image_channels = 1
    dropout = 0.1

    total_patches = (image_size // patch_size) ** 2
    dimension = image_channels * patch_size**2
    dim_feedforward = 4 * dimension

    images = torch.randn(image_channels // image_channels, total_patches, dimension)

    network = FeedForwardNeuralNetwork(
        d_model=dimension,
        dim_feedforward=dim_feedforward,
        dropout=dropout,
    )
    assert (
        network(images).size()
    ) == images.size(), "FFNN is not working properly".capitalize()

## Transformer Encoder Block

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        nhead: int = 8,
        d_model: int = 768,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "gelu",
        layer_norm_eps: float = 1e-05,
        bias: bool = False,
    ):
        super(TransformerEncoderBlock, self).__init__()
        self.nheads = nhead
        self.d_model = d_model
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        self.layer_norm_eps = layer_norm_eps
        self.bias = bias

        self.multi_head_attention = MultiHeadAttentionLayer(
            nheads=self.nheads,
            dimension=self.d_model,
        )

        self.layer_norm1 = LayerNormalization(
            normalized_shape=self.d_model, eps=self.layer_norm_eps
        )
        self.layer_norm2 = LayerNormalization(
            normalized_shape=self.d_model, eps=self.layer_norm_eps
        )
        self.feed_forward_network = FeedForwardNeuralNetwork(
            d_model=self.d_model,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation=self.activation,
        )
        self.dropout1 = nn.Dropout(p=self.dropout)
        self.dropout2 = nn.Dropout(p=self.dropout)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            residual = x
            
            x = self.multi_head_attention(x)
            x = self.dropout1(x)
            x = torch.add(x, residual)
            x = self.layer_norm1(x)
            
            residual = x
            
            x = self.feed_forward_network(x)
            x = torch.add(x, residual)
            x = self.layer_norm2(x)

            return x

if __name__ == "__main__":
    transformer = TransformerEncoderBlock(
        nhead=8,
        d_model=256,
        dim_feedforward=2048,
        dropout=0.1,
        activation="gelu",
        layer_norm_eps=1e-05,
        bias=False,
    )
    
    images = torch.randn((1, 196, 256))
    
    assert (transformer(images).size()) == images.size(), "Transformer is not working properly".capitalize()


## ViT - With Classifier

In [None]:
class Classifier(nn.Module):
    def __init__(
        self, dimension: int = 768, dropout: float = 0.3, activation: str = "leaky"
    ):
        super(Classifier, self).__init__()

        self.dimension = dimension
        self.dropout = dropout
        self.activation_func = activation

        if self.activation_func == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif self.activation_func == "leaky":
            self.activation = nn.LeakyReLU(inplace=True)
        elif self.activation_func == "gelu":
            self.activation = nn.GELU()
        elif self.activation_func == "tanh":
            self.activation = nn.Tanh()
        else:
            raise ValueError("Invalid activation function")

        self.in_features = self.dimension
        self.out_features = self.in_features // 4

        self.layers = []

        for index in range(2):
            self.layers += [
                nn.Linear(in_features=self.in_features, out_features=self.out_features)
            ]

            if index == 0:
                self.layers += [nn.BatchNorm1d(num_features=self.out_features)]
                self.layers += [self.activation]
                self.layers += [nn.Dropout(p=self.dropout)]

            self.in_features = self.out_features
            self.out_features = self.out_features // 4

        self.layers += [
            nn.Sequential(
                nn.Linear(in_features=self.in_features, out_features=4),
                nn.Softmax(dim=1),
            )
        ]

        self.classifier = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            return self.classifier(x)


if __name__ == "__main__":
    classifier = Classifier(
        dimension=256,
        dropout=0.3,
        activation="leaky",
    )

    images = torch.randn((16, 196, 256))
    images = torch.mean(images, dim=1)
    
    assert (classifier(images).size()) == (16, 4), "Classifier is not working properly".capitalize()

class ViTWithClassifier(nn.Module):
    def __init__(
        self,
        image_channels: int = 3,
        image_size: int = 224,
        patch_size: int = 16,
        target_size: int = 4,
        encoder_layer: int = 4,
        nhead: int = 8,
        d_model: int = 768,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "gelu",
        layer_norm_eps: float = 1e-05,
        bias: bool = False,
    ):
        super(ViTWithClassifier, self).__init__()
        self.image_channels = image_channels
        self.image_size = image_size
        self.patch_size = patch_size
        self.target_size = target_size
        self.encoder_layer = encoder_layer
        self.nhead = nhead
        self.d_model = d_model
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        self.layer_norm_eps = layer_norm_eps
        self.bias = bias

        self.layers = []

        self.patch_embedding = PatchEmbedding(
            image_channels=self.image_channels,
            image_size=self.image_size,
            patch_size=self.patch_size,
            embedding_dimension=self.d_model,
        )

        self.transformer = nn.Sequential(
            *[
                TransformerEncoderBlock(
                    nhead=self.nhead,
                    d_model=self.d_model,
                    dim_feedforward=self.dim_feedforward,
                    dropout=self.dropout,
                    activation=self.activation,
                    layer_norm_eps=self.layer_norm_eps,
                    bias=self.bias,
                )
                for _ in tqdm(
                    range(self.encoder_layer), desc="Transformer Block".title()
                )
            ]
        )

        self.classifier = Classifier(
            dimension=self.d_model,
            dropout=self.dropout,
            activation=self.activation,
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        else:
            x = self.patch_embedding(x)

            for layer in self.transformer:
                x = layer(x)

            x = torch.mean(x, dim=1)
            x = self.classifier(x)

            return x


if __name__ == "__main__":
    vit = ViTWithClassifier(
        image_channels=1,
        image_size=224,
        patch_size=16,
        target_size=4,
        encoder_layer=4,
        nhead=8,
        d_model=256,
        dim_feedforward=4 * 256,
        dropout=0.1,
        activation="gelu",
        layer_norm_eps=1e-05,
        bias=False,
    )

    images = torch.randn((16, 1, 224, 224))

    assert (vit(images).size()) == (16, 4), "ViTWithClassifier is not working properly".capitalize()

## Helper Function

In [None]:
class Criterion(nn.Module):
    def __init__(self, loss_function: str = "cross_entropy", reduction: str = "mean"):
        super(Criterion, self).__init__()
        self.loss_function = loss_function
        self.reduction = reduction

        if loss_function == "cross_entropy":
            self.criterion = nn.CrossEntropyLoss(reduction=self.reduction)
        elif loss_function == "cross_entropy_with_logits":
            self.criterion = nn.CrossEntropyLossWithLogits(reduction=self.reduction)
        else:
            raise ValueError("Invalid loss function")

    def forward(self, actual: torch.Tensor, predicted: torch.Tensor):
        if not isinstance(actual, torch.Tensor) or not isinstance(
            predicted, torch.Tensor
        ):
            raise ValueError("Input must be a torch.Tensor".capitalize())

        return self.criterion(actual, predicted)


def load_dataloader():
    dataloader_path = "./data/processed"
    train_dataloader = os.path.join(dataloader_path, "train_dataloader.pkl")
    test_dataloader = os.path.join(dataloader_path, "test_dataloader.pkl")
    valid_dataloader = os.path.join(dataloader_path, "valid_dataloader.pkl")

    train_dataloader = load_files(filename=train_dataloader)
    test_dataloader = load_files(filename=test_dataloader)
    valid_dataloader = load_files(filename=valid_dataloader)

    return {
        "train_dataloader": train_dataloader,
        "test_dataloader": test_dataloader,
        "valid_dataloader": valid_dataloader,
    }


def helper(**kwargs):
    model = kwargs["model"]
    lr: float = kwargs["lr"]
    weight_decay: float = kwargs["weight_decay"]
    adam: bool = kwargs["adam"]
    beta1: float = kwargs["beta1"]
    beta2: float = kwargs["beta2"]
    SGD: bool = kwargs["SGD"]
    momentum: float = kwargs["momentum"]

    if model is None:
        classifier = ViTWithClassifier(
            image_channels=1,
            image_size=224,
            patch_size=16,
            target_size=4,
            encoder_layer=4,
            nhead=8,
            d_model=256,
            dim_feedforward=2048,
            dropout=0.1,
            activation="gelu",
            layer_norm_eps=1e-05,
            bias=False,
        )
    else:
        classifier = model

    if adam:
        optimizer = optim.Adam(
            params=classifier.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=weight_decay,
        )
    elif SGD:
        optimizer = optim.SGD(
            params=classifier.parameters(),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
    else:
        raise ValueError("Optimizer not found, use 'adam' or 'SGD'")
    try:
        dataloader = load_dataloader()
    except Exception as e:
        raise ValueError(
            "Dataloader not found, use 'load_dataloader' to load dataloader"
        )

    try:
        criterion = Criterion(loss_function="cross_entropy")
    except Exception as e:
        raise ValueError("Criterion not found, use 'Criterion' to load criterion")

    return {
        "classifier": classifier,
        "optimizer": optimizer,
        "criterion": criterion,
        "dataloader": dataloader,
    }


if __name__ == "__main__":
    init = helper(
        model=None,
        lr=1e-4,
        adam=True,
        beta1=0.9,
        beta2=0.999,
        weight_decay=0.01,
        SGD=False,
        momentum=0.9,
    )

    train_dataloader = init["dataloader"]["train_dataloader"]
    test_dataloader = init["dataloader"]["test_dataloader"]
    valid_dataloader = init["dataloader"]["valid_dataloader"]

    classifier = init["classifier"]

    criterion = init["criterion"]

    assert train_dataloader.__class__ == torch.utils.data.dataloader.DataLoader
    assert test_dataloader.__class__ == torch.utils.data.dataloader.DataLoader
    assert valid_dataloader.__class__ == torch.utils.data.dataloader.DataLoader

    assert classifier.__class__ == ViTWithClassifier
    assert criterion.__class__ == Criterion

## Trainer

In [None]:
from tqdm import tqdm
from sklearn.metrics import accuracy_score
class Trainer:
    def __init__(
        self,
        model=None,
        epochs: int = 100,
        lr: float = 0.001,
        beta1: float = 0.5,
        beta2: float = 0.999,
        weight_decay: float = 0.0,
        momentum: float = 0.85,
        adam: bool = True,
        SGD: bool = False,
        l1_regularization: bool = False,
        elasticNet_regularization: bool = False,
        device: str = "cuda",
        verbose: bool = True,
    ):
        self.model = model
        self.epochs = epochs
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.weight_decay = weight_decay
        self.momentum = momentum
        self.adam = adam
        self.SGD = SGD
        self.l1_regularization = l1_regularization
        self.elasticNet_regularization = elasticNet_regularization
        self.device = device
        self.verbose = verbose

        self.init = helper(
            model=self.model,
            lr=self.lr,
            adam=self.adam,
            beta1=self.beta1,
            beta2=self.beta2,
            weight_decay=self.weight_decay,
            SGD=self.SGD,
            momentum=self.momentum,
        )

        self.device = device_init(device=device)

        self.train_dataloader = self.init["dataloader"]["train_dataloader"]
        self.test_dataloader = self.init["dataloader"]["test_dataloader"]
        self.valid_dataloader = self.init["dataloader"]["valid_dataloader"]

        self.optimizer = self.init["optimizer"]
        self.criterion = self.init["criterion"]
        self.classifier = self.init["classifier"]

        self.classifier = self.classifier.to(self.device)
        self.criterion = self.criterion.to(self.device)

        assert self.train_dataloader.__class__ == torch.utils.data.dataloader.DataLoader
        assert self.test_dataloader.__class__ == torch.utils.data.dataloader.DataLoader
        assert self.valid_dataloader.__class__ == torch.utils.data.dataloader.DataLoader

        assert self.classifier.__class__ == ViTWithClassifier
        assert self.criterion.__class__ == Criterion

        if self.adam:
            assert self.optimizer.__class__ == torch.optim.Adam
        elif self.SGD:
            assert self.optimizer.__class__ == torch.optim.SGD
        else:
            raise ValueError("Optimizer not supported".capitalize())

        self.loss = float("inf")

    def l1_regularizer(self, model: ViTWithClassifier):
        if not isinstance(model, ViTWithClassifier):
            raise ValueError("Model must be a ViTWithClassifier".capitalize())
        return 0.01 * sum(
            torch.norm(input=params, p=1) for params in model.parameters()
        )

    def elasticNet_regularizer(self, model: ViTWithClassifier):
        if not isinstance(model, ViTWithClassifier):
            raise ValueError("Model must be a ViTWithClassifier".capitalize())
        return 0.01 * sum(
            torch.norm(input=params, p=1) for params in model.parameters()
        ) + 0.01 * sum(torch.norm(input=params, p=2) for params in model.parameters())

    def saved_checkpoints(self, train_loss: float, epoch: int = 1):
        if not isinstance(train_loss, float):
            raise ValueError("Train Loss must be a tensor".capitalize())

        if train_loss < self.loss:
            self.loss = train_loss
            torch.save(
                {
                    "train_loss": self.loss,
                    "epoch": self.epochs,
                    "model_state_dict": self.classifier.state_dict(),
                    "optimizer_state_dict": self.optimizer.state_dict(),
                },
                os.path.join("./artifacts/checkpoints/best_model", "best_model.pth"),
            )

        torch.save(
            self.classifier.state_dict(),
            os.path.join("./artifacts/checkpoints/train_models", f"model{epoch}.pth"),
        )

    def update_training(self, predicted: torch.Tensor, actual: torch.Tensor):
        if not isinstance(predicted, torch.Tensor) and isinstance(actual, torch.Tensor):
            raise ValueError("Predicted and Actual must be tensors".capitalize())

        self.optimizer.zero_grad()

        predicted_loss = self.criterion(predicted, actual)

        if self.l1_regularization:
            predicted_loss += self.l1_regularizer(self.classifier)
        elif self.elasticNet_regularization:
            predicted_loss += self.elasticNet_regularizer(self.classifier)

        predicted_loss.backward()

        self.optimizer.step()

        return predicted_loss.item()

    def display(self, **kwargs):
        epoch = kwargs["epoch"]
        train_loss = kwargs["train_loss"]
        valid_loss = kwargs["valid_loss"]
        train_accuracy = kwargs["train_accuracy"]
        valid_accuracy = kwargs["valid_accuracy"]

        print(
            "Epochs - [{}/{}] - train_loss: {:.4f} - test_loss: {:.4f} - train_accuracy: {:.4f} - test_accuracy: {:.4f}".format(
                epoch,
                self.epochs,
                train_loss,
                valid_loss,
                train_accuracy,
                valid_accuracy,
            )
        )

    def train(self):
        for epoch in tqdm(range(self.epochs), desc="Training Medical-Assistant"):
            train_loss = []
            valid_loss = []
            total_train_predicted_labels = []
            total_valid_predicted_labels = []
            total_train_actual_labels = []
            total_valid_actual_labels = []

            for images, labels in self.train_dataloader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                predicted = self.classifier(images)

                train_loss.append(
                    self.update_training(predicted=predicted, actual=labels)
                )
                predicted = torch.argmax(input=predicted, dim=1)
                predicted = predicted.detach().cpu().numpy()

                total_train_predicted_labels.append(predicted)
                total_train_actual_labels.append(labels.detach().cpu().numpy())

            for images, labels in self.test_dataloader:
                images = images.to(self.device)
                labels = labels.to(self.device)

                predicted = self.classifier(images)

                valid_loss.append(self.criterion(predicted, labels).item())

                predicted = torch.argmax(input=predicted, dim=1)
                predicted = predicted.detach().cpu().numpy()

                total_valid_predicted_labels.append(predicted)
                total_valid_actual_labels.append(labels.detach().cpu().numpy())

            train_accuracy = accuracy_score(
                np.concatenate(total_train_predicted_labels),
                np.concatenate(total_train_actual_labels),
            )
            valid_accuracy = accuracy_score(
                np.concatenate(total_valid_predicted_labels),
                np.concatenate(total_valid_actual_labels),
            )

            self.display(
                epoch=epoch + 1,
                train_loss=np.mean(train_loss),
                valid_loss=np.mean(valid_loss),
                train_accuracy=train_accuracy,
                valid_accuracy=valid_accuracy,
            )

            self.saved_checkpoints(train_loss=np.mean(train_loss), epoch=epoch + 1)


if __name__ == "__main__":
    trainer = Trainer(
        model=None,
        epochs=100,
        lr=0.001,
        beta1=0.9,
        beta2=0.999,
        weight_decay=0.0001,
        momentum=0.85,
        adam=True,
        SGD=False,
        device="mps",
        verbose=True,
    )

    trainer.train()

## Medical Assistant

In [None]:
class MedicalAssistant:
    def __init__(self, device: str = "cuda", image: str = None):
        self.device = device
        self.image = image

        self.image_channels = 1

        self.device = device_init(device=device)

        self.classifier = ViTWithClassifier(
            image_channels=1,
            image_size=224,
            patch_size=16,
            target_size=4,
            encoder_layer=1,
            nhead=8,
            d_model=256,
            dim_feedforward=256,
            dropout=0.1,
            activation="gelu",
            layer_norm_eps=1e-05,
            bias=False,
        ).to(self.device)

        self.memory = []

    def load_model(self):
        path = "./artifacts/checkpoints/best_model/best_model.pth"
        model = torch.load(path)

        state_dict = model["model_state_dict"]
        self.classifier.load_state_dict(state_dict=state_dict)

    def preprocess_image(self):
        if self.image_channels == 1:
            transform = transforms.Compose(
                [
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Grayscale(num_output_channels=1),
                    transforms.CenterCrop((224, 224)),
                    transforms.Normalize([0.5], [0.5]),
                ]
            )
        else:
            transform = transforms.Compose(
                [
                    transforms.Resize((224, 224)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )

        image = cv2.imread(self.image)
        image = Image.fromarray(image)
        image = transform(image)
        image = image.unsqueeze(0)

        return image.to(self.device)

    def chat(self):
        image = self.preprocess_image()
        self.classifier.eval()

        with torch.no_grad():
            predicted = self.classifier(image)

            score = torch.softmax(input=predicted, dim=1)
            score, _ = torch.max(score, dim=1)
            score = f"{round(score.item() * 100, 2)}%"

            predicted = torch.argmax(predicted, dim=1)[0]

            if predicted == 0:
                labels = "Brain: glioma".title()
            elif predicted == 1:
                labels = "Brain: Meningioma".title()
            elif predicted == 2:
                labels = "Brain: No Tumor".title()
            else:
                labels = "Brain: Pituitary".title()

        load_dotenv()

        llm = ChatOpenAI()
        parser = StrOutputParser()

        initial_response = classifier_prompt | llm | parser
        initial_response = initial_response.invoke(
            {"predicted_disease": labels, "predicted_probability": score}
        )

        print(initial_response)

        self.memory.append("AI Response:\n" + " " + initial_response)

        while True:
            question = input("Human: ")
            if question == "exit":
                break

            question = "\n".join(self.memory) + "\nHuman: " + question

            response = QA_prompt | llm | parser
            response = response.invoke({"question": question})

            self.memory.append("AI Response:\n" + " " + response)

            print("AI:\n", response)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Medical Assistant chatbot".title())
    parser.add_argument("--image", type=str, help="Path to the image file".capitalize())
    parser.add_argument(
        "--device", type=str, help="Device to run the model on".capitalize()
    )

    args = parser.parse_args()

    assistant = MedicalAssistant(device=args.device, image=args.image)

    assistant.load_model()
    assistant.chat()
