## Import the lib

In [None]:
import os
import cv2
import math
import torch
import zipfile
import argparse
import warnings
import traceback
import pandas as pd
from PIL import Image
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
from torchview import draw_graph
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

## Utils file

In [None]:
import yaml
import joblib


def dump(filename=None, value=None):
    if (filename is not None) and (value is not None):
        joblib.dump(value=value, filename=filename)
    else:
        raise ValueError("Could not dump file".capitalize())


def load(filename=None):
    if filename is not None:
        return joblib.load(filename=filename)
    else:
        raise ValueError("Could not load file".capitalize())


def config():
    with open("../config.yml", "r") as file:
        return yaml.safe_load(file)


def device_init(device: str = "cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")


def weight_init(m):
    classname = m.__class__.__name
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


def clean_folders():
    train_models = "../artifacts/train_models/"
    best_model = "../artifacts/best_model/"
    metrics_path = "../artifacts/metrics_path/"
    train_images = "../artifacts/train_images/"
    test_image = "../artifacts/test_image/"

    for folder in tqdm(
        [train_images, test_image, train_models, best_model, metrics_path]
    ):
        for file in os.listdir(folder):
            file_path = os.path.join(folder, file)
            try:
                if os.path.isfile(file_path):
                    os.remove(file_path)
            except Exception as e:
                print(f"Error occurred while cleaning folder: {folder}")
                print(f"Error: {e}")

        print("All files have been deleted.".capitalize())

## Dataloader

In [None]:
warnings.filterwarnings("ignore")


class Loader:
    def __init__(
        self,
        dataset=None,
        image_size: int = 128,
        batch_size: int = 1,
        split_size: float = 0.20,
    ):
        self.dataset = dataset
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.imageA = []
        self.imageB = []

    def unzip_folder(self):
        os.makedirs(config()["path"]["processed_path"], exist_ok=True)

        with zipfile.ZipFile(self.dataset, mode="r") as zip_ref:
            zip_ref.extractall(path=config()["path"]["processed_path"])

        print(f"""Unzip folder {config()["path"]["processed_path"]}""".capitalize())

    def split_dataset(self, X: list, y: list):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=self.split_size, random_state=42
            )
            return {
                "X_train": X_train,
                "X_test": X_test,
                "y_train": y_train,
                "y_test": y_test,
            }

    def transforms(self, type: str = "image"):
        if type == "image":
            return transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize(
                        (self.image_size, self.image_size), Image.BICUBIC
                    ),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
                ]
            )
        else:
            return transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
                ]
            )

    def features_extractor(self):
        dataset_path = os.path.join(config()["path"]["processed_path"], "dataset")

        images_path = os.path.join(dataset_path, "X")
        masks_path = os.path.join(dataset_path, "y")

        for image in tqdm(os.listdir(images_path)):
            if image.endswith((".jpg", ".jpeg", ".png")) and (
                image in os.path.join(masks_path, image)
            ):
                image_path = os.path.join(images_path, image)
                mask_path = os.path.join(masks_path, image)

                if not os.path.exists(image_path):
                    print(f"Image not found: {image_path}")
                    continue
                if not os.path.exists(mask_path):
                    print(f"Mask not found: {mask_path}")
                    continue

                X = cv2.imread(image_path)
                y = cv2.imread(mask_path)

                X = cv2.cvtColor(X, cv2.COLOR_BGR2RGB)
                y = cv2.cvtColor(y, cv2.COLOR_BGR2RGB)

                X = self.transforms(type="image")(Image.fromarray(X))
                y = self.transforms()(Image.fromarray(y))

                self.imageA.append(X)
                self.imageB.append(y)

        assert len(self.imageA) == len(self.imageB)

        try:
            return self.split_dataset(X=self.imageA, y=self.imageB)
        except AssertionError as e:
            print(f"Assertion error: {e}")
            sys.exit(1)

    def create_dataloader(self):
        try:
            dataset = self.features_extractor()

            train_dataloader = DataLoader(
                dataset=list(zip(dataset["X_train"], dataset["y_train"])),
                batch_size=self.batch_size,
                shuffle=True,
            )
            valid_dataloader = DataLoader(
                dataset=list(zip(dataset["X_test"], dataset["y_test"])),
                batch_size=self.batch_size * 16,
                shuffle=True,
            )

            for filename, dataloader in [
                ("train_dataloader", train_dataloader),
                ("valid_dataloader", valid_dataloader),
            ]:
                dump(
                    filename=os.path.join(
                        "../data/processed/", filename + ".pkl"
                    ),
                    value=dataloader,
                )

            print(
                "Train and valid dataloader saved successfully in the folder {}".format(
                    "../data/processed/"
                ).capitalize()
            )

        except AssertionError as e:
            print(f"Assertion error: {e}")
            traceback.print_exc()
            sys.exit(1)
        except Exception as e:
            print(f"An error occurred: {e}")
            traceback.print_exc()
            sys.exit(1)

    @staticmethod
    def display_images():
        processed_path = os.path.join(config()["path"]["processed_path"])
        if os.path.exists(processed_path):
            train_dataloder = os.path.join(processed_path, "train_dataloader.pkl")
            valid_dataloder = os.path.join(processed_path, "valid_dataloader.pkl")

            train_dataloder = load(filename=train_dataloder)
            valid_dataloder = load(filename=valid_dataloder)

            valid_X, valid_Y = next(iter(valid_dataloder))

            num_of_rows = valid_X.size(0) // 2
            num_of_cols = valid_X.size(0) // num_of_rows

            plt.figure(figsize=(10, 20))

            for index, X in enumerate(valid_X):
                X = X.squeeze().permute(2, 1, 0).detach().cpu().numpy()
                y = valid_Y[index].squeeze().permute(2, 1, 0).detach().cpu().numpy()

                X = (X - X.min()) / (X.max() - X.min())
                y = (y - y.min()) / (y.max() - y.min())

                plt.subplot(2 * num_of_rows, 2 * num_of_cols, 2 * index + 1)
                plt.imshow(X)
                plt.title("X")
                plt.axis("off")

                plt.subplot(2 * num_of_rows, 2 * num_of_cols, 2 * index + 2)
                plt.imshow(y)
                plt.title("Y")
                plt.axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join("../artifacts/files/", "images.png"))
            plt.show()

    @staticmethod
    def dataset_details():
        processed_path = os.path.join(config()["path"]["processed_path"])
        if os.path.exists(processed_path):
            train_dataloder = os.path.join(processed_path, "train_dataloader.pkl")
            valid_dataloder = os.path.join(processed_path, "valid_dataloader.pkl")

            train_dataloder = load(filename=train_dataloder)
            valid_dataloder = load(filename=valid_dataloder)

            train_X, train_Y = next(iter(train_dataloder))
            valid_X, valid_Y = next(iter(valid_dataloder))

            pd.DataFrame(
                {
                    "Train X Shape": str(train_X.size()),
                    "Train Y Shape": str(train_Y.size()),
                    "Valid X Shape": str(valid_X.size()),
                    "Valid Y Shape": str(valid_Y.size()),
                    "total_train_dataset": sum(X.size(0) for X, _ in train_dataloder),
                    "total_valid_dataset": sum(X.size(0) for X, _ in valid_dataloder),
                    "total_dataset": (sum(X.size(0) for X, _ in train_dataloder))
                    + sum(Y.size(0) for Y, _ in valid_dataloder),
                },
                index=["Dataset Details"],
            ).T.to_csv(
                os.path.join("../artifacts/files", "dataset_details.csv"),
            )

        else:
            print(f"Folder {processed_path} does not exist".capitalize())
            sys.exit(1)


if __name__ == "__main__":

    loader = Loader(dataset="../data/raw/dataset.zip", batch_size=1, split_size=0.20)

    loader.unzip_folder()
    loader.create_dataloader()

    Loader.dataset_details()
    Loader.display_images()

## Residual Block

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels: int = 256):
        super(ResidualBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = in_channels

        self.reflectionpad2d = 1
        self.kernel_size = 3
        self.stride_size = 1

        self.layers = []

        for index in range(2):
            self.layers.append(nn.ReflectionPad2d(self.reflectionpad2d))

            self.layers.append(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                )
            )

            self.layers.append(nn.InstanceNorm2d(num_features=self.out_channels))

            if index != 1:
                self.layers.append(nn.ReLU())

        self.residualBlock = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return x + self.residualBlock(x)
        else:
            raise ValueError("Input should be the tensor type".capitalize())


if __name__ == "__main__":
    image_channels = 256
    residual = nn.Sequential(
        *[ResidualBlock(in_channels=image_channels) for _ in range(3)]
    )

    print(residual)

## Encoder 

In [None]:
class Encoder(nn.Module):

    def __init__(self, in_channels: int = 3, sharedBlocks=None):
        super(Encoder, self).__init__()

        self.in_channels = in_channels
        self.out_channels = int(math.pow(2, self.in_channels + self.in_channels))
        self.kerenl_size = (self.in_channels * 2) + 1

        if not isinstance(sharedBlocks, ResidualBlock):
            raise ValueError(
                "shared_block must be an instance of ResidualBlock".capitalize()
            )

        self.sharedBlocks = sharedBlocks

        self.modelBlocks = list()
        self.downLayers = list()

        self.modelBlocks.append(
            nn.Sequential(
                nn.ReflectionPad2d(padding=self.in_channels),
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kerenl_size,
                ),
                nn.InstanceNorm2d(num_features=self.out_channels),
                nn.LeakyReLU(negative_slope=0.2, inplace=True),
            )
        )

        for _ in range(2):
            self.downLayers.append(
                nn.Conv2d(
                    in_channels=self.out_channels,
                    out_channels=self.out_channels * 2,
                    kernel_size=self.kerenl_size // 2,
                    stride=(self.in_channels // self.in_channels) + 1,
                    padding=self.in_channels // self.in_channels,
                )
            )
            self.downLayers.append(
                nn.InstanceNorm2d(num_features=self.out_channels * 2)
            )
            self.downLayers.append(nn.ReLU())

            self.out_channels *= 2

        self.modelBlocks.append(nn.Sequential(*self.downLayers))

        self.modelBlocks.append(
            nn.Sequential(
                *[ResidualBlock(in_channels=self.out_channels) for _ in range(3)]
            )
        )

        self.modelBlocks = nn.Sequential(*self.modelBlocks)

    def reparameterization(self, mu: torch.Tensor):
        if isinstance(mu, torch.Tensor):
            z = torch.randn_like(mu)
            return mu + z
        else:
            raise ValueError("Input should be the tensor type".capitalize())

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.modelBlocks(x)
            mu = self.sharedBlocks(x)
            z = self.reparameterization(mu)

            return mu, z
        else:
            raise ValueError("Input should be the tensor type".capitalize())


if __name__ == "__main__":
    in_channels = 3

    batch_size = 1
    image_size = 128

    shared_E = ResidualBlock(
        in_channels=int(in_channels * (math.pow(2, 8) - 1) / in_channels + 1)
    )

    encoder1 = Encoder(in_channels=in_channels, sharedBlocks=shared_E)
    encoder2 = Encoder(in_channels=in_channels, sharedBlocks=shared_E)

    mu1, z1 = encoder1(torch.randn(batch_size, in_channels, image_size, image_size))
    mu2, z2 = encoder2(torch.randn(batch_size, in_channels, image_size, image_size))

    assert (
        mu1.size() == mu2.size() == z1.size() == z2.size()
    ), "Shape mismatch(mu1, mu2) and (z1, z2)".capitalize()

## Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels: int = 256, sharedBlock: ResidualBlock = None):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = self.in_channels

        self.kernel_size = int(math.sqrt(math.sqrt(self.in_channels)))
        self.stride_size = int(math.sqrt(self.kernel_size))
        self.padding_size = self.stride_size // self.stride_size

        if isinstance(sharedBlock, ResidualBlock):
            self.sharedBlock = sharedBlock
        else:
            raise ValueError(
                "shared_block must be an instance of ResidualBlock".capitalize()
            )

        self.modelBlocks = []
        self.upsampleBlocks = []

        self.modelBlocks.append(
            nn.Sequential(
                *[ResidualBlock(in_channels=self.in_channels) for _ in range(3)]
            )
        )

        for _ in range(2):
            self.upsampleBlocks.append(
                nn.ConvTranspose2d(
                    in_channels=self.in_channels,
                    out_channels=self.in_channels // 2,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                )
            )

            self.upsampleBlocks.append(
                nn.InstanceNorm2d(num_features=self.in_channels // 2)
            )
            self.upsampleBlocks.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))

            self.in_channels //= 2

        self.modelBlocks.append(nn.Sequential(*self.upsampleBlocks))

        self.modelBlocks.append(
            nn.Sequential(
                nn.ReflectionPad2d(padding=self.kernel_size - 1),
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.kernel_size - 1,
                    kernel_size=self.kernel_size + 3,
                ),
                nn.Tanh(),
            )
        )

        self.modelBlocks = nn.Sequential(*self.modelBlocks)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            x = self.sharedBlock(x)
            return self.modelBlocks(x)
        else:
            raise ValueError("Input should be the tensor type".capitalize())


if __name__ == "__main__":
    image_channels = 256

    batch_size = 1
    image_size = 32

    shared_G = ResidualBlock(in_channels=image_channels)

    netG1 = Generator(
        in_channels=image_channels,
        sharedBlock=ResidualBlock(in_channels=image_channels),
    )
    netG2 = Generator(
        in_channels=image_channels,
        sharedBlock=ResidualBlock(in_channels=image_channels),
    )

    generatedImage1 = netG1(
        torch.randn(batch_size, image_channels, image_size, image_size)
    )
    generatedImage2 = netG2(
        torch.randn(batch_size, image_channels, image_size, image_size)
    )

    assert (
        generatedImage1.size() == generatedImage2.size()
    ), "Shape mismatch(generatedImage1, generatedImage2)".capitalize()

## Discriminator

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels: int = 3):
        super(Discriminator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = int(math.pow(2, self.in_channels + self.in_channels))

        self.kernel_size = self.in_channels + 1
        self.stride_size = self.kernel_size // 2
        self.padding_size = self.stride_size // 2

        self.layers = list()

        for index in range(4):
            self.layers.append(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel_size,
                    stride=self.stride_size,
                    padding=self.padding_size,
                )
            )

            if index != 0:
                self.layers.append(nn.InstanceNorm2d(num_features=self.out_channels))

            self.layers.append(nn.LeakyReLU(negative_slope=0.2, inplace=True))

            self.in_channels = self.out_channels
            self.out_channels = self.out_channels * 2

        self.layers.append(
            nn.Sequential(
                nn.Conv2d(
                    in_channels=self.in_channels,
                    out_channels=self.in_channels // self.in_channels,
                    kernel_size=self.kernel_size - 1,
                    stride=self.stride_size - 1,
                    padding=self.padding_size,
                )
            )
        )

        self.model = nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.model(x)
        else:
            raise ValueError("Input should be the tensor type".capitalize())


if __name__ == "__main__":
    image_channels = 3

    batch_size = 1
    image_size = 128

    netD = Discriminator(in_channels=image_channels)

    assert netD(
        torch.randn(batch_size, image_channels, image_size, image_size)
    ).size() == torch.Size(
        [
            batch_size,
            image_channels // image_channels,
            image_size // 16,
            image_size // 16,
        ]
    )

## GANLoss

In [None]:
class GANLoss(nn.Module):
    def __init__(self, reduction: str = "mean"):
        super(GANLoss, self).__init__()

        self.name = "GANLoss for the UNIT-GAN"
        self.reduction = reduction
        self.loss = nn.MSELoss(reduction=self.reduction)

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(actual, torch.Tensor):
            return self.loss(predicted, actual)
        else:
            raise ValueError("Predicted and actual should be both tensor".capitalize())


if __name__ == "__main__":
    loss = GANLoss(reduction="mean")

    actual = torch.tensor([1.0, 0.0, 1.0, 1.0])
    predicted = torch.tensor([1.0, 0.0, 1.0, 1.0])

    print(loss(predicted, actual))


## PixelLoss

In [None]:
class PixelLoss(nn.Module):
    def __init__(self, reduction: str = "mean"):
        super(PixelLoss, self).__init__()

        self.name = "L1Loss for the UNIT-GAN"
        self.reduction = reduction

        self.loss = nn.L1Loss(reduction=self.reduction)

    def forward(self, predicted: torch.Tensor, actual: torch.Tensor):
        if isinstance(predicted, torch.Tensor) and isinstance(actual, torch.Tensor):
            return self.loss(predicted, actual)
        else:
            raise ValueError("Predicted and actual should be in tensor".capitalize())


if __name__ == "__main__":
    loss = PixelLoss(reduction="mean")

    actual = torch.tensor([1.0, 0.0, 1.0, 1.0])
    predicted = torch.tensor([1.0, 0.0, 1.0, 1.0])

    print(loss(predicted, actual))

## KL Divergence

In [None]:
class KLDivergence(nn.Module):
    def __init__(self):
        super(KLDivergence, self).__init__()
        self.name = "KL Divergence".title()

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return torch.mean(torch.pow(x, 2))
        else:
            raise ValueError("X should be the type of torch.Tensor".capitalize())


if __name__ == "__main__":
    batch_size = config()["dataloader"]["batch_size"]

    loss = KLDivergence()

    predicted = torch.randn((batch_size, 64))

    assert (
        type(loss(predicted)) == torch.Tensor
    ), "Output should be the torch.Tensor".capitalize()

## Helper

In [None]:
import warnings

warnings.filterwarnings("ignore")


def load_dataloader():
    processed_path = "../data/processed/"
    if os.path.exists(processed_path):
        train_dataloader = os.path.join(processed_path, "train_dataloader.pkl")
        valid_dataloader = os.path.join(processed_path, "valid_dataloader.pkl")

        train_dataloader = load(filename=train_dataloader)
        valid_dataloader = load(filename=valid_dataloader)

        return {
            "train_dataloader": train_dataloader,
            "valid_dataloader": valid_dataloader,
        }


def helper(**kwargs):
    lr = kwargs["lr"]
    beta1 = kwargs["beta1"]
    beta2 = kwargs["beta2"]
    momentum = kwargs["momentum"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]

    shared_E = ResidualBlock(in_channels=256)
    shared_G = ResidualBlock(in_channels=256)

    E1 = Encoder(
        in_channels=config()["dataloader"]["image_channels"], sharedBlocks=shared_E
    )
    E2 = Encoder(
        in_channels=config()["dataloader"]["image_channels"], sharedBlocks=shared_E
    )

    G1 = Generator(in_channels=256, sharedBlocks=shared_G)
    G2 = Generator(in_channels=256, sharedBlocks=shared_G)

    D1 = Discriminator(in_channels=config()["dataloader"]["image_channels"])
    D2 = Discriminator(in_channels=config()["dataloader"]["image_channels"])

    if adam:
        optimizerG = optim.Adam(
            params=list(E1.parameters())
            + list(E2.parameters())
            + list(G1.parameters())
            + list(G2.parameters()),
            lr=lr,
            betas=(beta1, beta2),
        )
        optimizerD1 = optim.Adam(params=D1.parameters(), lr=lr, betas=(beta1, beta2))
        optimizerD2 = optim.Adam(params=D2.parameters(), lr=lr, betas=(beta1, beta2))

    elif SGD:
        optimizerG = optim.SGD(
            params=list(E1.parameters())
            + list(E2.parameters())
            + list(G1.parameters())
            + list(G2.parameters()),
            lr=lr,
            momentum=momentum,
        )
        optimizerD1 = optim.SGD(params=D1.parameters(), lr=lr, momentum=momentum)
        optimizerD2 = optim.SGD(params=D2.parameters(), lr=lr, momentum=momentum)

    criterion = GANLoss(reduction="mean")
    pixelLoss = PixelLoss(reduction="mean")

    try:
        dataset = load_dataloader()
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        traceback.print_exc()
        sys.exit(1)

    return {
        "train_dataloader": dataset["train_dataloader"],
        "valid_dataloader": dataset["valid_dataloader"],
        "E1": E1,
        "E2": E2,
        "netG1": G1,
        "netG2": G2,
        "netD1": D1,
        "netD2": D2,
        "optimizerG": optimizerG,
        "optimizerD1": optimizerD1,
        "optimizerD2": optimizerD2,
        "criterion": criterion,
        "pixelLoss": pixelLoss,
    }


if __name__ == "__main__":
    init = helper(
        lr=2e-4,
        beta1=0.5,
        beta2=0.999,
        momentum=0.95,
        adam=True,
        SGD=False,
    )

    train_dataloader = init["train_dataloader"]
    valid_dataloader = init["valid_dataloader"]

    encoder1 = init["E1"]
    encoder2 = init["E2"]

    netG1 = init["netG1"]
    netG2 = init["netG2"]

    netD1 = init["netD1"]
    netD2 = init["netD2"]

    optimizerG = init["optimizerG"]
    optimizerD1 = init["optimizerD1"]
    optimizerD2 = init["optimizerD2"]

    criterion = init["criterion"]
    pixelLoss = init["pixelLoss"]

    assert (
        train_dataloader.__class__ == torch.utils.data.DataLoader
    ), "Train dataloader shoould be torch.utils.data.DataLoader".capitalize()
    assert (
        valid_dataloader.__class__ == torch.utils.data.DataLoader
    ), "Valid dataloader shoould be torch.utils.data.DataLoader".capitalize()

    assert (
        encoder1.__class__ == Encoder
    ), "Encoder object should be Encoder class".capitalize()
    assert (
        encoder2.__class__ == Encoder
    ), "Encoder object should be Encoder class".capitalize()

    assert (
        netG1.__class__ == Generator
    ), "Generator object should be Generator class".capitalize()
    assert (
        netG2.__class__ == Generator
    ), "Generator object should be Generator class".capitalize()

    assert (
        netD1.__class__ == Discriminator
    ), "netD1 object should be Discriminator class".capitalize()
    assert (
        netD2.__class__ == Discriminator
    ), "netD2 object should be Discriminator class".capitalize()

    assert (
        optimizerG.__class__ == optim.Adam
    ), "optimizerG object should be Adam class".capitalize()
    assert (
        optimizerD1.__class__ == optim.Adam
    ), "optimizerD1 object should be Adam class".capitalize()
    assert (
        optimizerD2.__class__ == optim.Adam
    ), "optimizerD2 object should be Adam class".capitalize()

    assert (
        criterion.__class__ == GANLoss
    ), "Criterion object should be GANLoss class".capitalize()
    assert (
        pixelLoss.__class__ == PixelLoss
    ), "pixelLoss object should be PixelLoss class".capitalize()

## Trainer

In [None]:
class Trainer:
    def __init__(
        self,
        epochs: int = 500,
        lr: float = 2e-5,
        beta1: float = 0.5,
        beta2: float = 0.999,
        momentum: float = 0.95,
        adam: bool = True,
        SGD: bool = False,
        device: str = "cuda",
        l1_regularization: bool = False,
        l2_regularization: float = False,
        elasticNet_regularization: bool = False,
        verbose: bool = True,
        mlFlow: bool = True,
    ):
        self.epochs = epochs
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.momentum = momentum
        self.adam = adam
        self.SGD = SGD
        self.device = device
        self.l1_regularization = l1_regularization
        self.l2_regularization = l2_regularization
        self.elasticNet_regularization = elasticNet_regularization
        self.verbose = verbose
        self.mlFlow = mlFlow

        self.device = device_init(device=self.device)

        self.init = helper(
            lr=self.lr,
            beta1=self.beta1,
            beta2=self.beta2,
            momentum=self.momentum,
            adam=self.adam,
            SGD=self.SGD,
        )

        self.train_dataloader = self.init["train_dataloader"]
        self.valid_dataloader = self.init["valid_dataloader"]

        self.encoder1 = self.init["E1"].to(self.device)
        self.encoder2 = self.init["E2"].to(self.device)

        self.netG1 = self.init["netG1"].to(self.device)
        self.netG2 = self.init["netG2"].to(self.device)

        self.netD1 = self.init["netD1"].to(self.device)
        self.netD2 = self.init["netD2"].to(self.device)

        self.optimizerG = self.init["optimizerG"]
        self.optimizerD1 = self.init["optimizerD1"]
        self.optimizerD2 = self.init["optimizerD2"]

        self.criterion = self.init["criterion"]
        self.pixelLoss = self.init["pixelLoss"]
        self.kl_loss = self.init["kl_loss"]

        assert (
            self.train_dataloader.__class__ == torch.utils.data.DataLoader
        ), "Train dataloader shoould be torch.utils.data.DataLoader".capitalize()
        assert (
            self.valid_dataloader.__class__ == torch.utils.data.DataLoader
        ), "Valid dataloader shoould be torch.utils.data.DataLoader".capitalize()

        assert (
            self.encoder1.__class__ == Encoder
        ), "Encoder object should be Encoder class".capitalize()
        assert (
            self.encoder2.__class__ == Encoder
        ), "Encoder object should be Encoder class".capitalize()

        assert (
            self.netG1.__class__ == Generator
        ), "Generator object should be Generator class".capitalize()
        assert (
            self.netG2.__class__ == Generator
        ), "Generator object should be Generator class".capitalize()

        assert (
            self.netD1.__class__ == Discriminator
        ), "netD1 object should be Discriminator class".capitalize()
        assert (
            self.netD2.__class__ == Discriminator
        ), "netD2 object should be Discriminator class".capitalize()

        assert (
            self.optimizerG.__class__ == optim.Adam
        ), "optimizerG object should be Adam class".capitalize()
        assert (
            self.optimizerD1.__class__ == optim.Adam
        ), "optimizerD1 object should be Adam class".capitalize()
        assert (
            self.optimizerD2.__class__ == optim.Adam
        ), "optimizerD2 object should be Adam class".capitalize()

        assert (
            self.criterion.__class__ == GANLoss
        ), "Criterion object should be GANLoss class".capitalize()
        assert (
            self.pixelLoss.__class__ == PixelLoss
        ), "pixelLoss object should be PixelLoss class".capitalize()
        assert (
            self.kl_loss.__class__ == KLDivergence
        ), "KL Divergence object should be PixelLoss class".capitalize()

    def l1_regularizer(self, model):
        if model is not None:
            return sum(torch.norm(params, 1) for params in model.parameters())
        else:
            raise TypeError(
                "Model should be passed in the l1 regularizer".capitalize()()
            )
            
    def l2_regularizer(self, model):
        if model is not None:
            return sum(torch.norm(params, 2) for params in model.parameters())
        else:
            raise TypeError(
                "Model should be passed in the l2 regularizer".capitalize()()
            )
            
    def elasticNet_regularizer(self, model):
        if model is not None:
            l1 = self.l1_regularization(model=model)
            l2 = self.l2_regularization(model=model)
            
            return 0.01 * (l1 + l2)
        else:
            raise TypeError(
                "Model should be passed in the l2 regularizer".capitalize()()
            )


if __name__ == "__main__":
    trainer = Trainer()
