In [1]:
import os
import sys
import cv2
import yaml
import torch
import joblib
import zipfile
import traceback
import numpy as np
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import torch.optim as optim
import matplotlib.pyplot as plt
from torchsummary import summary
from torchview import draw_graph
from torchvision import transforms
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from sklearn.model_selection import train_test_split

In [2]:
def config():
    with open("../../config.yml", "r") as file:
        config_files = yaml.safe_load(file)
        
    return config_files

In [None]:
class PathException(Exception):
    def __init__(self, message):
        super(PathException, self).__init__(message)
        self.message = message

def validate_path(path):
    if os.path.exists(path):
        return path
    else:
        traceback.print_exc()
        raise PathException("{} Path does not exist".capitalize().format(path))


def dump(value=None, filename=None):
    if (value is not None) and (filename is not None):
        joblib.dump(value, filename)

def load(filename=None):
    if filename is not None:
        return joblib.load(filename)


def device_init(device="cuda"):
    if device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    elif device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        return torch.device("cpu")


def weight_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Loader():
    def __init__(self, image_path = None, channels = 3, image_size = 256, batch_size = 1, split_size = 0.20):
        self.image_path = image_path
        self.channels = channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.images = []
        self.masks = []

        self.config = config()

        self.raw_path = validate_path(self.config["path"]["raw_path"])
        self.processed_path = validate_path(self.config["path"]["processed_path"])
        self.files_path = validate_path(self.config["path"]["files_path"])

    def transforms(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
            transforms.CenterCrop((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

    def unzip_folder(self):
        with zipfile.ZipFile(self.image_path, "r") as zip_ref:
            zip_ref.extractall(os.path.join(self.raw_path,))

        print("The unzip folder of image saved in {}".format(self.raw_path))

    def split_dataset(self, X, y):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.split_size, random_state=42)
            return {
                "X_train": X_train,
                "X_test": X_test,
                "y_train": y_train,
                "y_test": y_test
            }

        else:
            raise ValueError("X and y must be of type list".capitalize())

    def feature_extractor(self):
        self.directory = os.path.join(self.raw_path, "datasets")
        self.images_path = os.path.join(self.directory, "images")
        self.masks_path = os.path.join(self.directory, "masks")

        self.masks_list = os.listdir(self.masks_path)

        for index, image in tqdm(enumerate(os.listdir(self.images_path))):
            image_name, _ = image.split(".")
            mask_name, _ = self.masks_list[index].split(".")

            if image_name == mask_name:
                image_path = os.path.join(self.images_path, image)
                mask_path = os.path.join(self.masks_path, self.masks_list[index])

                image_X = cv2.imread(image_path)
                image_Y = cv2.imread(mask_path)

                image_X = cv2.cvtColor(image_X, cv2.COLOR_BGR2RGB)
                image_Y = cv2.cvtColor(image_Y, cv2.COLOR_BGR2RGB)

                image_X = Image.fromarray(image_X)
                image_Y = Image.fromarray(image_Y)

                image_X = self.transforms()(image_X)
                image_Y = self.transforms()(image_Y)

                self.images.append(image_X)
                self.masks.append(image_Y)

        return self.split_dataset(X=self.images, y=self.masks)

    def create_dataloader(self):
        data = self.feature_extractor()

        self.train_dataloader = DataLoader(
            dataset=list(zip(data["X_train"],data["y_train"])),
            batch_size=self.batch_size,
            shuffle=True
        
        )

        self.test_dataloader = DataLoader(
            dataset=list(zip(data["X_test"], data["y_test"])),
            batch_size=self.batch_size * 8,
            shuffle=True,
        )

        for filename, value in tqdm([("train_dataloader", self.train_dataloader), ("test_dataloader", self.test_dataloader)]):
            dump(
                value=value,
                filename=os.path.join(self.processed_path, "{}.pkl".format(filename))
            )

        print("all the dataloaders are saved in # {}".format(self.processed_path))

    @staticmethod
    def plot_images():
        config_files = config()

        processed_path = validate_path(path=config_files["path"]["processed_path"])
        files_path = validate_path(path=config_files["path"]["files_path"])

        test_dataloader = load(filename=os.path.join(processed_path, "test_dataloader.pkl"))

        image_X, image_Y = next(iter(test_dataloader))

        plt.figure(figsize=(20, 10))

        for index, (image) in enumerate(image_X):
            X = image.permute(1, 2, 0).cpu().detach().numpy()
            Y = image_Y[index].permute(1, 2, 0).cpu().detach().numpy()

            X = (X - X.min())/(X.max() - X.min())
            Y = (Y - Y.min())/(Y.max() - Y.min())

            for idx, (title, value) in enumerate([("X", X), ("Y", Y)]):

                plt.subplot(2 * 2, 2 * 4, 2 * index + idx + 1)
                plt.imshow(value)
                plt.title(title)
                plt.axis("off")

        plt.savefig(os.path.join(files_path, "image.jpeg"))

        plt.show()

    @staticmethod
    def dataset_details():
        config_files = config()

        processed_path = validate_path(path=config_files["path"]["processed_path"])
        files_path = validate_path(path=config_files["path"]["files_path"])
        
        train_dataloader = load(filename=os.path.join(processed_path, "train_dataloader.pkl"))
        test_dataloader = load(filename=os.path.join(processed_path, "test_dataloader.pkl"))
        
        pd.DataFrame(
            {
                "train_image(Total)": str(sum(X.size(0) for X, _ in train_dataloader)),
                "test_image(Total)": str(sum(X.size(0) for X, _ in test_dataloader)),
                "train_image(Batch)": str(train_dataloader.batch_size),
                "test_image(Batch)": str(test_dataloader.batch_size),
                
            },
            index = ["Quantity"]
        ).T.to_csv(os.path.join(files_path, "dataset_details.csv"))


if __name__== "__main__":
    loader = Loader(image_path="/Users/shahmuhammadraditrahman/Desktop/datasets.zip")

    loader.unzip_folder()
    loader.create_dataloader()
    
    loader.plot_images()
    loader.dataset_details()

#### Generator block

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, use_leaky_relu = True, use_batch_norm = False):
        super(EncoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.leaky_relu = use_leaky_relu
        self.batch_norm = use_batch_norm

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1

        self.encoder = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=False
        )

        if self.leaky_relu:
            layers["leaky_ReLU"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        if self.batch_norm:
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder(x)

        else:
            raise ValueError("X should be in the format of tensor".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, torch.nn.modules.container.Sequential):
            return sum(params.numel() for params in model.parameters())

        else:
            raise ValueError("Model should be in the format of Sequential".capitalize())


if __name__ == "__main__":
    in_channels = 3
    out_channels = 64

    layers = []

    encoder1 = EncoderBlock(in_channels=in_channels, out_channels=out_channels, use_leaky_relu=True, use_batch_norm=False)
    encoder2 = EncoderBlock(in_channels=out_channels, out_channels=out_channels*2, use_leaky_relu=True, use_batch_norm=True) 
    encoder3 = EncoderBlock(in_channels=out_channels*2, out_channels=out_channels*4, use_leaky_relu=True, use_batch_norm=True) 
    encoder4 = EncoderBlock(in_channels=out_channels*4, out_channels=out_channels*8, use_leaky_relu=True, use_batch_norm=True) 
    encoder5 = EncoderBlock(in_channels=out_channels*8, out_channels=out_channels*8, use_leaky_relu=True, use_batch_norm=True) 
    encoder6 = EncoderBlock(in_channels=out_channels*8, out_channels=out_channels*8, use_leaky_relu=True, use_batch_norm=True) 
    encoder7 = EncoderBlock(in_channels=out_channels*8, out_channels=out_channels*8, use_leaky_relu=False,use_batch_norm=False)
    
    for block in [encoder1, encoder2, encoder3, encoder4, encoder5, encoder6, encoder7]:
        layers.append(block)
    
    model = nn.Sequential(*layers)
    
    assert EncoderBlock.total_params(model=model) == 15342336
    
    print(model(torch.randn(1, 3, 256, 256)).size())

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels = 512, out_channels = 512, last_layer=False):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.last_layer = last_layer

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1

        self.decoder = self.block()

    def block(self):
        layers = OrderedDict()

        layers["convTranspose"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=False
        )
        if self.last_layer:
            pass
        else:
            layers["ReLU"] = nn.ReLU(inplace=True)
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)

        return nn.Sequential(layers)

    def forward(self, x, skip_info=None):
        if isinstance(x, torch.Tensor) and isinstance(skip_info, torch.Tensor):
            x = self.decoder(x)
            return torch.cat((x, skip_info), dim=1)
        
        else:
            if isinstance(x, torch.Tensor) and skip_info is None:
                return self.decoder(x)


    @staticmethod
    def total_params(model):
        if isinstance(model, torch.nn.modules.container.Sequential):
            return sum(params.numel() for params in model.parameters())

        else:
            raise ValueError("Model should be in the format of Sequential".capitalize())
        
if __name__ == "__main__":
    in_channels = 512
    out_channels = 512
    
    model = DecoderBlock(in_channels=in_channels, out_channels=out_channels)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = 64

        self.encoder1 = EncoderBlock(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            use_leaky_relu=True,
            use_batch_norm=False,
        )
        self.encoder2 = EncoderBlock(
            in_channels=self.out_channels,
            out_channels=self.out_channels * 2,
            use_leaky_relu=True,
            use_batch_norm=True,
        )
        self.encoder3 = EncoderBlock(
            in_channels=self.out_channels * 2,
            out_channels=self.out_channels * 4,
            use_leaky_relu=True,
            use_batch_norm=True,
        )
        self.encoder4 = EncoderBlock(
            in_channels=self.out_channels * 4,
            out_channels=self.out_channels * 8,
            use_leaky_relu=True,
            use_batch_norm=True,
        )
        self.encoder5 = EncoderBlock(
            in_channels=self.out_channels * 8,
            out_channels=self.out_channels * 8,
            use_leaky_relu=True,
            use_batch_norm=True,
        )
        self.encoder6 = EncoderBlock(
            in_channels=self.out_channels * 8,
            out_channels=self.out_channels * 8,
            use_leaky_relu=True,
            use_batch_norm=True,
        )
        self.encoder7 = EncoderBlock(
            in_channels=self.out_channels * 8,
            out_channels=self.out_channels * 8,
            use_leaky_relu=False,
            use_batch_norm=False,
        )

        self.decoder1 = DecoderBlock(
            in_channels=self.out_channels * 8, out_channels=self.out_channels * 8
        )
        self.decoder2 = DecoderBlock(
            in_channels=self.out_channels * 8 * 2, out_channels=self.out_channels * 8
        )
        self.decoder3 = DecoderBlock(
            in_channels=self.out_channels * 8 * 2, out_channels=self.out_channels * 8
        )
        self.decoder4 = DecoderBlock(
            in_channels=self.out_channels * 8 * 2, out_channels=self.out_channels * 4
        )
        self.decoder5 = DecoderBlock(
            in_channels=self.out_channels * 4 * 2, out_channels=self.out_channels * 2
        )
        self.decoder6 = DecoderBlock(
            in_channels=self.out_channels * 2 * 2, out_channels=self.out_channels
        )
        self.decoder7 = DecoderBlock(
            in_channels=self.out_channels * 2,
            out_channels=in_channels,
            last_layer=True,
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            encoder1 = self.encoder1(x)
            encoder2 = self.encoder2(encoder1)
            encoder3 = self.encoder3(encoder2)
            encoder4 = self.encoder4(encoder3)
            encoder5 = self.encoder5(encoder4)
            encoder6 = self.encoder6(encoder5)
            encoder7 = self.encoder7(encoder6)

            decoder1 = self.decoder1(encoder7, encoder6)
            decoder2 = self.decoder2(decoder1, encoder5)
            decoder3 = self.decoder3(decoder2, encoder4)
            decoder4 = self.decoder4(decoder3, encoder3)
            decoder5 = self.decoder5(decoder4, encoder2)
            decoder6 = self.decoder6(decoder5, encoder1)
            output = self.decoder7(decoder6)

            return output

        else:
            raise ValueError("X should be in the format of tensor".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Generator):
            return sum(p.numel() for p in model.parameters())
        
        else:
            raise ValueError("Model should be in the Generator".capitalize()) 
        
        
if __name__ == "__main__":
    config_files = config()
    files_path = validate_path(config_files["path"]["files_path"])
    
    in_channels = 3
    
    netG = Generator(in_channels=in_channels)
    
    print(summary(model=netG, input_size=(3, 256, 256)))
    
    draw_graph(model=netG, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(files_path, "netG"), format="jpeg"
    )
    
    assert netG(torch.randn(1, 3, 256, 256)).size() == (1, 3, 256, 256)
    
    assert netG.total_params(netG) == 41828992

##### Discriminator

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, kernel_size = 4, stride_size = 2, padding_size = 1, last_layer = False):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.last_layer = last_layer

        self.kernel_size = kernel_size
        self.stride_size = stride_size
        self.padding_size = padding_size

        self.discriminator_block = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size
        )

        if self.last_layer:
            layers["tanh"] = nn.Tanh()

        else:
            layers["leaky_ReLU"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.discriminator_block(x)

        else:
            raise ValueError("Input should be a tensor".capitalize())
        
        
    @staticmethod
    def total_params(model):
        if isinstance(model, torch.nn.modules.container.Sequential):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        else:
            raise ValueError("Model should be a Sequential model".capitalize())


if __name__ == "__main__":
    in_channels = 3
    out_channels = 64
    
    layers = []
    
    for _ in range(3):
        layers.append(
            DiscriminatorBlock(
                in_channels=in_channels, out_channels=out_channels
            )
        )
        
        in_channels = out_channels
        out_channels *= 2
    
    for idx in range(2):
        layers.append(
            DiscriminatorBlock(
                in_channels=in_channels,
                out_channels= 1 if (idx == 1) else out_channels,
                stride_size=1,
                last_layer=(idx == 1)
            )
        )
        in_channels = out_channels
        out_channels *= 2
        
    model = nn.Sequential(*layers)
    
    assert DiscriminatorBlock.total_params(model=model) == 2766657
    assert model(torch.randn(1, 3, 256, 256)).size() == (1, 1, 30, 30)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Discriminator, self).__init__()

        self.in_channels = 3
        self.out_channels = 64

        layers = []

        for _ in range(3):
            layers.append(
                DiscriminatorBlock(
                    in_channels=self.in_channels, out_channels=self.out_channels
                )
            )

            self.in_channels = self.out_channels
            self.out_channels *= 2

        for idx in range(2):
            layers.append(
                DiscriminatorBlock(
                    in_channels=self.in_channels,
                    out_channels=1 if (idx == 1) else self.out_channels,
                    stride_size=1,
                    last_layer=(idx == 1),
                )
            )
            self.in_channels = self.out_channels
            self.out_channels *= 2

        self.model = nn.Sequential(*layers)
        
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            x1 = self.model[0](x)
            x2 = self.model[1](x1)
            x3 = self.model[2](x2)
            x4 = self.model[3](x3)
            x5 = self.model[4](x4)
            
            return torch.cat((
                x1.view(x1.size(0), -1),
                x2.view(x2.size(0), -1),
                x3.view(x3.size(0), -1),
                x4.view(x4.size(0), -1),
                x5.view(x5.size(0), -1)), dim=1
            )
        
        else:
            raise ValueError("X should be in the format of tenor".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Discriminator):
            return sum(params.numel() for params in model.parameters() if params.requires_grad)
        
        else:
            raise ValueError("Model should be in the format of Discriminator".capitalize())
        
if __name__ == "__main__":
    config_files = config()
    files_path = validate_path(path=config_files["path"]["files_path"])
    
    in_channels = 3
    
    netD = Discriminator(in_channels=in_channels)
    
    print(summary(model=netD, input_size=(in_channels, 256, 256)))
    
    draw_graph(model=netD, input_data=torch.randn(1, in_channels, 256, 256)).visual_graph.render(
        filename=os.path.join(files_path, "netD"), format="jpeg"
    )
    
    assert Discriminator.total_params(model=netD) == 2766657
    
    assert netD(torch.randn(1,3,256,256)).size() == (1, 2327940)

In [None]:
class L1Loss(nn.Module):
    def __init__(self, reduction = "mean"):
        super(L1Loss, self).__init__()
        
        self.loss_name = "L1Loss".title()
        
        self.reduction = reduction
    
        self.l1_loss = nn.L1Loss(reduction=self.reduction)
        
    def forward(self, actual, predicted):
        if isinstance(actual, torch.Tensor) and isinstance(predicted, torch.Tensor):
            return self.l1_loss(actual, predicted)
        
        else:
            raise ValueError("Actual and Predicted should be in the format of tensor".capitalize())
        
if __name__ == "__main__":
    l1loss = L1Loss(reduction="mean")
    
    actual = torch.tensor([1.0, 0.0, 1.0])
    predicted = torch.tensor([1.0, 0.0, 1.0])
    
    print("Total loss is # {}".format(l1loss(actual, predicted)))

In [None]:
class DiceLoss(nn.Module):
    def __init__(self, smooth=0.001):
        super(DiceLoss, self).__init__()

        self.loss_name = "Dice Loss".capitalize()

        self.smooth = smooth

    def forward(self, predicted, actual):
        if isinstance(predicted, torch.Tensor) and isinstance(actual, torch.Tensor):
            actual = actual.view(-1)
            predicted = predicted.view(-1)

        intersection = (predicted * actual).sum()
        dice = (2.0 * intersection + self.smooth) / (
            predicted.sum() + actual.sum() + self.smooth
        )

        return 1.0 - dice

#### helpers method

In [3]:
def load_dataloader():
    config_files = config()
    processed_path = config_files["path"]["processed_path"]
    processed_path = validate_path(path=processed_path)

    train_dataloader = load(
        filename=os.path.join(processed_path, "train_dataloader.pkl")
    )
    test_dataloader = load(filename=os.path.join(processed_path, "test_dataloader.pkl"))

    return {"train_dataloader": train_dataloader, "test_dataloader": test_dataloader}


def helpers(**kwargs):
    channels = kwargs["channels"]
    lr = kwargs["lr"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]

    netG = Generator(in_channels=channels)
    netD = Discriminator(in_channels=channels)

    if adam:
        optimizerG = optim.Adam(params=netG.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizerD = optim.Adam(params=netD.parameters(), lr=lr, betas=(0.5, 0.999))

    elif SGD:
        optimizerG = optim.SGD(params=netG.parameters(), lr=lr, momentum=0.9)
        optimizerD = optim.SGD(params=netD.parameters(), lr=lr, momentum=0.9)

    l1loss = L1Loss(reduction="mean")
    diceloss = DiceLoss(smooth=0.001)
    dataloader = load_dataloader()

    return {
        "netG": netG,
        "netD": netD,
        "optimizerG": optimizerG,
        "optimizerD": optimizerD,
        "l1loss": l1loss,
        "diceloss": diceloss,
        "train_dataloader": dataloader["train_dataloader"],
        "test_dataloader": dataloader["test_dataloader"],
    }

##### Trainer

In [None]:
import os
import argparse
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR
from torchvision.utils import save_image

from .helper import helpers
from .utils import weight_init, device_init, config, validate_path, dump, load
from .generator import Generator
from .discriminator import Discriminator


class Trainer:
    def __init__(
        self,
        channels=3,
        lr=0.0002,
        epochs=100,
        adam=True,
        SGD=False,
        device="cuda",
        beta1=0.5,
        beta2=0.999,
        momentum=0.90,
        smooth=0.001,
        step_size=10,
        gamma=0.5,
        lr_scheduler=False,
        is_display=True,
        is_weight_init=True,
    ):

        self.channels = channels
        self.lr = lr
        self.epochs = epochs
        self.adam = adam
        self.SGD = SGD
        self.device = device
        self.beta1 = beta1
        self.beta2 = beta2
        self.momentum = momentum
        self.smooth = smooth
        self.step_size = step_size
        self.gamma = gamma
        self.lr_scheduler = lr_scheduler
        self.is_display = is_display
        self.is_weight_init = is_weight_init

        self.device = device_init(device=self.device)

        self.init = helpers(
            channels=self.channels,
            lr=self.lr,
            adam=self.adam,
            SGD=self.SGD,
            beta1=self.beta1,
            beta2=self.beta2,
            momentum=self.momentum,
            smooth=self.smooth,
        )

        self.netG = self.init["netG"]
        self.netD = self.init["netD"]

        self.optimizerG = self.init["optimizerG"]
        self.optimizerD = self.init["optimizerD"]

        self.l1loss = self.init["l1loss"]
        self.diceloss = self.init["diceloss"]

        self.train_dataloader = self.init["train_dataloader"]
        self.test_dataloader = self.init["test_dataloader"]

        self.netG.to(self.device)
        self.netD.to(self.device)

        if self.is_weight_init:
            self.netG.apply(weight_init)
            self.netD.apply(weight_init)

        if self.lr_scheduler:
            self.schedulerG = StepLR(
                optimizer=self.optimizerG, step_size=self.step_size, gamma=self.gamma
            )
            self.schedulerD = StepLR(
                optimizer=self.optimizerD, step_size=self.step_size, gamma=self.gamma
            )

        self.loss = float("inf")
        self.total_netG_loss = []
        self.total_netD_loss = []
        self.history = {"netG_loss": [], "netD_loss": []}

        self.train_images_path = validate_path(
            path=config()["path"]["train_images_path"]
        )
        self.train_model = validate_path(path=config()["path"]["train_model_path"])
        self.best_model = validate_path(path=config()["path"]["best_model_path"])
        self.metrics_path = validate_path(path=config()["path"]["metrics_path"])

    def l1_loss(self, model):
        if isinstance(model, Generator):
            return (torch.norm(params, 1) for params in model.parameters()).mean()
        else:
            raise ValueError("The model is not a Generator".capitalize())

    def l2_loss(self, model):
        if isinstance(model, Generator):
            return (torch.norm(params, 2) for params in model.parameters()).mean()
        else:
            raise ValueError("The model is not a Generator".capitalize())

    def elastic_loss(self, model):
        if isinstance(model, Generator):
            l1 = self.l1_loss(model=model)
            l2 = self.l2_loss(model=model)

            return l1 + l2

        else:
            raise ValueError("The model is not a Generator".capitalize())

    def saved_model_checkpoints(self, **kwargs):
        torch.save(
            self.netG.state_dict(),
            os.path.join(self.train_model, "netG{}.pth".format(kwargs["epoch"])),
        )

        if self.loss > kwargs["netG_loss"]:
            self.loss = kwargs["netG_loss"]
            torch.save(
                {
                    "netG": self.netG.state_dict(),
                    "netD": self.netD.state_dict(),
                    "loss": kwargs["netG_loss"],
                },
                os.path.join(self.best_model, "best_model.pth"),
            )

    def saved_training_images(self, **kwargs):
        save_image(
            kwargs["predicted_mask"],
            os.path.join(
                self.train_images_path, "image{}.png".format(kwargs["epoch"] + 1)
            ),
            nrow=32,
            normalize=True,
        )

    def show_progress(self, **kwargs):
        if self.is_display:
            print(
                "Epochs - [{}/{}] - netG_loss: {:.4f} - netD_loss: {:.4f}".format(
                    kwargs["epoch"],
                    kwargs["epochs"],
                    kwargs["netG_loss"],
                    kwargs["netD_loss"],
                )
            )

        else:
            print(
                "Epochs - [{}/{}] is completed".format(
                    kwargs["epoch"], kwargs["epochs"]
                )
            )

    def update_train_netG(self, **kwargs):
        self.optimizerG.zero_grad()

        images = kwargs["image"]
        masks = kwargs["mask"]

        fake_masks = self.netG(images)
        fake_masks = torch.sigmoid(fake_masks)

        fakeB = images * fake_masks
        realA = images * masks

        fake_masks_loss = self.diceloss(fake_masks, masks)

        real_predict = self.netD(realA)
        fake_predict = self.netD(fakeB.detach())

        multiscale_loss = self.l1loss(real_predict, fake_predict)

        lossG = 0.1 * fake_masks_loss + multiscale_loss

        lossG.backward()
        self.optimizerG.step()

        return lossG.item()

    def update_train_netD(self, **kwargs):
        self.optimizerD.zero_grad()

        images = kwargs["image"]
        masks = kwargs["mask"]

        fake_masks = self.netG(images)
        fake_masks = torch.sigmoid(fake_masks)

        realA = images * masks
        fakeB = images * fake_masks

        real_predict = self.netD(realA)
        fake_predict = self.netD(fakeB.detach())

        lossD = 1 - self.l1loss(real_predict, fake_predict)

        lossD.backward()
        self.optimizerD.step()

        for params in self.netD.parameters():
            params.data.clamp_(-0.01, 0.01)

        return lossD.item()

    def train(self):
        for epoch in tqdm(range(self.epochs)):
            netG_loss = []
            netD_loss = []

            for index, (image, mask) in enumerate(self.train_dataloader):
                try:
                    image = image.to(self.device)
                    mask = mask.to(self.device)

                    netD_loss.append(self.update_train_netD(image=image, mask=mask))
                    netG_loss.append(self.update_train_netG(image=image, mask=mask))

                except Exception as e:
                    print(
                        f"Error during training at epoch {epoch + 1}, batch {index}: {e}"
                    )
                    continue

            try:
                self.show_progress(
                    epoch=epoch + 1,
                    epochs=self.epochs,
                    netD_loss=np.mean(netD_loss),
                    netG_loss=np.mean(netG_loss),
                )

                image, mask = next(iter(self.test_dataloader))
                image = image.to(self.device)
                predicted_mask = self.netG(image)

                self.saved_training_images(
                    image=image, mask=mask, predicted_mask=predicted_mask, epoch=epoch
                )

                self.saved_model_checkpoints(
                    netG_loss=np.mean(netG_loss),
                    epoch=epoch + 1,
                )

                if self.lr_scheduler:
                    self.schedulerD.step()
                    self.schedulerG.step()

                self.total_netG_loss.append(np.mean(netG_loss))
                self.total_netD_loss.append(np.mean(netD_loss))

            except Exception as e:
                print(f"Error during post-epoch processing at epoch {epoch + 1}: {e}")
                continue

        try:
            self.history["netG_loss"].append(self.total_netG_loss)
            self.history["netD_loss"].append(self.total_netD_loss)

            pd.DataFrame(
                {
                    "netG_loss": self.total_netG_loss,
                    "netD_loss": self.total_netD_loss,
                }
            ).to_csv(os.path.join(self.metrics_path, "model_history.csv"))

            for filename, value in [
                ("history.pkl", self.history),
                ("netG_loss.pkl", self.total_netG_loss),
                ("netD_loss.pkl", self.total_netD_loss),
            ]:
                dump(value=value, filename=os.path.join(self.metrics_path, filename))

            print(
                "Saved the model history in a csv file in {}\nSaved the model history in pickle format in {}".format(
                    self.metrics_path, self.metrics_path
                )
            )

        except Exception as e:
            print(f"Error during post-training processing: {e}")
            return

    @staticmethod
    def plot_history():
        metrics_path = config()["path"]["metrics_path"]
        metrics_path = validate_path(path=metrics_path)

        plt.figure(figsize=(20, 20))

        history = load(filename=os.path.join(metrics_path, "history.pkl"))

        for index, (title, loss) in enumerate(
            [
                ("netG_loss", history["netG_loss"]),
                ("netD_loss", history["netD_loss"]),
            ]
        ):
            plt.subplot(2 * 1, 2 * 2, 2 * index + (index + 1))

            plt.plot(loss[0], label=title)
            plt.title(f"{title}")
            plt.xlabel("Epochs")
            plt.ylabel("Loss")
            plt.legend()

        plt.tight_layout()
        save_image(os.path.join(metrics_path, "model_history.jpeg"))
        plt.show()

#### Testing

In [None]:
import os
import torch
import argparse
import matplotlib.pyplot as plt

from .utils import validate_path, config, device_init, load
from .generator import Generator


class TestModel:
    def __init__(self, device="cuda"):
        try:
            self.device = device_init(device=device)
            self.test_dataloader = validate_path(config()["path"]["processed_path"])
            self.best_model = validate_path(
                path=os.path.join(config()["path"]["best_model_path"])
            )
            self.test_image_path = validate_path(
                path=config()["path"]["test_image_path"]
            )
            self.netG = Generator().to(self.device)
        except Exception as e:
            print(f"Error during initialization: {e}")
            raise

    def load_dataset(self):
        try:
            return load(
                filename=os.path.join(self.test_dataloader, "test_dataloader.pkl")
            )
        except Exception as e:
            print(f"Error loading dataset: {e}")
            raise

    def load_best_model(self):
        try:
            self.netG.load_state_dict(
                torch.load(os.path.join(self.best_model, "best_model.pth"))["netG"]
            )
        except Exception as e:
            print(f"Error loading the best model: {e}")
            raise

    def plot(self):
        try:
            plt.figure(figsize=(20, 10))

            image, mask = next(iter(self.load_dataset()))
            predict = self.netG(image.to(self.device))

            size = image.size(0)
            num_row = size // 2
            num_columns = size // num_row

            for index, image in enumerate(predict):
                plt.subplot(2 * num_row, 2 * num_columns, 2 * index + 1)
                image = image.squeeze().cpu().permute(1, 2, 0).detach().numpy()
                masks = mask[index].squeeze().cpu().permute(1, 2, 0).detach().numpy()

                image = (image - image.min()) / (image.max() - image.min())
                masks = (masks - masks.min()) / (masks.max() - masks.min())

                plt.imshow(image, cmap="gray")
                plt.axis("off")
                plt.title("Generated Image")

                plt.subplot(2 * num_row, 2 * num_columns, 2 * index + 2)
                plt.imshow(masks)
                plt.title("Mask")
                plt.axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(self.test_image_path, "test.png"))
            plt.show()

            print("Test image is saved in the path: ", self.test_image_path)
        except Exception as e:
            print(f"Error during plotting: {e}")
            raise

    def test(self):
        try:
            self.load_best_model()
            self.plot()
        except Exception as e:
            print(f"Error during testing: {e}")
            raise