In [None]:
import os
import sys
import cv2
import math
import torch
import zipfile
import argparse
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
from torchview import draw_graph
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

# Utils

In [None]:
import os
import sys
import yaml
import joblib
import torch
import traceback

sys.path.append("./src/")


def dump(value=None, filename=None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)

    else:
        raise ValueError("Both value and filename are required".capitalize())


def load(filename=None):
    if filename is not None:
        return joblib.load(filename=filename)

    else:
        raise ValueError("Filename is required".capitalize())


def device_init(device="cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")


def weight_init(m):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


def config():
    with open("../../config.yml", "r") as file:
        return yaml.safe_load(file)


def clean():
    config_files = config()

    FILES_PATH = config_files["path"]["FILES_PATH"]
    TRAIN_IMAGES = config_files["path"]["TRAIN_IMAGES"]
    TEST_IMAGE = config_files["path"]["TEST_IMAGE"]
    TRAIN_MODELS = config_files["path"]["TRAIN_MODELS"]
    BEST_MODEL = config_files["path"]["BEST_MODEL"]
    METRICS_PATH = config_files["path"]["METRICS_PATH"]

    for path in [
        FILES_PATH,
        TRAIN_IMAGES,
        TEST_IMAGE,
        TRAIN_MODELS,
        BEST_MODEL,
        METRICS_PATH,
    ]:
        if os.path.exists(path):
            for file in os.listdir(path):
                os.remove(path=os.path.join(path, file))

            print(f"Deleted all files in {path}".capitalize())

        else:
            raise FileNotFoundError(f"{path} does not exist".capitalize())


# Dataloader



In [None]:
class Loader:
    def __init__(
        self,
        image_path=None,
        image_size: int = 128,
        batch_size: int = 16,
        split_size: float = 0.2,
    ):
        super(Loader, self).__init__()

        self.image_path = image_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.train_images = list()
        self.valid_images = list()

        self.train_masks = list()
        self.valid_masks = list()

    def unzip_folder(self):
        if os.path.exists(self.image_path):
            with zipfile.ZipFile(self.image_path, "r") as zip_file:
                zip_file.extractall(path=config()["path"]["RAW_PATH"])

        else:
            raise FileNotFoundError(
                "Image path not found in the Loader class".capitalize()
            )

    def image_transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def mask_transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize((0.5,), (0.5,)),
            ]
        )

    def split_dataset(self, X: list, y: list):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=self.split_size, random_state=42
            )

            return {
                "X_train": X_train,
                "X_test": X_test,
                "y_train": y_train,
                "y_test": y_test,
            }

        else:
            raise TypeError("X and y must be of type list".capitalize())

    def feature_extractor(self):
        self.directory = os.path.join(config()["path"]["RAW_PATH"], "dataset")
        self.train_directory = os.path.join(self.directory, "train")
        self.valid_directory = os.path.join(self.directory, "test")

        for directory in tqdm([self.train_directory, self.valid_directory]):
            self.images = os.path.join(directory, "image")
            self.masks = os.path.join(directory, "mask")

            for image in os.listdir(self.images):
                mask = os.path.join(self.masks, image)
                image = os.path.join(self.images, image)

                image_name = image.split("/")[-1]
                mask_name = mask.split("/")[-1]

                if image_name == mask_name:
                    extracted_image = cv2.imread(image)
                    extracted_mask = cv2.imread(mask)

                    extracted_image = cv2.cvtColor(extracted_image, cv2.COLOR_BGR2RGB)
                    # extracted_mask = cv2.cvtColor(extracted_mask, cv2.COLOR_RGB2GRAY)
                    extracted_mask = cv2.cvtColor(extracted_mask, cv2.COLOR_BGR2RGB)

                    extracted_image = Image.fromarray(extracted_image)
                    extracted_mask = Image.fromarray(extracted_mask)

                    extracted_image = self.image_transforms()(extracted_image)
                    # extracted_mask = self.mask_transforms()(extracted_mask)
                    extracted_mask = self.image_transforms()(extracted_mask)

                    if directory.split("/")[-1] == "train":
                        self.train_images.append(extracted_image)
                        self.train_masks.append(extracted_mask)

                    elif directory.split("/")[-1] == "test":
                        self.valid_images.append(extracted_image)
                        self.valid_masks.append(extracted_mask)

                else:
                    print("Image and mask names do not match".capitalize())

        assert len(self.train_images) == len(
            self.train_masks
        ), "Number of images and masks do not match".capitalize()
        assert len(self.valid_images) == len(
            self.valid_masks
        ), "Number of images and masks do not match".capitalize()

        try:
            dataset = self.split_dataset(X=self.train_images, y=self.train_masks)

        except TypeError as e:
            print("An error occurred while splitting the dataset: ", e)
        except Exception as e:
            print("An error occurred while splitting the dataset: ", e)

        else:
            return dataset, {
                "valid_images": self.valid_images,
                "valid_masks": self.valid_masks,
            }

    def create_dataloader(self):
        train_dataset, valid_dataset = self.feature_extractor()

        train_dataloader = DataLoader(
            dataset=list(zip(train_dataset["X_train"], train_dataset["y_train"])),
            batch_size=self.batch_size,
            shuffle=True,
        )
        test_dataloader = DataLoader(
            dataset=list(zip(train_dataset["X_test"], train_dataset["y_test"])),
            batch_size=self.batch_size,
            shuffle=False,
        )
        valid_datalader = DataLoader(
            dataset=list(
                zip(valid_dataset["valid_images"], valid_dataset["valid_masks"])
            ),
            batch_size=self.batch_size * 4,
            shuffle=False,
        )

        for filename, value in [
            ("train_dataloader", train_dataloader),
            ("test_dataloader", test_dataloader),
            ("valid_dataloader", valid_datalader),
        ]:
            dump(
                value=value,
                filename=os.path.join(config()["path"]["PROCESSED_PATH"], filename)
                + ".pkl",
            )

        print("Dataloader is saved in the folder of {}".format("./data/processed/"))

    @staticmethod
    def display_images():
        dataset = load(
            os.path.join(config()["path"]["PROCESSED_PATH"], "train_dataloader.pkl")
        )

        images, maks = next(iter(dataset))

        number_of_rows = int(math.sqrt(images.size(0)))
        number_of_columns = int(images.size(0) // number_of_rows)

        plt.figure(figsize=(10, 10))

        plt.suptitle("Images and Masks".capitalize())

        for index, image in enumerate(images):
            image = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
            mask = maks[index].permute(1, 2, 0).squeeze().detach().cpu().numpy()

            image = (image - image.min()) / (image.max() - image.min())
            mask = (mask - mask.min()) / (mask.max() - mask.min())

            plt.subplot(2 * number_of_rows, 2 * number_of_columns, 2 * index + 1)
            plt.imshow(image)
            plt.title("Image")
            plt.axis("off")

            plt.subplot(2 * number_of_rows, 2 * number_of_columns, 2 * index + 2)
            plt.imshow(mask, cmap="gray")
            plt.title("Mask")
            plt.axis("off")

        plt.tight_layout()
        plt.savefig(os.path.join(config()["path"]["FILES_PATH"], "image.png"))
        plt.show()

        print(
            "Images saved in the folder: {}".format(
                config()["path"]["FILES_PATH"]
            ).capitalize()
        )

# Scaled Dot Product

In [None]:
def scaled_dot_product_attention(
    query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, channels: int
):
    if (
        isinstance(query, torch.Tensor)
        and isinstance(key, torch.Tensor)
        and isinstance(value, torch.Tensor)
    ):
        result = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(channels)

        assert result.size() == (
            query.size(0),
            query.size(1),
            query.size(2),
            query.size(2),
        ), "result size is not correct".capitalize()

        result = torch.softmax(result, dim=-1)

        attention = torch.matmul(result, value)

        assert attention.size() == (
            query.size(0),
            query.size(1),
            query.size(2),
            value.size(3),
        ), "attention size is not correct".capitalize()

        return attention

# MultiHead Attention Layer

In [None]:
class MultiHeadAttentionLayer(nn.Module):
    def __init__(self, channels: int = 128, nheads: int = 8, bias: bool = True):
        super(MultiHeadAttentionLayer, self).__init__()

        self.channels = channels
        self.nheads = nheads
        self.bias = bias

        assert (
            self.channels % self.nheads == 0
        ), "Channels must be divisible by number of heads".capitalize()

        self.kernel_size = 1
        self.stride = 1
        self.padding = 0

        self.QKV = nn.Conv2d(
            in_channels=self.channels,
            out_channels=3 * self.channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            bias=self.bias,
        )

        self.layers = nn.Conv2d(
            in_channels=self.channels,
            out_channels=self.channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            bias=self.bias,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            QKV = self.QKV(x)

            self.query, self.key, self.value = torch.chunk(input=QKV, chunks=3, dim=1)

            assert (
                self.query.size() == self.key.size() == self.value.size()
            ), "QKV must have the same size".capitalize()

            self.query = self.query.view(
                self.query.size(0),
                self.nheads,
                self.channels // self.nheads,
                self.query.size(2) * self.query.size(3),
            )

            self.key = self.key.view(
                self.key.size(0),
                self.nheads,
                self.channels // self.nheads,
                self.key.size(2) * self.key.size(3),
            )

            self.value = self.value.view(
                self.value.size(0),
                self.nheads,
                self.channels // self.nheads,
                self.value.size(2) * self.value.size(3),
            )

            self.attention = scaled_dot_product_attention(
                query=self.query, key=self.key, value=self.value, channels=self.channels
            )

            assert (
                self.attention.size()
                == self.query.size()
                == self.key.size()
                == self.value.size()
            ), "Attention output must have the same size as QKV"

            self.attention = self.attention.view(
                self.attention.size(0),
                self.attention.size(1) * self.attention.size(2),
                self.attention.size(3) // self.channels,
                self.attention.size(3) // self.channels,
            )

            return self.layers(self.attention)

# FeedForward Network

In [None]:
class FeedForwardNeuralNetwork(nn.Module):
    def __init__(
        self,
        channels: int = 128,
        dropout: float = 0.1,
        activation: str = "relu",
        bias: bool = True,
    ):
        super(FeedForwardNeuralNetwork, self).__init__()

        self.channels = channels
        self.dropout = dropout
        self.activation = activation
        self.bias = bias

        self.in_channels = self.channels
        self.out_channels = 3 * self.channels

        self.kernel_size = 1
        self.stride = 1
        self.padding = 0

        if activation == "leaky_relu":
            self.activation = nn.LeakyReLU(inplace=True, negative_slope=0.2)
        elif activation == "gelu":
            self.activation = nn.GELU()
        else:
            self.activation = nn.ReLU()

        self.layers = []

        for index in range(2):
            self.layers.append(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride,
                    padding=self.padding,
                    bias=self.bias,
                )
            )
            if index == 0:
                self.layers.append(self.activation)
                self.layers.append(nn.Dropout2d(p=self.dropout))

            self.in_channels = self.out_channels
            self.out_channels = channels

        self.model = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.model(x)

        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())

# Encoder Block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(
        self, in_channels: int = 128, out_channels: int = 256, batch_norm: bool = True
    ):
        super(EncoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_batchnorm = batch_norm

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1

        self.layers = list()

        self.encoder_block = self.layers.append(
            nn.Sequential(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                ),
                nn.ReLU(inplace=True),
                nn.Conv2d(
                    in_channels=self.out_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size - 1,
                    stride=self.stride_size // self.stride_size,
                    padding=self.padding_size,
                ),
            )
        )

        if self.use_batchnorm:
            self.layers.append(nn.BatchNorm2d(num_features=self.out_channels))

        self.encoder_block = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.encoder_block(x)

        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())

# Decoder Block

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
        self, in_channels: int = 1024, out_channels: int = 512, batchnorm: bool = True
    ):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_batchnorm = batchnorm

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1

        self.layers = []

        self.decoder_block = self.layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                ),
                nn.ReLU(inplace=True),
                nn.Conv2d(
                    in_channels=self.out_channels,
                    out_channels=self.out_channels,
                    kernel_size=3,
                    stride=1,
                    padding=1,
                ),
            )
        )

        if self.use_batchnorm:
            self.layers.append(nn.BatchNorm2d(num_features=self.out_channels))

        self.decoder_block = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.decoder_block(x)

        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())

# Attention CNN Block

In [None]:
class attentionCNNBlock(nn.Module):
    def __init__(
        self,
        channels: int = 128,
        nheads: int = 8,
        dropout: float = 0.1,
        activation: str = "relu",
        bias: bool = True,
    ):
        super(attentionCNNBlock, self).__init__()

        self.channels = channels
        self.nheads = nheads
        self.dropout = dropout
        self.activation = activation
        self.bias = bias

        self.multihead_attention = MultiHeadAttentionLayer(
            channels=self.channels,
            nheads=self.nheads,
            bias=self.bias,
        )

        self.feedforward_network = FeedForwardNeuralNetwork(
            channels=self.channels,
            dropout=self.dropout,
            activation=self.activation,
            bias=self.bias,
        )

        self.batch_norm = nn.BatchNorm2d(num_features=self.channels)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            residual = x

            x = self.multihead_attention(x=x)
            x = torch.dropout(input=x, p=self.dropout, train=self.training)
            x = torch.add(x, residual)
            x = self.batch_norm(x)

            residual = x

            x = self.feedforward_network(x=x)
            x = torch.dropout(input=x, p=self.dropout, train=self.training)
            x = torch.add(x, residual)
            x = self.batch_norm(x)

            return x

        else:
            raise ValueError("Input must be a torch.Tensor".capitalize())

# Attention CNN

In [None]:
class attentionCNN(nn.Module):
    def __init__(
        self,
        image_channels: int = 3,
        image_size: int = 128,
        nheads: int = 8,
        dropout: float = 0.1,
        num_layers: int = 8,
        activation: str = "relu",
        bias: bool = True,
    ):
        super(attentionCNN, self).__init__()

        self.image_channels = image_channels
        self.image_size = image_size
        self.nheads = nheads
        self.dropout = dropout
        self.num_layers = num_layers
        self.activation = activation
        self.bias = bias

        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1

        self.input_block = nn.Conv2d(
            in_channels=self.image_channels,
            out_channels=self.image_size,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=self.bias,
        )

        self.attention_cnn_block = nn.Sequential(
            *[
                attentionCNNBlock(
                    channels=self.image_size,
                    nheads=self.nheads,
                    dropout=self.dropout,
                    activation=self.activation,
                    bias=self.bias,
                )
                for _ in tqdm(range(self.num_layers))
            ]
        )

        self.encoder1 = EncoderBlock(
            in_channels=self.image_size,
            out_channels=self.image_size * 2,
            batch_norm=True,
        )
        self.encoder2 = EncoderBlock(
            in_channels=self.image_size * 2,
            out_channels=self.image_size * 4,
            batch_norm=True,
        )
        self.encoder3 = EncoderBlock(
            in_channels=self.image_size * 4,
            out_channels=self.image_size * 8,
            batch_norm=False,
        )

        self.decoder1 = DecoderBlock(
            in_channels=self.image_size * 8,
            out_channels=self.image_size * 4,
            batchnorm=True,
        )
        self.decoder2 = DecoderBlock(
            in_channels=self.image_size * 8,
            out_channels=self.image_size * 2,
            batchnorm=True,
        )
        self.decoder3 = DecoderBlock(
            in_channels=self.image_size * 4,
            out_channels=self.image_size,
            batchnorm=True,
        )

        self.output_block = nn.Conv2d(
            in_channels=self.image_size,
            # out_channels=self.image_channels // self.image_channels,
            out_channels=self.image_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=self.bias,
        )

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            torch.autograd.set_detect_anomaly(True)
            x = self.input_block(x)
            x = self.attention_cnn_block(x)

            encoder1 = self.encoder1(x)
            encoder2 = self.encoder2(encoder1)
            encoder3 = self.encoder3(encoder2)

            decoder1 = self.decoder1(encoder3)
            decoder1 = torch.cat((decoder1, encoder2), dim=1)

            decoder2 = self.decoder2(decoder1)
            decoder2 = torch.cat((decoder2, encoder1), dim=1)

            decoder3 = self.decoder3(decoder2)

            output = self.output_block(decoder3)

            # return torch.sigmoid(input=output)
            return torch.tanh(output)

        else:
            raise ValueError("Input must be a torch.Tensor")

    @staticmethod
    def total_params(model=None):
        if isinstance(model, attentionCNN):
            return sum(params.numel() for params in model.parameters())

        else:
            raise TypeError("Model should be attentionCNN".capitalize())

# BCE Loss

In [None]:
class BinaryCrossEntropyLoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(BinaryCrossEntropyLoss, self).__init__()

        self.reduction = reduction

        self.loss = nn.BCEWithLogitsLoss(reduction=self.reduction)

    def forward(self, pred: torch.Tensor, target: torch.Tensor):
        if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
            return self.loss(pred, target)

        else:
            raise TypeError("pred and target must be torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="BCELoss for attentionCNN".title())
    parser.add_argument(
        "--reduction", type=str, default="mean", help="mean or sum or none".capitalize()
    )

    loss = BinaryCrossEntropyLoss()

    predicted = torch.tensor([1.0, 0.0, 1.0, 0.0])
    target = torch.tensor([1.0, 0.0, 1.0, 0.0])

    assert loss(predicted, target).size() == torch.Size(
        []
    ), "BCELoss is not working".capitalize()


# Dice Loss

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth: float = 1e-3):
        super(DiceLoss, self).__init__()

        self.name = "DiceLoss".title()
        self.smooth = smooth

    def forward(self, pred: torch.Tensor, target: torch.Tensor):
        if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
            pred = pred.contiguous().view(-1)
            target = target.contiguous().view(-1)

            dice_coefficient = (2 * (pred * target).sum()) / (
                pred.sum() + target.sum() + self.smooth
            ).mean()

            return 1 - dice_coefficient

        else:
            raise TypeError("pred and target must be torch.Tensor".capitalize())

# Focal Loss

In [None]:
class FocalLoss(nn.Module):
    def __init__(self, alpha: float = 0.25, gamma: int = 2):
        super(FocalLoss, self).__init__()

        self.alpha = alpha
        self.gamma = gamma

    def forward(self, predicted: torch.Tensor, target: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(target, torch.Tensor):
            predicted = predicted.view(-1)
            target = target.view(-1)

            BCELoss = nn.BCELoss()(predicted, target)
            pt = torch.exp(-BCELoss)

            return (self.alpha * (1 - pt) ** self.gamma * BCELoss).mean()
        else:
            raise TypeError("Predicted and target must be torch.Tensor".capitalize())

# IoU Loss

In [None]:
class IoULoss(nn.Module):
    def __init__(self, smooth: float = 1e-6):
        super(IoULoss, self).__init__()

        self.name = "Iou Loss".title()
        self.smooth = smooth

    def forward(self, predicted: torch.Tensor, target: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(target, torch.Tensor):
            predicted = predicted.view(-1)
            target = target.view(-1)

            return 1 - (predicted * target).sum() / (
                predicted.sum()
                + target.sum()
                - (predicted * target).sum()
                + self.smooth
            )

        else:
            raise TypeError("Predicted and Target must be torch.Tensor".capitalize())

# Tversky Loss

In [None]:
class TverskyLoss(nn.Module):
    def __init__(self, name: str = "TveskyLoss"):
        super(TverskyLoss, self).__init__()

        self.name = name

    def forward(self, predicted: torch.Tensor, target: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(target, torch.Tensor):
            TP = torch.sum(predicted * target)
            FP = torch.sum(predicted * (1 - target))
            FN = torch.sum((1 - predicted) * target)

            return 1 - (TP / (TP + 0.5 * (FP + FN)))

        else:
            raise TypeError("Predicted and target must be torch.Tensor".capitalize())

# Combo Loss

In [None]:
class ComboLoss(nn.Module):
    def __init__(
        self,
        alpha: float = 0.5,
        gamma: int = 2,
        smooth: float = 1e-4,
        reduction: str = "mean",
    ):
        super(ComboLoss, self).__init__()

        self.alpha = alpha
        self.gamma = gamma
        self.smooth = smooth
        self.reduction = reduction

        self.dice_loss = DiceLoss(smooth=self.smooth)
        self.focal_loss = FocalLoss(alpha=self.alpha, gamma=self.gamma)
        self.bce_loss = BinaryCrossEntropyLoss(reduction=self.reduction)

    def forward(self, predicted: torch.Tensor, target: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(target, torch.Tensor):

            predicted = predicted.contiguous().view(-1)
            target = target.contiguous().view(-1)

            return (
                self.dice_loss(predicted, target)
                + self.focal_loss(predicted, target)
                + self.bce_loss(predicted, target)
            ).mean()
        else:
            raise TypeError("Predicted and target must be torch.Tensor".capitalize())

# MSE Loss

In [None]:
class MeanSquaredLoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(MeanSquaredLoss, self).__init__()

        self.reduction = reduction

        self.loss = nn.MSELoss(reduction=self.reduction)

    def forward(self, pred: torch.Tensor, target: torch.Tensor):
        if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
            return self.loss(pred, target)

        else:
            raise TypeError("pred and target must be torch.Tensor".capitalize())

# Helpers

In [None]:
import os
import torch.optim as optim

def load_dataloader():
    if os.path.exists(config()["path"]["PROCESSED_PATH"]):
        train_dataloader = os.path.join(
            config()["path"]["PROCESSED_PATH"], "train_dataloader.pkl"
        )
        valid_dataloader = os.path.join(
            config()["path"]["PROCESSED_PATH"], "valid_dataloader.pkl"
        )
        test_dataloader = os.path.join(
            config()["path"]["PROCESSED_PATH"], "test_dataloader.pkl"
        )

        return {
            "train_dataloader": load(filename=train_dataloader),
            "valid_dataloader": load(filename=valid_dataloader),
            "test_dataloader": load(filename=test_dataloader),
        }

    else:
        raise FileNotFoundError(
            "dataloader cannot be imported from the helper method".capitalize()
        )


def helper(**kwargs):
    model = kwargs["model"]
    lr = kwargs["lr"]
    beta1 = kwargs["beta1"]
    beta2 = kwargs["beta2"]
    momentum = kwargs["momentum"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]
    loss = kwargs["loss"]
    smooth = kwargs["smooth"]
    alpha = kwargs["alpha"]
    gamma = kwargs["gamma"]

    if model is None:
        model = attentionCNN(
            image_channels=config()["attentionCNN"]["image_channels"],
            image_size=config()["attentionCNN"]["image_size"],
            nheads=config()["attentionCNN"]["nheads"],
            dropout=config()["attentionCNN"]["dropout"],
            num_layers=config()["attentionCNN"]["num_layers"],
            activation=config()["attentionCNN"]["activation"],
            bias=True,
        )

    if adam:
        optimizer = optim.Adam(
            params=model.parameters(),
            lr=lr,
            betas=(beta1, beta2),
            weight_decay=config()["Trainer"]["weight_decay"],
        )
    elif SGD:
        optimizer = optim.SGD(params=model.parameters(), lr=lr, momentum=momentum)

    try:
        dataloader = load_dataloader()
    except FileNotFoundError as e:
        print("An error is occurred in the file ", e)
    except Exception as e:
        print("An error is occurred in the file ", e)

    if loss == "dice":
        loss = DiceLoss(smooth=smooth)
    elif loss == "focal":
        loss = FocalLoss(alpha=alpha, gamma=gamma)
    elif loss == "IoU":
        loss = IoULoss(smooth=smooth)
    elif loss == "tversky":
        loss = TverskyLoss()
    elif loss == "mse":
        loss = MeanSquaredLoss()
    else:
        loss = BinaryCrossEntropyLoss(reduction="mean")

    return {
        "train_dataloader": dataloader["train_dataloader"],
        "valid_dataloader": dataloader["valid_dataloader"],
        "test_dataloader": dataloader["test_dataloader"],
        "model": model,
        "optimizer": optimizer,
        "loss": loss,
    }

# Trainer

In [None]:
import torch
import mlflow
import dagshub
import traceback
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import save_image
from torch.optim.lr_scheduler import StepLR

class Trainer:
    def __init__(
        self,
        model=None,
        epochs: int = 100,
        lr: float = 0.0001,
        beta1: float = 0.5,
        beta2: float = 0.999,
        momentum: float = 0.90,
        adam: bool = True,
        SGD: bool = False,
        loss="bce",
        smooth: float = 1e-4,
        alpha: float = 0.25,
        gamma: int = 2,
        step_size: int = 20,
        device: str = "cuda",
        lr_scheduler: bool = False,
        l1_regularization: bool = False,
        l2_regularization: bool = False,
        elasticnet_regularization: bool = False,
        is_weight_init: bool = False,
        is_mlflow: bool = False,
        verbose: bool = True,
    ):
        self.model = model
        self.epochs = epochs
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.momentum = momentum
        self.adam = adam
        self.SGD = SGD
        self.loss = loss
        self.smooth = smooth
        self.alpha = alpha
        self.gamma = gamma
        self.step_size = step_size
        self.device = device
        self.lr_scheduler = lr_scheduler
        self.l1_regularization = l1_regularization
        self.l2_regularization = l2_regularization
        self.elasticnet_regularization = elasticnet_regularization
        self.is_weight_init = is_weight_init
        self.is_mlflow = is_mlflow
        self.verbose = verbose

        self.device = device_init(
            device=self.device,
        )

        self.init = helper(
            model=self.model,
            lr=self.lr,
            beta1=self.beta1,
            beta2=self.beta2,
            momentum=self.momentum,
            adam=self.adam,
            SGD=self.SGD,
            loss=self.loss,
            smooth=self.smooth,
            alpha=self.alpha,
            gamma=self.gamma,
        )

        self.train_dataloader = self.init["train_dataloader"]
        self.valid_dataloader = self.init["valid_dataloader"]
        self.test_dataloader = self.init["test_dataloader"]

        self.model = self.init["model"]
        self.model = self.model.to(self.device)

        self.optimizer = self.init["optimizer"]
        self.criterion = self.init["loss"]

        assert (
            self.init["train_dataloader"].__class__
        ) == torch.utils.data.dataloader.DataLoader, (
            "train_dataloader is not a dataloader".capitalize()
        )
        assert (
            self.init["valid_dataloader"].__class__
        ) == torch.utils.data.dataloader.DataLoader, (
            "valid_dataloader is not a dataloader".capitalize()
        )
        assert (
            self.init["test_dataloader"].__class__
        ) == torch.utils.data.dataloader.DataLoader, (
            "test_dataloader is not a dataloader".capitalize()
        )
        assert (
            self.init["model"].__class__
        ) == attentionCNN, "model is not a model".capitalize()

        if self.lr_scheduler:
            self.scheduler = StepLR(
                optimizer=self.optimizer, step_size=self.step_size, gamma=self.gamma
            )

        if self.is_weight_init:
            self.model.apply(weight_init)

        if self.is_mlflow:
            dagshub.init(
                repo_owner=config()["MLFlow"]["MLFLOW_USERNAME"],
                repo_name=config()["MLFlow"]["MLFLOW_REPONAME"],
                mlflow=False,
            )

        try:
            clean()
        except FileNotFoundError as e:
            print("An error occurred: " + str(e))
            traceback.print_exc()
        except Exception as e:
            print("An error occurred: " + str(e))
            traceback.print_exc()

        self.loss = float("inf")

        self.model_history = {"train_loss": [], "test_loss": []}

        experiment_id = mlflow.create_experiment(
            config()["MLFlow"]["MLFLOW_EXPERIMENT_NAME"]
        )
        mlflow.set_experiment(experiment_id=experiment_id)

    def l1_loss(self, model=None):
        if isinstance(model, attentionCNN):
            return 0.01 * (torch.norm(params, 1) for params in model.parameters())
        else:
            raise ValueError("model is not a model".capitalize())

    def l2_loss(self, model=None):
        if isinstance(model, attentionCNN):
            return 0.01 * (torch.norm(params, 2) for params in model.parameters())
        else:
            raise ValueError("model is not a model".capitalize())

    def elasticnet_loss(self, model=None):
        if isinstance(model, attentionCNN):
            return 0.01 * (
                torch.norm(params, 1) + torch.norm(params, 2)
                for params in model.parameters()
            )
        else:
            raise ValueError("model is not a model".capitalize())

    def saved_checkpoints(self, **kwargs):
        try:
            epoch = kwargs["epoch"]
            train_loss = kwargs["train_loss"]
            valid_loss = kwargs["valid_loss"]
        except KeyError:
            raise ValueError(
                "Missing required arguments: 'epoch', 'train_loss', or 'valid_loss'"
            )
        else:
            if self.loss > valid_loss:
                self.loss = valid_loss
                torch.save(
                    {
                        "model": self.model.state_dict(),
                        "train_loss": train_loss,
                        "valid_loss": valid_loss,
                        "epoch": epoch,
                    },
                    os.path.join(config()["path"]["BEST_MODEL"], "best_model.pth"),
                )

            torch.save(
                self.model.state_dict(),
                os.path.join(
                    config()["path"]["TRAIN_MODELS"], f"model_epoch_{epoch}.pth"
                ),
            )

    def updated_training_model(self, **kwargs):
        self.optimizer.zero_grad()

        try:
            image = kwargs["image"]
            mask = kwargs["mask"]
        except KeyError:
            raise ValueError("Missing a dataloader".capitalize())

        else:
            predicted = self.model(image)

            loss = self.criterion(predicted, mask)

            if self.l1_regularization:
                loss += self.l1_loss(model=self.model)
            elif self.l2_regularization:
                loss += self.l2_loss(model=self.model)
            elif self.elasticnet_regularization:
                loss += self.elasticnet_loss(predicted, mask)
            else:
                loss = loss

            loss.backward()
            self.optimizer.step()

            return loss.item()

    def saved_training_images(self, **kwargs):
        try:
            epoch = kwargs["epoch"]
        except KeyError:
            raise ValueError("Missing a dataloader".capitalize())
        else:
            images, mask = next(iter(self.test_dataloader))
            images = images.to(self.device)

            predicted = self.model(images)

            save_image(
                predicted,
                os.path.join(
                    config()["path"]["TRAIN_IMAGES"], "image{}.png".format(epoch)
                ),
                normalize=True,
            ),
            save_image(
                mask,
                os.path.join(
                    config()["path"]["TRAIN_IMAGES"], "real_image{}.png".format(epoch)
                ),
                normalize=True,
            ),

    def display_progress(self, **kwargs):
        try:
            epoch = kwargs["epoch"]
            train_loss = kwargs["train_loss"]
            valid_loss = kwargs["valid_loss"]
        except KeyError:
            raise ValueError("Missing a dataloader".capitalize())

        else:
            if self.verbose:
                print(
                    "Epochs: [{}/{}] - train_loss: {:.4f} - valid_loss: {:.4f}".format(
                        epoch, self.epochs, train_loss, valid_loss
                    )
                )
            else:
                print(
                    "Epochs: [{}/{}] is completed".format(
                        epoch, self.epochs
                    ).capitalize()
                )

    def train(self):
        with mlflow.start_run():
            for epoch in tqdm(range(self.epochs)):
                train_loss = []
                valid_loss = []

                for _, (image, mask) in enumerate(self.train_dataloader):
                    image = image.to(self.device)
                    mask = mask.to(self.device)

                    try:
                        train_loss.append(
                            self.updated_training_model(image=image, mask=mask)
                        )
                    except KeyError as e:
                        print("An error occured: {}".format(e))
                    except Exception as e:
                        print("An error occured: {}".format(e))

                for _, (image, mask) in enumerate(self.test_dataloader):
                    image = image.to(self.device)
                    mask = mask.to(self.device)

                    predicted = self.model(image)

                    valid_loss.append(self.criterion(predicted, mask).item())

                if self.lr_scheduler:
                    self.scheduler.step()

                try:
                    self.display_progress(
                        epoch=epoch + 1,
                        train_loss=np.mean(train_loss),
                        valid_loss=np.mean(valid_loss),
                    )
                except KeyError as e:
                    print("An error occured: {}".format(e).capitalize())
                except Exception as e:
                    print("An error occured: {}".format(e).capitalize())

                try:
                    self.saved_training_images(epoch=epoch)
                except KeyError as e:
                    print("An error occured: {}".format(e).capitalize())
                except Exception as e:
                    print("An error occured: {}".format(e).capitalize())

                try:
                    self.saved_training_images(epoch=epoch)
                except KeyError as e:
                    print("An error occured: {}".format(e).capitalize())
                except Exception as e:
                    print("An error occured: {}".format(e).capitalize())

                try:
                    self.saved_checkpoints(
                        epoch=epoch,
                        train_loss=np.mean(train_loss),
                        valid_loss=np.mean(valid_loss),
                    )
                except KeyError as e:
                    print("An error occured in : {}".format(e).capitalize())
                except Exception as e:
                    print("An error occuredin in: {}".format(e).capitalize())

                try:
                    self.model_history["train_loss"].append(np.mean(train_loss))
                    self.model_history["test_loss"].append(np.mean(valid_loss))
                except Exception as e:
                    print("An error occured: {}".format(e).capitalize())

                try:
                    mlflow.log_params(
                        {
                            "channels": str(config()["attentionCNN"]["image_channels"]),
                            "image_size": str(config()["attentionCNN"]["image_size"]),
                            "nheads": str(config()["attentionCNN"]["nheads"]),
                            "dropout": str(config()["attentionCNN"]["dropout"]),
                            "num_layers": str(config()["attentionCNN"]["num_layers"]),
                            "activation": str(config()["attentionCNN"]["activation"]),
                            "bias": str(config()["attentionCNN"]["bias"]),
                            "num_epochs": str(self.epochs),
                            "lr": str(self.lr),
                            "beta1": str(self.beta1),
                            "beta2": str(self.beta2),
                            "momentum": str(self.momentum),
                            "adam": str(self.adam),
                            "SGD": str(self.SGD),
                            "optimizer": self.optimizer,
                            "smooth": str(self.smooth),
                            "alpha": str(self.alpha),
                            "gamma": str(self.gamma),
                            "step_size": str(self.step_size),
                            "device": str(self.device),
                            "lr_scheduler": self.lr_scheduler,
                            "l1_regularization": str(self.l1_regularization),
                            "l2_regularization": str(self.l2_regularization),
                            "elasticnet_regularization": str(
                                self.elasticnet_regularization
                            ),
                            "is_weight_int": str(self.is_weight_int),
                            "is_mlflow": str(self.is_mlflow),
                            "verbose": str(self.verbose),
                        }
                    )

                    mlflow.log_metric("train_loss", np.mean(train_loss), step=epoch + 1)
                    mlflow.log_metric("val_loss", np.mean(valid_loss), step=epoch + 1)

                except Exception as e:
                    print("An error occured in the MLFflow log_params: {}".format(e))
                    traceback.print_exc()

            dump(
                value=self.model_history,
                filename=os.path.join(config()["path"]["METRICS_PATH"], "history.pkl"),
            )
            print(
                "Model history saved in the folder {}".format(
                    config()["path"]["METRICS_PATH"]
                ).capitalize()
            )

            try:
                mlflow.pytorch.log_model(self.model, "attentionCNNModel")
            except Exception as e:
                print("An error occured in the MLFflow log_params: {}".format(e))
                traceback.print_exc()

    @staticmethod
    def display_history():
        metrics_path = config()["path"]["METRICS_PATH"]
        history = load(filename=os.path.join(metrics_path, "history.pkl"))

        fig, axes = plt.subplots(1, 2, figsize=(15, 5))

        fig.suptitle("Loss Evolution")

        axes[0].plot(history["train_loss"], label="Train Loss", color="b")
        axes[1].plot(history["test_loss"], label="Test Loss", color="r")

        axes[0].grid(True, which="both", linestyle="--", linewidth=0.5)
        axes[1].grid(True, which="both", linestyle="--", linewidth=0.5)

        axes[0].set_title("Train Loss")
        axes[1].set_title("Test Loss")

        axes[0].set_xlabel("Epochs")
        axes[0].set_ylabel("Loss")

        axes[1].set_xlabel("Epochs")
        axes[1].set_ylabel("Loss")

        axes[0].legend()
        axes[1].legend()

        plt.legend()
        plt.savefig(os.path.join(metrics_path, "loss_evolution.png"))
        plt.show()

# Tester

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix


class Tester:
    def __init__(self, data: str = "test", device: str = "cuda"):
        self.data = data
        self.device = device

        self.device = device_init(
            device=self.device,
        )

    def load_dataloader(self):
        if self.data == "test":
            test_dataloader = os.path.join(
                config()["path"]["PROCESSED_PATH"], "test_dataloader.pkl"
            )

            test_dataloader = load(filename=test_dataloader)

            return test_dataloader

        elif self.data == "valid":
            valid_dataloader = os.path.join(
                config()["path"]["PROCESSED_PATH"], "valid_dataloader.pkl"
            )
            valid_dataloader = load(filename=valid_dataloader)

            return valid_dataloader

        else:
            raise ValueError("Invalid data type".capitalize())

    def compute_iou(self, predicted_mask, real_mask, threshold=0.4):
        """
        Compute the Intersection over Union (IoU) between the predicted mask and the real mask using PyTorch.

        Parameters:
        - predicted_mask: A 2D or 3D PyTorch tensor with predicted mask values (continuous or binary).
        - real_mask: A 2D or 3D PyTorch tensor with the real mask values (binary).
        - threshold: A float value to threshold the predicted mask if it contains continuous values.

        Returns:
        - iou: The IoU score as a float.
        """

        binary_pred_masks = (predicted_mask >= threshold).to(torch.bool)
        binary_real_masks = (real_mask >= threshold).to(torch.bool)

        intersection = torch.logical_and(binary_pred_masks, binary_real_masks).sum(
            dim=(1, 2, 3)
        )
        union = torch.logical_or(binary_pred_masks, binary_real_masks).sum(
            dim=(1, 2, 3)
        )

        iou = torch.where(
            union == 0,
            torch.ones_like(intersection, dtype=torch.float),
            intersection.float() / union.float(),
        )

        mean_iou = iou.mean().item()

        return mean_iou

    def compute_dice_score(self, y_pred, y_true, threshold=0.5):
        """
        Compute the Dice Score for a batch of predicted masks and real masks using PyTorch.

        Parameters:
        - y_pred: A 4D PyTorch tensor with predicted mask values (batch_size, channels, height, width).
        - y_true: A 4D PyTorch tensor with the real mask values (binary, same dimensions as y_pred).
        - threshold: A float value to threshold the predicted masks if they contain continuous values.

        Returns:
        - mean_dice: The average Dice score across the batch as a float.
        """

        y_pred = (y_pred >= threshold).float()
        y_true = y_true.float()

        y_pred_flat = y_pred.view(y_pred.size(0), -1)
        y_true_flat = y_true.view(y_true.size(0), -1)

        intersection = (y_pred_flat * y_true_flat).sum(dim=1)
        union = y_pred_flat.sum(dim=1) + y_true_flat.sum(dim=1)

        dice = 2 * intersection / union

        mean_dice = dice.mean().item()

        return 1 - mean_dice

    def select_model(self):
        try:
            model = attentionCNN(
                image_channels=config()["attentionCNN"]["image_channels"],
                image_size=config()["attentionCNN"]["image_size"],
                nheads=config()["attentionCNN"]["nheads"],
                dropout=config()["attentionCNN"]["dropout"],
                num_layers=config()["attentionCNN"]["num_layers"],
                activation=config()["attentionCNN"]["activation"],
                bias=config()["attentionCNN"]["bias"],
            )

        except Exception as e:
            print("An error occurred to load the model: ", e)
        else:
            return model

    def plot_images(self):
        try:
            model = self.select_model()
        except Exception as e:
            print("An error occurred to load the model: ", e)
        else:
            model = model.to(self.device)

            state_dict = torch.load(
                os.path.join(config()["path"]["BEST_MODEL"], "best_model.pth")
            )

            model.load_state_dict(state_dict["model"])

            images, mask = next(iter(self.load_dataloader()))

            num_of_rows = int(math.sqrt(images.size(0)))
            num_of_cols = images.size(0) // num_of_rows

            predicted = model(images.to(self.device))
            mask = mask.to(self.device)

            IoU = self.compute_iou(predicted_mask=predicted, real_mask=mask)

            plt.figure(figsize=(num_of_rows * 10, num_of_cols * 5))

            for index, image in enumerate(images):
                real_image = image.permute(1, 2, 0).cpu().detach().numpy()
                predicted_image = (
                    predicted[index].permute(1, 2, 0).cpu().detach().numpy()
                )
                mask_image = mask[index].permute(1, 2, 0).cpu().detach().numpy()

                real_image = (real_image - real_image.min()) / (
                    real_image.max() - real_image.min()
                )
                predicted_image = (predicted_image - predicted_image.min()) / (
                    predicted_image.max() - predicted_image.min()
                )
                mask_image = (mask_image - mask_image.min()) / (
                    mask_image.max() - mask_image.min()
                )

                plt.subplot(3 * num_of_rows, 2 * num_of_cols, 3 * index + 1)
                plt.imshow(real_image)
                plt.axis("off")
                plt.title("Real Image")

                plt.subplot(3 * num_of_rows, 2 * num_of_cols, 3 * index + 2)
                plt.imshow(predicted_image, cmap="gray")
                plt.axis("off")
                plt.title("Predicted Image")

                plt.subplot(3 * num_of_rows, 2 * num_of_cols, 3 * index + 3)
                plt.imshow(mask_image, cmap="gray")
                plt.axis("off")
                plt.title("Mask Image")

            plt.tight_layout()
            plt.savefig(
                os.path.join(
                    config()["path"]["TEST_IMAGE"],
                    (
                        "{}_result.png".format(self.data)
                        if self.data == "test"
                        else "{}_result.png".format(self.data)
                    ),
                )
            )
            plt.show()

            print(
                "IoU score # {:.4f} and result image saved the result image to {}".format(
                    os.path.join(
                        IoU,
                        config()["path"]["TEST_IMAGE"],
                        "{}_result.png".format(self.data),
                    )
                )
            )