### Import Lib

In [None]:
import os
import sys
import cv2
import math
import yaml
import torch
import joblib
import zipfile
import warnings
import argparse
import pandas as pd
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchview import draw_graph
from torchvision import transforms
from torch.utils.data import DataLoader

### Install the required dependencies

In [None]:
!pip install -r requirements.txt

### Utility functions

In [None]:
def config_files():
    with open("../notebook_config.yml", "r") as config_file:
        return yaml.safe_load(config_file)


def dump_files(value=None, filename=None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
    else:
        raise ValueError(
            "Determine the filename and value of a config file".capitalize()
        )


def load_file(filename: str = None):
    if filename is not None:
        return joblib.load(filename)
    else:
        raise ValueError(
            "Please provide a filename to load config data from.".capitalize()
        )


def device_init(device: str = "cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device(device)


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        nn.init.kaiming_normal_(m.weight.data, nonlinearity="relu")
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
        elif classname.find("BatchNorm") != -1:
            nn.init.constant_(m.weight.data, 1.0)
            nn.init.constant_(m.bias.data, 0.0)


def path_names():
    raw_data_path: str = config_files()["artifacts"]["raw_data_path"]
    processed_data_path: str = config_files()["artifacts"]["processed_data_path"]
    train_models_path: str = config_files()["artifacts"]["train_models"]
    best_model_path: str = config_files()["artifacts"]["best_model"]
    files_path: str = config_files()["artifacts"]["files_path"]
    metrics_path: str = config_files()["artifacts"]["metrics_path"]
    train_images: str = config_files()["artifacts"]["train_images"]
    test_image: str = config_files()["artifacts"]["test_image"]
    image_path: str = config_files()["dataloader"]["image_path"]

    return {
        "raw_data_path": raw_data_path,
        "processed_data_path": processed_data_path,
        "train_models_path": train_models_path,
        "best_model_path": best_model_path,
        "files_path": files_path,
        "metrics_path": metrics_path,
        "train_images": train_images,
        "test_image": test_image,
        "image_path": image_path,
    }


def plot_images(
    predicted_images: torch.Tensor = None,
    predicted: bool = False,
):
    processed_data_path = path_names()["processed_data_path"]
    processed_data_path = load_file(
        filename=os.path.join(processed_data_path, "test_dataloader.pkl")
    )
    images, masks = next(iter(processed_data_path))

    max_number = min(16, images.size(0))
    num_of_rows = math.ceil(math.sqrt(max_number))
    num_of_columns = math.ceil(max_number / num_of_rows)

    plt.figure(figsize=(10, 20))

    for index, (image, mask) in enumerate(zip(images[:max_number], masks[:max_number])):
        image = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
        mask = mask.squeeze().detach().cpu().numpy()

        image = (image - image.min()) / (image.max() - image.min())
        mask = (mask - mask.min()) / (mask.max() - mask.min())

        if predicted:
            pred_mask = predicted_images[index].squeeze().permute(1, 2, 0)
            pred_mask = pred_mask.detach().cpu().numpy()
            pred_mask = (pred_mask - pred_mask.min()) / (
                pred_mask.max() - pred_mask.min()
            )

            plt.subplot(3 * num_of_rows, num_of_columns, 3 * index + 1)
            plt.imshow(image)
            plt.axis("off")
            plt.title("Image")

            plt.subplot(3 * num_of_rows, num_of_columns, 3 * index + 2)
            plt.imshow(mask, cmap="gray")
            plt.axis("off")
            plt.title("Mask")

            plt.subplot(3 * num_of_rows, num_of_columns, 3 * index + 3)
            plt.imshow(pred_mask, cmap="gray")
            plt.axis("off")
            plt.title("Pred Mask")

        else:
            plt.subplot(2 * num_of_rows, num_of_columns, 2 * index + 1)
            plt.imshow(image)
            plt.axis("off")
            plt.title("Image")

            plt.subplot(2 * num_of_rows, num_of_columns, 2 * index + 2)
            plt.imshow(mask, cmap="gray")
            plt.axis("off")
            plt.title("Mask")

    plt.subplots_adjust(wspace=0.3, hspace=0.3)
    plt.tight_layout()

    save_path = os.path.join(path_names()["files_path"], "images.png")
    plt.savefig(save_path)
    plt.show()
    plt.close()
    print("Image files saved in", save_path)

def total_params(model=None):
    if model is None:
        raise ValueError(
            "Please provide a model to calculate the total parameters.".capitalize()
        )
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


def plot_model_architecture(
    model=None,
    input_data: torch.Tensor = None,
    model_name: str = "./artifacts/files_name/",
    format: str = "pdf",
):
    if model is None and not isinstance(input_data, torch):
        raise ValueError(
            "Please provide a model and input data to plot the model architecture.".capitalize()
        )

    filename = path_names()["files_path"]
    draw_graph(model=model, input_data=input_data).visual_graph.render(
        filename=os.path.join(filename, model_name), format=format
    )
    print(f"Model architecture saved in {filename}/{model_name}.{format}")


### DataLoader

In [None]:
class Loader:
    def __init__(
        self,
        image_path=None,
        image_channels: int = 3,
        image_size: int = 224,
        batch_size: int = 4,
        split_size: float = 0.25,
        shuffle: bool = False,
    ):
        self.image_channels = image_channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size
        self.shuffle = shuffle

        if image_path is None:
            self.image_path = path_names()["image_path"]
        else:
            warnings.warn(
                "Ensure the provided image path is correct; an incorrect path may result in an error.".title()
            )
            self.image_path = image_path

        self.train_images = list()
        self.train_masks = list()
        self.test_images = list()
        self.test_masks = list()

    def transform_images(self, type: str = "image"):
        if type == "image":
            return transforms.Compose(
                [
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.ToTensor(),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
                ]
            )
        elif type == "mask":
            return transforms.Compose(
                [
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.ToTensor(),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.Normalize((0.5,), (0.5,)),
                ]
            )

    def unzip_folder(self):
        if os.path.exists(self.image_path):
            with zipfile.ZipFile(file=self.image_path, mode="r") as zip_file:
                zip_file.extractall(path=path_names()["processed_data_path"])
            print(
                f"Dataset unzipped successfully in {path_names()['processed_data_path']}"
            )

        else:
            raise FileNotFoundError("Image file not found".capitalize())

    def features_extracted(self):
        processed_path = path_names()["processed_data_path"]
        train_path = os.path.join(processed_path, "dataset", "train")
        valid_path = os.path.join(processed_path, "dataset", "test")

        train_images_path = os.path.join(train_path, "image")
        train_masks_path = os.path.join(train_path, "mask")

        test_images_path = os.path.join(valid_path, "image")
        test_masks_path = os.path.join(valid_path, "mask")

        for type, path in tqdm(
            [("train", train_images_path), ("test", test_images_path)],
            desc="Extracted features ..",
        ):
            try:
                mask_path = train_masks_path if type == "train" else test_masks_path

                for image in os.listdir(path):
                    try:
                        if image in os.listdir(mask_path):
                            image_path = os.path.join(path, image)
                            mask_image_path = os.path.join(mask_path, image)

                            _image = cv2.imread(image_path)
                            if _image is None:
                                raise FileNotFoundError(
                                    f"Image file not found: {image_path}"
                                )
                            _image = cv2.cvtColor(_image, cv2.COLOR_BGR2RGB)
                            _image = Image.fromarray(_image)
                            _image = self.transform_images(type="image")(_image)

                            _mask = cv2.imread(mask_image_path, cv2.IMREAD_GRAYSCALE)
                            _mask = Image.fromarray(_mask)
                            _mask = self.transform_images(type="mask")(_mask)

                            if _mask is None:
                                raise FileNotFoundError(
                                    f"Mask file not found: {mask_image_path}"
                                )

                            if type == "train":
                                self.train_images.append(_image)
                                self.train_masks.append(_mask)

                            elif type == "test":
                                self.test_images.append(_image)
                                self.test_masks.append(_mask)

                    except FileNotFoundError as e:
                        print(f"[WARNING] {e}")
                    except cv2.error as e:
                        print(f"[ERROR] OpenCV error while processing {image}: {e}")
                    except Exception as e:
                        print(f"[ERROR] Unexpected error while processing {image}: {e}")

            except Exception as e:
                print(f"[CRITICAL] Failed to process dataset for {type} images: {e}")

        assert len(self.train_images) == len(
            self.train_masks
        ), "Images, Masks should be equal in size".capitalize()

        assert len(self.test_images) == len(
            self.test_masks
        ), "Images, Masks should be equal in size".capitalize()

        return {
            "train_images": self.train_images,
            "train_masks": self.train_masks,
            "test_images": self.test_images,
            "test_masks": self.test_masks,
        }

    def create_dataloader(self):
        try:
            dataset = self.features_extracted()

            train_dataloader = DataLoader(
                dataset=list(zip(dataset["train_images"], dataset["train_masks"])),
                batch_size=self.batch_size,
                shuffle=self.shuffle,
            )

            valid_dataloader = DataLoader(
                dataset=list(zip(dataset["test_images"], dataset["test_masks"])),
                batch_size=self.batch_size,
                shuffle=self.shuffle,
            )

            for filename, value in [
                ("train_dataloader.pkl", train_dataloader),
                ("test_dataloader.pkl", valid_dataloader),
            ]:
                dump_files(
                    value=value,
                    filename=os.path.join(
                        path_names()["processed_data_path"], filename
                    ),
                )
                print(f"Dataloader saved as {filename}".capitalize())
        except FileNotFoundError as e:
            print(f"[WARNING] {e}")
        except cv2.error as e:
            print(f"[ERROR] OpenCV error while processing: {e}")
        except Exception as e:
            print(f"[ERROR] Unexpected error while processing: {e}")

    @staticmethod
    def display_images():
        try:
            plot_images(predicted=False)
        except Exception as e:
            print(f"[ERROR] Unexpected error while displaying images: {e}")

    @staticmethod
    def dataset_details():
        processed_data_path = path_names()["processed_data_path"]

        train_dataloader = os.path.join(processed_data_path, "train_dataloader.pkl")
        valid_dataloader = os.path.join(processed_data_path, "test_dataloader.pkl")

        train_dataloader = load_file(filename=train_dataloader)
        valid_dataloader = load_file(filename=valid_dataloader)

        train_images, _ = next(iter(train_dataloader))
        _, valid_masks = next(iter(valid_dataloader))

        pd.DataFrame(
            {
                "Train Images": sum(images.size(0) for images, _ in train_dataloader),
                "Valid Images": sum(images.size(0) for images, _ in valid_dataloader),
                "Train Masks": sum(masks.size(0) for _, masks in train_dataloader),
                "Valid Masks": sum(masks.size(0) for _, masks in valid_dataloader),
                "Image Size": str(train_images.size()),
                "Mask Size": str(valid_masks.size()),
            },
            index=["Dataset Details"],
        ).to_csv(os.path.join(path_names()["files_path"], "dataset_details.csv"))
        print(f"Dataset details saved to {path_names()['files_path']}".capitalize())


if __name__ == "__main__":
    loader = Loader(
        image_path="../data/raw/dataset.zip",
        image_channels=3,
        image_size=128,
        batch_size=16,
        split_size=0.25,
        shuffle=True,
    )

    loader.unzip_folder()
    loader.create_dataloader()

    Loader.display_images()
    Loader.dataset_details()

### Loss Functions
1. BCELoss
2. DiceLoss
3. JaccardLoss
4. FocalLoss
5. TverskyLoss

In [None]:
class BCE(nn.Module):

    def __init__(self, name: str = "BCEWithLogitsLoss", reduction: str = "mean"):
        super(BCE, self).__init__()
        self.name = name
        self.reduction = reduction

        if self.name == "BCEWithLogitsLoss":
            self.loss_func = nn.BCEWithLogitsLoss(reduction=self.reduction)
        elif self.name == "BCELoss":
            self.loss_func = nn.BCELoss(reduction=self.reduction)

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(actual, torch.Tensor):
            return self.loss_func(predicted, actual)
        else:
            raise ValueError("Input must be torch.Tensor".capitalize())


if __name__ == "__main__":
    predicted = torch.Tensor([3.0, 2.0, -1.0, -3.0, 5.0])
    actual = torch.Tensor([1.0, 1.0, 0.0, 0.0, 1.0])

    loss = BCE(name="BCEWithLogitsLoss", reduction="mean")
    assert type(loss(predicted, actual)) == torch.Tensor

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, reduction: str = "mean", smooth: float = 1e-5):
        super(DiceLoss, self).__init__()
        self.smooth = smooth
        self.reduction = reduction

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if not isinstance(predicted, torch.Tensor) or not isinstance(
            actual, torch.Tensor
        ):
            raise ValueError("Inputs must be torch.Tensor")

        predicted = torch.sigmoid(predicted)

        predicted = predicted.view(predicted.shape[0], -1)
        actual = actual.view(actual.shape[0], -1)

        intersection = (predicted * actual).sum(dim=1)
        union = predicted.sum(dim=1) + actual.sum(dim=1)

        dice_score = (2.0 * intersection + self.smooth) / (union + self.smooth)

        loss = 1 - dice_score

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss


if __name__ == "__main__":
    predicted = torch.randn((64, 1, 128, 128))
    actual = torch.randint(0, 2, (64, 1, 128, 128)).float()

    loss_func = DiceLoss(reduction="mean", smooth=1e-5)
    print(loss_func(predicted, actual))

In [None]:
class JaccardLoss(nn.Module):
    def __init__(self, reduction: str = "mean", smooth: float = 1e-5):
        super(JaccardLoss, self).__init__()
        self.smooth = smooth
        self.reduction = reduction

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if not isinstance(predicted, torch.Tensor) or not isinstance(
            actual, torch.Tensor
        ):
            raise ValueError("Inputs must be torch.Tensor")

        predicted = torch.sigmoid(predicted)

        predicted = predicted.view(predicted.shape[0], -1)
        actual = actual.view(actual.shape[0], -1)

        intersection = (predicted * actual).sum(dim=1)
        union = (predicted + actual - (predicted * actual)).sum(dim=1)

        jaccard_score = (intersection + self.smooth) / (union + self.smooth)

        loss = 1 - jaccard_score

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        else:
            return loss


if __name__ == "__main__":
    predicted = torch.randn((64, 1, 128, 128))
    actual = torch.randint(0, 2, (64, 1, 128, 128)).float()

    loss_func = JaccardLoss(reduction="mean", smooth=1e-5)
    print(loss_func(predicted, actual))

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=0.75, gamma=2.0, reduction="mean"):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

        self.bce_loss = nn.BCEWithLogitsLoss(reduction="none")

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if not isinstance(predicted, torch.Tensor) or not isinstance(
            actual, torch.Tensor
        ):
            raise ValueError("Inputs must be torch.Tensor")

        bce_loss = self.bce_loss(predicted, actual)

        with torch.no_grad():
            pt = torch.exp(-bce_loss)

        focal_weight = self.alpha * (1 - pt) ** self.gamma
        focal_loss = focal_weight * bce_loss

        if self.reduction == "mean":
            return focal_loss.mean()
        elif self.reduction == "sum":
            return focal_loss.sum()
        else:
            return focal_loss


if __name__ == "__main__":

    predicted = torch.randn((64, 1, 128, 128))
    actual = torch.randint(0, 2, (64, 1, 128, 128)).float()

    loss_func = FocalLoss(alpha=0.75, gamma=2.0, reduction="mean")

    print(loss_func(predicted, actual))

In [None]:
class TverskyLoss(nn.Module):
    def __init__(self, alpha: float = 0.7, beta: float = 0.3, reduction: str = "mean"):
        super(TverskyLoss, self).__init__()
        self.alpha = alpha
        self.beta = beta
        self.reduction = reduction

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if not isinstance(predicted, torch.Tensor) or not isinstance(
            actual, torch.Tensor
        ):
            raise ValueError("Inputs must be torch.Tensor")

        predicted = torch.sigmoid(predicted)

        predicted = predicted.view(predicted.size(0), -1)
        actual = actual.view(actual.size(0), -1)

        TP = (predicted * actual).sum(dim=1)
        FP = (predicted * (1 - actual)).sum(dim=1)
        FN = ((1 - predicted) * actual).sum(dim=1)

        epsilon = 1e-8
        tversky_index = TP / (TP + self.alpha * FP + self.beta * FN + epsilon)

        tversky_loss = 1 - tversky_index

        if self.reduction == "mean":
            return tversky_loss.mean()
        elif self.reduction == "sum":
            return tversky_loss.sum()
        else:
            return tversky_loss


if __name__ == "__main__":
    predicted = torch.randn((64, 1, 128, 128))
    actual = torch.randint(0, 2, (64, 1, 128, 128)).float()

    loss_func = TverskyLoss(alpha=0.7, beta=0.3, reduction="mean")
    print(loss_func(actual, predicted))

### Scaled Dot Product -> Self Attention Compute

In [None]:
def scaled_dot_product(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
    if not (
        isinstance(query, torch.Tensor)
        and isinstance(key, torch.Tensor)
        and isinstance(value, torch.Tensor)
    ):
        raise ValueError("Inputs must be torch.Tensor")

    key = key.transpose(-1, -2)

    logits = torch.matmul(query, key)
    dimension = key.size(-1)

    logits = logits / torch.sqrt(
        torch.tensor(float(dimension), dtype=query.dtype, device=query.device)
    )
    logits = torch.softmax(logits, dim=-1)

    attention = torch.matmul(logits, value)

    return attention


if __name__ == "__main__":
    query = torch.randn(1, 4, 256, 128)
    key = torch.randn(1, 4, 256, 128)
    value = torch.randn(1, 4, 256, 128)

    attention = scaled_dot_product(query, key, value)
    print(attention.size())

### MultiHead Attention

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, nheads: int = 4, dimension: int = 256):
        super(MultiHeadAttention, self).__init__()
        self.nheads = nheads
        self.dimension = dimension

        assert dimension % nheads == 0, "Dimension must be divisible by nheads"

        self.QKV = nn.Linear(
            in_features=self.dimension, out_features=3 * self.dimension
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            QKV = self.QKV(x)

            query, key, value = torch.chunk(input=QKV, chunks=3, dim=-1)

            query = query.view(
                query.size(0), query.size(1), self.nheads, self.dimension // self.nheads
            )
            key = key.view(
                key.size(0), key.size(1), self.nheads, self.dimension // self.nheads
            )
            value = value.view(
                value.size(0), value.size(1), self.nheads, self.dimension // self.nheads
            )

            query = query.permute(0, 2, 1, 3)
            key = key.permute(0, 2, 1, 3)
            value = value.permute(0, 2, 1, 3)

            attention = scaled_dot_product(query=query, key=key, value=value)

            attention = attention.view(
                attention.size(0),
                attention.size(2),
                attention.size(1) * attention.size(3),
            )
            return attention

        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    attention = MultiHeadAttention(nheads=4, dimension=512)
    images = torch.randn(1, 256, 512)
    print(attention(x=images).size())

### Feed Foward Neural Network : MLP

In [None]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(
        self,
        dimension: int = 512,
        dim_feedforward: int = 1024,
        activation: str = "relu",
        dropout: float = 0.1,
        bias: bool = False,
    ):
        super(FeedForwardNeuralNetwork, self).__init__()
        self.dimension = dimension
        self.dim_feedforward = dim_feedforward
        self.activation = activation
        self.dropout = dropout
        self.bias = bias

        if self.activation == "relu":
            self.activation_func = nn.ReLU(inplace=True)
        elif self.activation == "leaky":
            self.activation_func = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        elif self.activation == "gelu":
            self.activation_func = nn.GELU(inplace=True)
        elif self.activation == "selu":
            self.activation_func = nn.SELU(inplace=True)
        else:
            raise ValueError("Unsupported activation function".capitalize())

        self.layers = list()

        for index in range(2):
            self.layers.append(
                nn.Linear(
                    in_features=self.dimension,
                    out_features=self.dim_feedforward,
                    bias=self.bias,
                )
            )
            self.dimension = self.dim_feedforward
            self.dim_feedforward = dimension

            if index == 0:
                self.layers.append(self.activation_func)
                self.layers.append(nn.Dropout(p=self.dropout))

        self.network = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise ValueError("Input must be a torch.Tensor".capitalize())

        x = self.network(x)

        return x


if __name__ == "__main__":
    network = FeedForwardNeuralNetwork(dimension=512, dim_feedforward=1024)
    images = torch.randn((1, 256, 512))
    print(network(x=images).size())

### Layer Normalization

In [None]:
class LayerNormalization(nn.Module):
    def __init__(self, dimension: int = 512, layer_norm_eps: float = 1e-05):
        super(LayerNormalization, self).__init__()
        self.dimension = dimension
        self.layer_norm_eps = layer_norm_eps

        self.alpha = nn.Parameter(
            data=torch.ones((1, 1, self.dimension)), requires_grad=True
        )
        self.beta = nn.Parameter(
            data=torch.zeros((1, 1, self.dimension)), requires_grad=True
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise ValueError("Input must be a torch.Tensor".capitalize())

        mean = torch.mean(input=x, dim=-1)
        variance = torch.var(input=x, dim=-1)

        mean = mean.unsqueeze(-1)
        variance = variance.unsqueeze(-1)

        normalized_x = (x - mean) / torch.sqrt(variance + self.layer_norm_eps)

        return self.alpha * normalized_x + self.beta


if __name__ == "__main__":
    layer_norm = LayerNormalization(dimension=512)
    images = torch.randn((1, 256, 512))
    print(layer_norm(x=images).size())

### Patch Embedding Layer

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(
        self,
        image_size: int = 128,
        patch_size: int = 1,
        dimension: int = 1024,
        bias: bool = False,
    ):
        super(PatchEmbedding, self).__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.dimension = dimension
        self.bias = bias
        self.constant = 16

        assert (
            self.image_size % self.patch_size == 0
        ), "Image size must be divisible by the patch size.".title()

        warnings.warn(
            "The encoder block extracts features, which may reduce the effective image size. "
            "To mitigate this, we set the patch size to 1."
        )

        self.num_of_patches = (
            (self.image_size // self.constant) // self.patch_size
        ) ** 2
        self.in_channels, self.out_channels = self.dimension, self.dimension

        self.projection = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            bias=self.bias,
        )

        self.postitonal_embedding = nn.Parameter(
            data=torch.randn((1, self.num_of_patches, self.dimension)),
            requires_grad=True,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.projection(x)
            x = x.view(x.size(0), x.size(-1) * x.size(-2), x.size(1))
            x = self.postitonal_embedding + x
            return x
        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    patch_embedding = PatchEmbedding(image_size=128, patch_size=1, dimension=1024)
    images = torch.randn((16, 1024, 8, 8))
    print(patch_embedding(x=images).size())

### Encoder Block for TransUNet

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels: int = 128, out_channels: int = 2 * 128):
        super(EncoderBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 1
        self.stride_size = 1
        self.padding_size = self.kernel_size // 2

        self.layers = list()

        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size * 2,
                padding=self.padding_size,
                bias=False,
            ),
            nn.BatchNorm2d(num_features=self.out_channels),
        )

        self.in_channels = self.in_channels
        self.out_channels = self.out_channels

        for index in range(3):
            if index != 1:
                self.kernel_size = self.kernel_size // self.kernel_size
                self.stride_size = self.stride_size // self.stride_size
                self.padding_size = self.kernel_size // 2
            else:
                self.kernel_size = self.kernel_size * 3
                self.stride_size = self.stride_size * 2
                self.padding_size = self.kernel_size // self.kernel_size

            self.layers.append(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                    bias=False,
                )
            )
            self.layers.append(nn.BatchNorm2d(num_features=self.out_channels))

            self.in_channels = self.out_channels

        self.layers.append(nn.ReLU(inplace=True))

        self.encoder_block = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x1 = self.conv1(x)
            x2 = self.encoder_block(x)
            x = x1 + x2
            return x
        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, EncoderBlock):
            return sum(p.numel() for p in model.parameters())
        else:
            raise ValueError("Input must be an EncoderBlock".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Encoder Block for the TransUNet".title()
    )
    parser.add_argument(
        "--in_channels",
        type=int,
        default=128,
        help="Input channels to encode".capitalize(),
    )
    parser.add_argument(
        "--out_channels",
        type=int,
        default=256,
        help="Output channels to decode".capitalize(),
    )
    args = parser.parse_args()

    encoder_block = EncoderBlock(
        in_channels=args.in_channels,
        out_channels=args.out_channels,
    )
    images = torch.randn((16, 128, 64, 64))
    assert (encoder_block(x=images).size()) == (16, 128 * 2, 64 // 2, 64 // 2)

### Transformer Encoder Block

In [None]:
class TransformerEncoderBlock(nn.Module):
    def __init__(
        self,
        dimension: int = 512,
        nheads: int = 8,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-05,
        bias: bool = False,
    ):
        super(TransformerEncoderBlock, self).__init__()
        self.dimension = dimension
        self.nheads = nheads
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        self.layer_norm_eps = layer_norm_eps
        self.bias = bias

        self.multi_head_attention = MultiHeadAttention(
            nheads=self.nheads, dimension=self.dimension
        )
        self.layer_norm1 = LayerNormalization(
            dimension=self.dimension, layer_norm_eps=self.layer_norm_eps
        )
        self.layer_norm2 = LayerNormalization(
            dimension=self.dimension, layer_norm_eps=self.layer_norm_eps
        )
        self.dropout1 = nn.Dropout(p=self.dropout)
        self.dropout2 = nn.Dropout(p=self.dropout)

        self.feed_forward_network = FeedForwardNeuralNetwork(
            dimension=self.dimension,
            dim_feedforward=self.dim_feedforward,
            activation=self.activation,
            dropout=self.dropout,
            bias=self.bias,
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise ValueError("Input must be a torch instance".capitalize())

        residual = x

        x = self.multi_head_attention(x=x)
        x = self.dropout1(x)
        x = torch.add(x, residual)
        x = self.layer_norm1(x)

        residual = x

        x = self.feed_forward_network(x=x)
        x = self.dropout2(x)
        x = torch.add(x, residual)
        x = self.layer_norm2(x)

        return x


if __name__ == "__main__":
    transfomer = TransformerEncoderBlock(
        dimension=512,
        nheads=4,
        dim_feedforward=1024,
        dropout=0.1,
        activation="relu",
        layer_norm_eps=1e-05,
        bias=False,
    )

    images = torch.randn((1, 256, 512))
    print(transfomer(x=images).size())

### ViT: Vision Transformer

In [None]:
class ViT(nn.Module):
    def __init__(
        self,
        image_size: int = 256,
        dimension: int = 512,
        nheads: int = 8,
        num_layers: int = 4,
        dim_feedforward: int = 1024,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-05,
        bias: bool = False,
    ):
        super(ViT, self).__init__()
        self.image_size = image_size
        self.dimension = dimension
        self.nheads = nheads
        self.num_layers = num_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        self.layer_norm_eps = layer_norm_eps
        self.bias = bias

        self.patch_embedding = PatchEmbedding(
            image_size=self.image_size,
            patch_size=self.image_size // self.image_size,
            dimension=self.dimension,
            bias=self.bias,
        )

        self.transformer = nn.Sequential(
            *[
                TransformerEncoderBlock(
                    dimension=self.dimension,
                    nheads=self.nheads,
                    dim_feedforward=self.dim_feedforward,
                    dropout=self.dropout,
                    activation=self.activation,
                    layer_norm_eps=self.layer_norm_eps,
                    bias=self.bias,
                )
                for _ in tqdm(range(self.num_layers))
            ]
        )

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise ValueError("Input must be a torch.Tensor".capitalize())

        x = self.patch_embedding(x)

        for transformer in self.transformer:
            x = transformer(x)

        x = x.view(
            x.size(0), x.size(-1), int(math.sqrt(x.size(1))), int(math.sqrt(x.size(1)))
        )

        return x


if __name__ == "__main__":
    vit = ViT(
        image_size=256,
        dimension=512,
        nheads=4,
        num_layers=4,
        dim_feedforward=1024,
        dropout=0.1,
        activation="relu",
        layer_norm_eps=1e-05,
        bias=False,
    )

    images = torch.randn((16, 512, 16, 16))
    print(vit(x=images).size())

### Decoder Block

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels: int = 256, out_channels: int = 128):
        super(DecoderBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 1
        self.stride_size = 1
        self.padding_size = 1

        self.layers = []

        self.upsample = nn.Upsample(
            scale_factor=2, mode="bilinear", align_corners=False
        )

        for _ in range(2):
            self.layers.append(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size * 3,
                    stride=self.stride_size,
                    padding=self.padding_size,
                    bias=False,
                )
            )
            self.layers.append(nn.BatchNorm2d(num_features=self.out_channels))
            self.layers.append(nn.ReLU(inplace=True))

            self.in_channels = self.out_channels

        self.decoder_network = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if not isinstance(x, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor".capitalize())

        x = self.upsample(x)
        x = self.decoder_network(x)

        return x

    @staticmethod
    def total_parameters(model):
        if not isinstance(model, DecoderBlock):
            raise ValueError("Input must be a DecoderBlock".capitalize())

        print("Total Parameter: ", total_params(model=model))


if __name__ == "__main__":
    decoder_block = DecoderBlock(
        in_channels=256, out_channels=256//2
    )

    images = torch.randn((1, 256, 8, 8))

    print(decoder_block(images).size())

### TransUNet: A combination of UNet and Transfomer ViT

In [None]:
class TransUNet(nn.Module):
    def __init__(
        self,
        image_channels: int = 3,
        image_size: int = 256,
        nheads: int = 8,
        num_layers: int = 4,
        dim_feedforward: int = 2048,
        dropout: float = 0.1,
        activation: str = "relu",
        layer_norm_eps: float = 1e-05,
        bias: bool = False,
    ):
        super(TransUNet, self).__init__()

        self.image_channels = image_channels
        self.image_size = image_size
        self.nheads = nheads
        self.num_layers = num_layers
        self.dim_feedforward = dim_feedforward
        self.dropout = dropout
        self.activation = activation
        self.layer_norm_eps = layer_norm_eps
        self.bias = bias

        warnings.warn(
            "Output channel configuration is determined based on image size:\n"
            "  - If image_size > 256, out_channels = (image_channels - 1) ** 8\n"
            "  - If image_size > 128, out_channels = (image_channels - 1) ** 6\n"
            "  - Otherwise, out_channels = (image_channels - 1) ** 5"
        )

        if self.image_size > math.pow(2, 8):
            self.out_channels = (self.image_channels - 1) ** 8
        elif self.image_size > math.pow(2, 7):
            self.out_channels = (self.image_channels - 1) ** 6
        else:
            self.out_channels = (self.image_channels - 1) ** 5

        self.kernel_size = (self.image_channels * 2) + 1
        self.stride_size = (self.kernel_size // self.kernel_size) + 1
        self.padding_size = self.kernel_size // 2

        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=self.image_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=self.bias,
            ),
            nn.BatchNorm2d(num_features=self.out_channels),
            nn.ReLU(inplace=True),
        )

        self.in_channels = self.out_channels

        self.encoder1 = EncoderBlock(
            in_channels=self.in_channels, out_channels=self.in_channels * 2
        )
        self.encoder2 = EncoderBlock(
            in_channels=self.in_channels * 2, out_channels=self.in_channels * 4
        )
        self.encoder3 = EncoderBlock(
            in_channels=self.in_channels * 4, out_channels=self.in_channels * 8
        )

        self.decoder1 = DecoderBlock(
            in_channels=self.in_channels * 8, out_channels=self.in_channels * 4
        )
        self.decoder2 = DecoderBlock(
            in_channels=self.in_channels * 8, out_channels=self.in_channels * 2
        )
        self.decoder3 = DecoderBlock(
            in_channels=self.in_channels * 4, out_channels=self.in_channels
        )
        self.decoder4 = DecoderBlock(
            in_channels=self.in_channels * 2,
            out_channels=self.in_channels // self.in_channels,
        )

        self.vit = ViT(
            image_size=self.image_size,
            dimension=self.in_channels * 8,
            nheads=self.nheads,
            num_layers=self.num_layers,
            dim_feedforward=self.dim_feedforward,
            dropout=self.dropout,
            activation=self.activation,
            layer_norm_eps=self.layer_norm_eps,
            bias=self.bias,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.conv1(x)

            encoder1 = self.encoder1(x)
            encoder2 = self.encoder2(encoder1)
            encoder3 = self.encoder3(encoder2)

            bottleneck = self.vit(x=encoder3)

            decoder1 = self.decoder1(bottleneck)
            decoder1 = torch.concat((decoder1, encoder2), dim=1)

            decoder2 = self.decoder2(decoder1)
            decoder2 = torch.concat((decoder2, encoder1), dim=1)

            decoder3 = self.decoder3(decoder2)
            decoder3 = torch.concat((decoder3, x), dim=1)

            output = self.decoder4(decoder3)

            return output
        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    model = TransUNet(
        image_channels=3,
        image_size=256,
        nheads=4,
        num_layers=4,
        dim_feedforward=1024,
        dropout=0.1,
        activation="relu",
        layer_norm_eps=1e-05,
        bias=False,
    )

    images = torch.randn((1, 3, 256, 256))

    segmented_images = model(x=images)

    print("Output size:", segmented_images.size())

### Helper function

In [None]:
def load_dataloader():
    processed_path = path_names()["processed_data_path"]

    train_dataloader = os.path.join(processed_path, "train_dataloader.pkl")
    valid_dataloader = os.path.join(processed_path, "test_dataloader.pkl")

    train_dataloader = load_file(filename=train_dataloader)
    valid_dataloader = load_file(filename=valid_dataloader)

    return {"train_dataloader": train_dataloader, "valid_dataloader": valid_dataloader}


def helper(**kwargs):
    model: TransUNet = kwargs["model"]
    lr: float = kwargs["lr"]
    beta1: float = kwargs["beta1"]
    beta2: float = kwargs["beta2"]
    weight_decay: float = kwargs["weight_decay"]
    momentum: float = kwargs["momentum"]
    adam: bool = kwargs["adam"]
    SGD: bool = kwargs["SGD"]
    loss: str = kwargs["loss"]
    loss_smooth: float = kwargs["loss_smooth"]
    alpha_focal: float = kwargs["alpha_focal"]
    gamma_focal: float = kwargs["gamma_focal"]
    alpha_tversky: float = kwargs["alpha_tversky"]
    beta_tversky: float = kwargs["beta_tversky"]

    image_channels: int = config_files()["dataloader"]["image_channels"]
    image_size: int = config_files()["dataloader"]["image_size"]
    nheads: int = config_files()["TransUNet"]["nheads"]
    num_layers: int = config_files()["TransUNet"]["num_layers"]
    dim_feedforward: int = config_files()["TransUNet"]["dim_feedforward"]
    dropout: float = float(config_files()["TransUNet"]["dropout"])
    activation: str = config_files()["TransUNet"]["activation"]
    layer_norm_eps: float = float(config_files()["TransUNet"]["layer_norm_eps"])
    bias: bool = config_files()["TransUNet"]["bias"]
    

    if model is None:
        trans_unet = TransUNet(
            image_channels=image_channels,
            image_size=image_size,
            nheads=nheads,
            num_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation=activation,
            layer_norm_eps=layer_norm_eps,
            bias=bias
        )
    elif isinstance(model, TransUNet):
        trans_unet = model
    else:
        raise ValueError("Invalid model type. Expected TransUNet.".capitalize())
    
    if adam:
        optimizer = optim.Adam(params=trans_unet.parameters(), lr=lr, betas=(beta1, beta2), weight_decay=weight_decay)
    elif SGD:
        optimizer = optim.SGD(params=trans_unet.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)

    if loss == "bce" or loss == "BCE":
        criterion = BCE(name="BCEWithLogitsLoss")
    elif loss == "dice":
        criterion = DiceLoss(smooth=loss_smooth)
    elif loss == "focal":
        criterion = FocalLoss(alpha=alpha_focal, gamma=gamma_focal)
    elif loss == "jaccard":
        criterion = JaccardLoss(smooth=loss_smooth)
    elif loss == "tversky":
        criterion = TverskyLoss(alpha=alpha_tversky, beta=beta_tversky, smooth=loss_smooth)

    return {
        "train_dataloader": load_dataloader()["train_dataloader"],
        "valid_dataloader": load_dataloader()["valid_dataloader"],
        "model": trans_unet,
        "optimizer": optimizer,
        "criterion": criterion,
    }


if __name__ == "__main__":
    adam, SGD = True, False
    loss = "bce"
    """
    "dice" | "focal" | "jaccard" | "tversky" | "BCE"
    """
    init = helper(
        model = None,
        lr = 2e-4,
        beta1 = 0.9,
        beta2 = 0.999,
        weight_decay = 1e-5,
        momentum = 0.9,
        adam = True,
        SGD = False,
        loss = "bce",
        loss_smooth = 1e-6,
        alpha_focal = 0.25,
        gamma_focal = 2.0,
        alpha_tversky = 0.5,
        beta_tversky = 0.5
    )

    assert init["train_dataloader"].__class__ == torch.utils.data.DataLoader
    assert init["valid_dataloader"].__class__ == torch.utils.data.DataLoader
    assert init["model"].__class__ == TransUNet
    if adam:
        assert init["optimizer"].__class__ == torch.optim.Adam
    elif SGD:
        assert init["optimizer"].__class__ == torch.optim.SGD
    
    if loss == "bce" or loss == "BCE":
        assert init["criterion"].__class__ == BCE
    elif loss == "dice":
        assert init["criterion"].__class__ == DiceLoss
    elif loss == "focal":
        assert init["criterion"].__class__ == FocalLoss
    elif loss == "tversky":
        assert init["criterion"].__class__ == TverskyLoss
    elif loss == "jaccard":
        assert init["criterion"].__class__ == JaccardLoss
    else:
        raise ValueError("Invalid loss function. Expected one of: bce, dice, focal, jaccard, tversky.".capitalize())
