In [None]:
import sys
import os
import yaml
import zipfile
import joblib
import cv2
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchview import draw_graph

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
def config():
    with open("../../config.yml", "r") as file:
        config_files = yaml.safe_load(file)
        
    return config_files

In [None]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
        
def load(filename = None):
    if filename is not None:
        return joblib.load(filename=filename)
    
    
def device_init(device = "mps"):
    if device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    elif device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    else:
        return torch.device("cpu")
    
def weights_init(m):
    classname = m.__class__.__name__
    
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [None]:
class Loader():
    def __init__(self, image_path = None, image_size = 256, channels = 3, batch_size = 1, split_size = 0.20, paired_images = False, unpaired_images = True):
        self.image_path = image_path
        self.image_size = image_size
        self.channels = channels
        self.batch_size = batch_size
        self.split_size = split_size
        self.paired_image = paired_images
        self.unpaired_image = unpaired_images

        self.config = config()

        self.X = []
        self.y = []

    def transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Grayscale(num_output_channels=self.channels),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

    def data_split(self, X, y):
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.split_size, random_state=42)

        return {
            "X_train": X_train,
            "X_test": X_test,
            "y_train": y_train,
            "y_test": y_test,
        }

    def unzip_folder(self):
        if os.path.exists(self.config["path"]["raw_path"]):
            path = self.config["path"]["raw_path"]

            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(os.path.join(path,))

        else:
            raise Exception("Unable to unzip the folder as the path is not exists".capitalize())

    def extract_features(self):
        self.directory = os.path.join(self.config["path"]["raw_path"])
        self.categories = os.listdir(os.path.join(self.directory, "images"))

        for index, category in enumerate(self.categories[0]) if self.paired_image else enumerate(self.categories): # X
            path = os.path.join(self.directory, "images", category) # Full path

            for image in os.listdir(path): # Each image

                if self.paired_image:
                    if image in os.listdir(os.path.join(self.directory, "images", "y")):
                        image_path_X = os.path.join(path, image) # Full image path - X
                        image_path_y = os.path.join(
                            self.directory, "images", "y", image
                        )  # Full image path - y

                        image_X = cv2.imread(image_path_X)
                        image_y = cv2.imread(image_path_y)

                        image_X = cv2.cvtColor(image_X, cv2.COLOR_BGR2RGB)
                        image_y = cv2.cvtColor(image_y, cv2.COLOR_BGR2RGB)

                        self.X.append(self.transforms()(Image.fromarray(image_X)))
                        self.y.append(self.transforms()(Image.fromarray(image_y)))

                elif self.unpaired_image:
                    image_path = os.path.join(path, image) # Full path

                    image_read =cv2.imread(image_path)
                    image_read = cv2.cvtColor(image_read, cv2.COLOR_BGR2RGB)

                    self.X.append(self.transforms()(Image.fromarray(image_read))) if index %2\
                        else self.y.append(self.transforms()(Image.fromarray(image_read)))

        data = self.data_split(self.X, self.y)

        return data

    def create_dataloader(self):

        self.data = self.extract_features()

        if os.path.exists(self.config["path"]["processed_path"]):
            processed_path = self.config["path"]["processed_path"]

            train_dataloader = DataLoader(
                dataset=list(zip(self.data["X_train"], self.data["y_train"])),
                batch_size=self.batch_size,
                shuffle=True,)

            test_dataloader = DataLoader(
                dataset=list(zip(self.data["X_test"], self.data["y_test"])),
                batch_size=self.batch_size*4,
                shuffle=True,)

            dataloader = DataLoader(
                dataset=list(zip(self.X, self.y)),
                batch_size=self.batch_size*8,
                shuffle=True,)

            dump(
                value=train_dataloader, filename=os.path.join(processed_path, "train_dataloader.pkl"))

            dump(
                value=test_dataloader, filename=os.path.join(processed_path, "test_dataloader.pkl"))

            dump(
                value=dataloader, filename=os.path.join(processed_path, "dataloader.pkl"))

        else:
            raise Exception("Unable to create the pickle file".capitalize())

    @staticmethod
    def plot_images():
        config_files = config()

        if os.path.exists(config_files["path"]["processed_path"]):
            dataloader = load(os.path.join(config_files["path"]["processed_path"], "dataloader.pkl"))

            X, y = next(iter(dataloader))

            plt.figure(figsize=(10, 10))

            for index, image in enumerate(X):
                image_X = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
                image_y = y[index].squeeze().permute(1, 2, 0).cpu().detach().numpy()

                image_X = (image_X - image_X.min())/(image_X.max() - image_X.min())
                image_y = (image_y - image_y.min())/(image_y.max() - image_y.min())

                plt.subplot(2 * 4, 2 * 2, 2 * index + 1)
                plt.imshow(image_X)
                plt.axis("off")

                plt.subplot(2 * 4, 2 * 2, 2 * index + 2)
                plt.imshow(image_y)
                plt.axis("off")

            plt.tight_layout()
            
            if os.path.exists(config_files["path"]["files_path"]):
                plt.savefig(os.path.join(config_files["path"]["files_path"], "images.png"))
            else:
                raise Exception("Unable to save the images as the path is not exists".capitalize())
            
            plt.show()

        else:
            raise Exception("Unable to plot the images as the path is not exists".capitalize())

    @staticmethod
    def dataset_details():
        config_files = config()

        if os.path.exists(config_files["path"]["processed_path"]):
            path = config_files["path"]["processed_path"]

            train_dataloader = load(filename=os.path.join(path, "train_dataloader.pkl"))
            test_dataloader = load(filename=os.path.join(path, "test_dataloader.pkl"))
            dataloader = load(filename=os.path.join(path, "dataloader.pkl"))

            pd.DataFrame(
                {
                    "train_data(total)": str(sum(data.size(0) for data, _ in train_dataloader)),
                    "test_data(total)": str(sum(data.size(0) for data, _ in test_dataloader)),
                    "data(total)": str(sum(data.size(0) for data, _ in dataloader)),
                    "train_data(batch)": str(len(train_dataloader)),
                    "test_data(batch)": str(len(test_dataloader)),
                    "train_data(shape)": str(train_dataloader.dataset[0][0].shape),
                    "test_data(shape)": str(test_dataloader.dataset[0][0].shape),
                    "data(shape)": str(dataloader.dataset[0][0].shape),
                },
                index=["Details dataset".capitalize()],
            ).T.to_csv(os.path.join(
                os.path.join(config_files["path"]["files_path"], "dataset_details.csv") if os.path.exists(config_files["path"]["files_path"])\
                else os.path.join(config_files["path"["files_path"]], "dataset_details.csv")))

        else:
            raise Exception("Unable to create the pickle file".capitalize())    


if __name__ == "__main__":
    loader = Loader(
        image_path="/Users/shahmuhammadraditrahman/Desktop/images.zip", paired_images=True)
    # loader.unzip_folder()
    loader.extract_features()
    loader.create_dataloader()
    loader.dataset_details()
    loader.plot_images()

#### Create the Generator

In [None]:
class InputBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64):
        super(InputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel = 7
        self.stride = 1
        self.padding = 3
        
        self.input_block = self.block()
        
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            padding_mode="reflect",
            bias=False,
        )
        layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(layers)
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.input_block(x)
        
        else:
            raise Exception("Unable to process the input".capitalize())
        
if __name__ == "__main__":
    in_channels = 3
    out_channels = 64
    config_files = config()
    
    input_block = InputBlock(
        in_channels=in_channels,
        out_channels=out_channels,
    )
    
    print(input_block(torch.randn(1, 3, 256, 256)).size())
    print(summary(model=input_block, input_size=(3, 256, 256)))
    draw_graph(model=input_block, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_input_block"), format="jpeg"
    )

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 128):
        super(EncoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 2
        self.padding = 1

        self.encoder_block = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            bias=False,
        )
        layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        layers["ReLU"] = nn.ReLU(inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 64
    out_channels = 128
    num_repetitive = 2

    layers = []

    for _ in tqdm(range(num_repetitive)):
        layers.append(EncoderBlock(in_channels=in_channels, out_channels=out_channels))

        in_channels = out_channels
        out_channels *= 2

    model = nn.Sequential(*layers)

    print(model(torch.randn(1, 64, 256, 256)).size())
    print(summary(model=model, input_size=(64, 256, 256)))
    draw_graph(model=model, input_data=torch.randn(1, 64, 256, 256)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_encoder_block"), format="jpeg"
    )

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels = 256, out_channels = 256):
        super(ResidualBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 1
        self.padding = 1

        self.residual_block = self.block()

    def block(self):
        layers = OrderedDict()
        for idx in range(2):
            layers["conv{}".format(idx+1)] = nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                padding=self.padding,
                bias=False,
            )
            layers["instance_norm{}".format(idx+1)] = nn.InstanceNorm2d(
                num_features=self.out_channels)

            if idx==0:
                layers["ReLU"] = nn.ReLU(inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return  x + self.residual_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())

if __name__ == "__main__":
    in_channels = 256
    num_repetitive = 9
    
    layers = []
    
    for idx in tqdm(range(num_repetitive)):
        layers+=[
            ResidualBlock(
            in_channels=in_channels, out_channels=in_channels)
        ]
        
    model = nn.Sequential(*layers)
    print(model(torch.randn(1, 256, 64, 64)).size())
    print(summary(model=model, input_size=(256, 64, 64)))
    draw_graph(model=model, input_data=torch.randn(1, 256, 64, 64)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_residual_block"), format="jpeg"
    )

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels = 256, out_channels = 128):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 2
        self.padding = 1
        self.output_padding = 1

        self.decoder_block = self.block()
        
    def block(self):
        layers = OrderedDict()
        
        layers["convTranspose"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            output_padding=self.output_padding,
            bias=False,
        )
        layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(layers)
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.decoder_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 256
    out_channels = 128
    num_repetitive = 2

    layers = []

    for _ in tqdm(range(num_repetitive)):
        layers.append(
            DecoderBlock(in_channels=in_channels, out_channels=out_channels)
        )
        in_channels = out_channels
        out_channels //= 2
        
    model = nn.Sequential(*layers)
    
    print(model(torch.randn(1, 256, 64, 64)).size())
    print(summary(model=model, input_size=(256, 64, 64)))
    draw_graph(model=model, input_data=torch.randn(1, 256, 64, 64)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_decoder_block"), format="jpeg"
    )

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = 64
        
        self.kernel = 7
        self.stride = 1
        self.padding = 3

        self.layers = []

        self.layers.append(
            InputBlock(in_channels=self.in_channels, out_channels=self.out_channels)
        )
        self.in_channels = self.out_channels
        self.out_channels *= 2
        
        for _ in tqdm(range(2)):
            self.layers.append(
                EncoderBlock(in_channels=self.in_channels, out_channels=self.out_channels)
            )
            
            self.in_channels = self.out_channels
            self.out_channels *= 2
            
        self.in_channels = self.in_channels
        self.out_channels //= 2
        
        for _ in tqdm(range(9)):
            self.layers.append(
                ResidualBlock(in_channels=self.in_channels, out_channels=self.in_channels)
            )
            
        self.out_channels //= 2
        
        for _ in tqdm(range(2)):
            self.layers.append(
                DecoderBlock(in_channels=self.in_channels, out_channels=self.out_channels)
            )
            
            self.in_channels = self.out_channels
            self.out_channels //= 2
            
        self.model = nn.Sequential(*self.layers)
        
        self.output = nn.Sequential(
            nn.Conv2d(
                in_channels=self.out_channels*2,
                out_channels=3,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding,
                bias=False,
            ),
            nn.Tanh(),
        )
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            x = self.model(x)
            return self.output(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 3
    
    netG = Generator(in_channels=in_channels)

    print(netG(torch.randn((1, 3, 256, 256))).size())
    
    print(summary(model=netG, input_size=(3, 256, 256)))
    
    draw_graph(model=netG, input_data=torch.randn((1, 3, 256, 256))).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_model"), format="jpeg"
    )

#### Create the Discriminator 

In [None]:
class DiscriminatorBlock(nn.Module):

    def __init__(self, in_channels=3, out_channels=64, use_norm=False, stride=2):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_norm = use_norm

        self.kernel = 4
        self.stride = stride
        self.padding = 1

        self.discriminator_block = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            bias=False,
        )

        if self.use_norm:
            layers["instance_norm"] = nn.InstanceNorm2d(
                num_features=self.out_channels
            )

        layers["lRelU"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.discriminator_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 3
    out_channels = 64
    num_repetitive = 4
    
    config_files = config()

    layers = []

    for idx in tqdm(range(num_repetitive)):
        layers.append(
            DiscriminatorBlock(
                in_channels=in_channels,
                out_channels=out_channels,
                use_norm=False if idx == 0 else True,
                stride=1 if idx == num_repetitive - 1 else 2,
            )
        )
        in_channels = out_channels
        out_channels *= 2

    model = nn.Sequential(*layers)

    print(model(torch.randn((1, 3, 256, 256))).size())
    
    print(summary(model = model, input_size=(3, 256, 256)))
    
    draw_graph(model = model, input_data=torch.randn((1, 3, 256, 256))).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netD_discriminator_block"), format="jpeg"
    )

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64):
        super(Discriminator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel = 4
        self.stride = 1
        self.padding = 1

        self.layers = []

        for idx in tqdm(range(4)):
            self.layers.append(
                DiscriminatorBlock(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    use_norm=False if idx == 0 else True,
                    stride=1 if idx == 3 else 2,
                )
            )
            self.in_channels = self.out_channels
            self.out_channels *= 2

        self.model = nn.Sequential(*self.layers)

        self.output = nn.Sequential(
            nn.Conv2d(
                in_channels=self.out_channels//2,
                out_channels=1,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding,
                bias=False,
            )
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            x = self.model(x)
            return self.output(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 3
    netD = Discriminator(in_channels=in_channels)
    
    print(netD(torch.randn(1, 3, 256, 256)).size())
    print(summary(model = netD, input_size=(3, 256, 256)))
    draw_graph(model=netD, input_data=torch.randn((1, 3, 256, 256))).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netD_model"), format="jpeg"
    )

#### Define the Loss

In [None]:
class GANLoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(GANLoss, self).__init__()

        self.name = "GANLoss".title()
        self.reduction = reduction

        self.adversarial_loss = nn.MSELoss(reduction=self.reduction)

    def forward(self, actual, pred):
        if isinstance(actual, torch.Tensor) and isinstance(pred, torch.Tensor):
            return self.adversarial_loss(actual, pred)
        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    loss = GANLoss()

    actual = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0,])
    pred = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0,])

    print(loss(actual, pred))  

In [None]:
class CycleLoss(nn.Module):
    def __init__(self, reduction = "mean"):
        super(CycleLoss, self).__init__()
        
        self.name = "CycleLoss".title()
        self.reduction = reduction
        
        self.cycle_loss = nn.L1Loss(reduction=self.reduction)
        
    def forward(self, actual, pred):
        if isinstance(actual, torch.Tensor) and isinstance(pred, torch.Tensor):
            return self.cycle_loss(actual, pred)
        else:
            raise Exception("Unable to process the input".capitalize())
        
        
if __name__ == "__main__":
    loss = CycleLoss()
    
    actual = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0,])
    pred = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0,])
    
    print(loss(actual, pred))

In [None]:
class PixelLoss(nn.Module):
    def __init__(self, reduction = "mean"):
        super(PixelLoss, self).__init__()
        
        self.name = "PixelLoss".title()
        self.reduction = reduction
        
        self.pixel_loss = nn.L1Loss(reduction=self.reduction)
        
    def forward(self, actual, pred):
        if isinstance(actual, torch.Tensor) and isinstance(pred, torch.Tensor):
            return self.pixel_loss(actual, pred)
        else:
            raise Exception("Unable to process the input".capitalize())
        
if __name__ == "__main__":
    loss = PixelLoss()
    
    actual = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0,])
    pred = torch.tensor([1.0, 0.0, 1.0, 0.0, 1.0,])
    
    print(loss(actual, pred))

#### Define the helper method

In [None]:
def load_dataloader():
    if os.path.exists(config()["path"]["processed_path"]):
        path = config()["path"]["processed_path"]

        train_dataloader = load(filename=os.path.join(path, "train_dataloader.pkl"))
        test_dataloader = load(filename=os.path.join(path, "test_dataloader.pkl"))
        dataloader = load(filename=os.path.join(path, "dataloader.pkl"))

        return {
            "train_dataloader": train_dataloader,
            "test_dataloader": test_dataloader,
            "dataloader": dataloader,
        }

    else:
        raise Exception("Can't load dataloader".capitalize())


def helpers(**kwargs):
    in_channels = kwargs["in_channels"]
    lr = kwargs["lr"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]

    out_channels = 64

    netG_XtoY = Generator(in_channels=in_channels)
    netG_YtoX = Generator(in_channels=in_channels)

    netD_X = Discriminator(in_channels=in_channels, out_channels=out_channels)
    netD_Y = Discriminator(in_channels=in_channels, out_channels=out_channels)

    if adam:
        optimizerG = optim.Adam(
            params=list(netG_XtoY.parameters()) + list(netG_YtoX.parameters()),
            lr=lr,
            betas=(0.5, 0.999),
        )
        optimizerD_X = optim.Adam(params=netD_X.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizerD_Y = optim.Adam(params=netD_Y.parameters(), lr=lr, betas=(0.5, 0.999))

    elif SGD:
        optimizerG = optim.SGD(
            params=list(netG_XtoY.parameters()) + list(netG_YtoX.parameters()),
            lr=lr,
        )
        optimizerD_X = optim.SGD(params=netD_X.parameters(), lr=lr, momentum=0.95)
        optimizerD_Y = optim.SGD(params=netD_Y.parameters(), lr=lr, momentum=0.95)

    adversarial_loss = GANLoss()
    cycle_loss = CycleLoss()
    pixel_loss = PixelLoss()

    dataloader = load_dataloader()

    return {
        "netG_XtoY": netG_XtoY,
        "netG_YtoX": netG_YtoX,
        "netD_X": netD_X,
        "netD_Y": netD_Y,
        "optimizerG": optimizerG,
        "optimizerD_X": optimizerD_X,
        "optimizerD_Y": optimizerD_Y,
        "adversarial_loss": adversarial_loss,
        "cycle_loss": cycle_loss,
        "pixel_loss": pixel_loss,
        "dataloader": dataloader["dataloader"],
        "train_dataloader": dataloader["train_dataloader"],
        "test_dataloader": dataloader["test_dataloader"],
    }


if __name__ == "__main__":
    init = helpers(lr=0.0002, adam=True, SGD=False, in_channels=3)
    print(init["netG_XtoY"])
    print(init["netG_YtoX"])
    print(init["netD_X"])
    print(init["netD_Y"])
    print(init["optimizerG"])
    print(init["optimizerD_X"])
    print(init["optimizerD_Y"])
    print(init["adversarial_loss"])
    print(init["cycle_loss"])
    print(init["pixel_loss"])
    print(init["dataloader"])
    print(init["train_dataloader"])
    print(init["test_dataloader"])

#### Trainer

In [None]:
class Trainer:
    def __init__(
        self,
        in_channels=3,
        epochs=500,
        lr=0.0002,
        device="mps",
        adam=True,
        SGD=False,
        lr_scheduler=False,
        is_display=True,
        is_weight_init=False,
        is_save_image=True,
    ):
        self.in_channels = in_channels
        self.epochs = epochs
        self.lr = lr
        self.device = device
        self.adam = adam
        self.SGD = SGD
        self.lr_scheduler = lr_scheduler
        self.is_display = is_display
        self.is_weight_init = is_weight_init
        self.is_save_image = is_save_image

        self.init = helpers(
            lr=self.lr, adam=self.adam, SGD=self.SGD, in_channels=self.in_channels
        )

        self.device = device_init(device=self.device)

        self.netG_XtoY = self.init["netG_XtoY"]
        self.netG_YtoX = self.init["netG_YtoX"]

        self.netD_X = self.init["netD_X"]
        self.netD_Y = self.init["netD_Y"]

        self.optimizerG = self.init["optimizerG"]
        self.optimizerD_X = self.init["optimizerD_X"]
        self.optimizerD_Y = self.init["optimizerD_Y"]

        self.adversarial_loss = self.init["adversarial_loss"]
        self.cycle_loss = self.init["cycle_loss"]
        self.pixel_loss = self.init["pixel_loss"]

        self.dataloader = self.init["dataloader"]
        self.train_dataloader = self.init["train_dataloader"]
        self.test_dataloader = self.init["test_dataloader"]

        self.netG_XtoY.to(self.device)
        self.netG_YtoX.to(self.device)

        self.netD_X.to(self.device)
        self.netD_Y.to(self.device)

        if self.is_weight_init:
            self.netG_XtoY.apply(weights_init)
            self.netG_YtoX.apply(weights_init)
            self.netD_X.apply(weights_init)
            self.netD_Y.apply(weights_init)

        if self.lr_scheduler:
            self.schedulerG = StepLR(
                optimizer=self.optimizerG, step_size=10, gamma=0.85
            )
            self.schedulerD_X = StepLR(
                optimizer=self.optimizerD_X, step_size=10, gamma=0.85
            )
            self.schedulerD_Y = StepLR(
                optimizer=self.optimizerD_Y, step_size=10, gamma=0.85
            )

        self.total_netG_loss = []
        self.total_netD_X_loss = []
        self.total_netD_Y_loss = []

        self.history = {
            "G_loss": [],
            "D_X_loss": [],
            "D_Y_loss": [],
        }

        self.loss = float("inf")

        self.config = config()

    def l1(self, model):
        if isinstance(model, Generator):
            return 0.01 * sum(torch.norm(params, 1) for params in model.parameters())

        else:
            raise Exception(
                "Cannot able to use L1 regularization with Generator".capitalize()
            )

    def l2(self, model):
        if isinstance(model, Generator):
            return 0.01 * sum(torch.norm(params, 2) for params in model.parameters())

        else:
            raise Exception(
                "Cannot able to use L2 regularization with Generator".capitalize()
            )

    def elastic_loss(self, model):
        if isinstance(model, Generator):
            l1 = self.l1(model=model)
            l2 = self.l2(model=model)

            return 0.01 * (l1 + l2)
        else:
            raise Exception(
                "Cannot able to use elastic regularization with Generator".capitalize()
            )

    def saved_checkpoints_netG_XtoY(self, epoch=None):
        if os.path.exists(self.config["path"]["netG_XtoY_path"]):
            torch.save(
                self.netG_XtoY.state_dict(),
                os.path.join(
                    self.config["path"]["netG_XtoY_path"],
                    "netG_XtoY{}.pth".format(epoch),
                ),
            )

        else:
            raise Exception("Cannot able to save the netG_XtoY model".capitalize())

    def saved_checkpoints_netG_YtoX(self, epoch=None):
        if os.path.exists(self.config["path"]["netG_YtoX_path"]):
            torch.save(
                self.netG_YtoX.state_dict(),
                os.path.join(
                    self.config["path"]["netG_YtoX_path"],
                    "netG_YtoX{}.pth".format(epoch),
                ),
            )

        else:
            raise Exception("Cannot able to save the netG_XtoY model".capitalize())

    def saved_train_best_model(self, **kwargs):
        if os.path.exists(self.config["path"]["best_model_path"]):
            path = self.config["path"]["best_model_path"]

            if self.loss > np.mean(kwargs["netG_loss"]):
                self.loss = np.mean(kwargs["netG_loss"])

                torch.save(
                    {
                        "netG_XtoY": self.netG_XtoY.state_dict(),
                        "netG_YtoX": self.netG_YtoX.state_dict(),
                        "epoch": kwargs["epoch"],
                        "loss": np.mean(kwargs["netG_loss"]),
                    },
                    os.path.join(path, "best_model.pth"),
                )

        else:
            raise Exception("Cannot able to save the best_model".capitalize())

    def saved_model_history(self, **kwargs):
        if os.path.exists(self.config["path"]["files_path"]):
            path = self.config["path"]["files_path"]
            pd.DataFrame(
                {
                    "netG_loss": kwargs["netG_loss"],
                    "netD_X_loss": kwargs["netD_X_loss"],
                    "netD_Y_loss": kwargs["netD_Y_loss"],
                }
            ).to_csv(os.path.join(path, "model_history.csv"))

        else:
            raise Exception("Cannot be saved the model history".capitalize())

    def saved_train_images(self, **kwargs):
        if os.path.exists(self.config["path"]["processed_path"]):
            path = self.config["path"]["processed_path"]

            X, _ = next(iter(load(filename=os.path.join(path, "train_dataloader.pkl"))))

            predict_y = self.netG_XtoY(X.to(self.device))
            reconstructed_x = self.netG_YtoX(predict_y)

            for image in [
                ("predict_y", predict_y),
                ("reconstructed_x", reconstructed_x),
            ]:
                save_image(
                    image[1],
                    os.path.join(
                        self.config["path"]["train_results"],
                        image[0] + "{}.png".format(kwargs["epoch"]),
                    ),
                    nrow=1,
                )
        else:
            raise Exception("Cannot be saved the processed images".capitalize())

    def show_progress(self, **kwargs):
        if self.is_display:
            print(
                "Epochs: [{}/{}] - netG_loss: [{:.4f}] - netD_X_loss: {:.4f} - netD_Y_loss: {:.4f}".format(
                    kwargs["epoch"],
                    self.epochs,
                    np.mean(kwargs["netG_loss"]),
                    np.mean(kwargs["netD_X_loss"]),
                    np.mean(kwargs["netD_Y_loss"]),
                )
            )
        else:
            print(
                "Epochs: [{}/{}] is completed".capitalize().format(
                    kwargs["epochs"], self.epochs
                )
            )

    def update_train_netG(self, **kwargs):
        self.optimizerG.zero_grad()

        fake_y = self.netG_XtoY(kwargs["X"])
        fake_y_predict = self.netD_Y(fake_y)
        fake_y_loss = self.adversarial_loss(
            fake_y_predict, torch.ones_like(fake_y_predict)
        )

        fake_x = self.netG_YtoX(kwargs["y"])
        real_x_predict = self.netD_X(fake_x)
        fake_x_loss = self.adversarial_loss(
            real_x_predict, torch.ones_like(real_x_predict)
        )

        reconstructed_x = self.netG_YtoX(fake_y)
        reconstructed_x_loss = self.cycle_loss(kwargs["X"], reconstructed_x)

        reconstructed_y = self.netG_XtoY(fake_x)
        reconstructed_y_loss = self.cycle_loss(kwargs["y"], reconstructed_y)

        pixel_loss_y = self.pixel_loss(kwargs["y"], fake_y)
        pixel_loss_x = self.pixel_loss(kwargs["X"], fake_x)

        total_G_loss = (
            (0.5 * (fake_y_loss + fake_x_loss))
            + (0.5 * (reconstructed_x_loss + reconstructed_y_loss))
            + (0.5 * (pixel_loss_x + pixel_loss_y))
        )

        total_G_loss.backward()
        self.optimizerG.step()

        return total_G_loss.item()

    def update_train_netD_X(self, **kwargs):
        self.optimizerD_X.zero_grad()

        fake_x = self.netG_YtoX(kwargs["y"])
        real_x_predict = self.netD_X(kwargs["X"])
        fake_x_predict = self.netD_X(fake_x)

        real_x_loss = self.adversarial_loss(
            real_x_predict, torch.ones_like(real_x_predict)
        )
        fake_x_loss = self.adversarial_loss(
            fake_x_predict, torch.zeros_like(fake_x_predict)
        )

        d_x_loss = (real_x_loss + fake_x_loss) / 2

        d_x_loss.backward()
        self.optimizerD_X.step()

        return d_x_loss.item()

    def update_train_netD_Y(self, **kwargs):
        self.optimizerD_Y.zero_grad()

        fake_y = self.netG_XtoY(kwargs["X"])
        real_y_predict = self.netD_Y(kwargs["y"])
        fake_y_predict = self.netD_Y(fake_y)

        real_y_loss = self.adversarial_loss(
            real_y_predict, torch.ones_like(real_y_predict)
        )
        fake_y_loss = self.adversarial_loss(
            fake_y_predict, torch.zeros_like(fake_y_predict)
        )

        d_y_loss = (real_y_loss + fake_y_loss) / 2

        d_y_loss.backward()
        self.optimizerD_Y.step()

        return d_y_loss.item()

    def train(self):
        warnings.filterwarnings("ignore")

        for epoch in tqdm(range(self.epochs)):
            netG_loss = []
            netD_X_loss = []
            netD_Y_loss = []

            for idx, (X, y) in enumerate(self.train_dataloader):
                try:
                    X = X.to(self.device)
                    y = y.to(self.device)

                    netD_X_loss.append(self.update_train_netD_X(X=X, y=y))
                    netD_Y_loss.append(self.update_train_netD_Y(X=X, y=y))
                    netG_loss.append(self.update_train_netG(X=X, y=y))

                except Exception as e:
                    print(f"An error occurred during the training process: {e}")
                    continue

            try:
                self.show_progress(
                    netG_loss=netG_loss,
                    netD_X_loss=netD_X_loss,
                    netD_Y_loss=netD_Y_loss,
                    epoch=epoch + 1,
                )

                self.saved_checkpoints_netG_XtoY(epoch=epoch + 1)
                self.saved_checkpoints_netG_YtoX(epoch=epoch + 1)
                self.saved_train_best_model(epoch=epoch + 1, netG_loss=netG_loss)

            except Exception as e:
                print(
                    f"An error occurred while saving checkpoints or updating progress: {e}"
                )

            self.total_netG_loss.append(np.mean(netG_loss))
            self.total_netD_X_loss.append(np.mean(netD_X_loss))
            self.total_netD_Y_loss.append(np.mean(netD_Y_loss))

            if (epoch + 1) % 50 and (self.is_save_image):
                self.saved_train_images(epoch=epoch + 1)

            if self.lr_scheduler:
                self.schedulerG.step()
                self.schedulerD_X.step()
                self.schedulerD_Y.step()

        try:
            self.history["G_loss"].extend(self.total_netG_loss)
            self.history["D_X_loss"].extend(self.total_netD_X_loss)
            self.history["D_Y_loss"].extend(self.total_netD_Y_loss)

            self.saved_model_history(
                netG_loss=self.total_netG_loss,
                netD_X_loss=self.total_netD_X_loss,
                netD_Y_loss=self.total_netD_Y_loss,
            )

            if os.path.exists(self.config["path"]["metrics_path"]):
                for file in [
                    ("netG", self.total_netG_loss),
                    ("netD_X", self.total_netD_X_loss),
                    ("netD_Y", self.total_netD_Y_loss),
                ]:
                    dump(
                        value=file[1],
                        filename=os.path.join(
                            self.config["path"]["metrics_path"], file[0] + ".pkl"
                        ),
                    )

            else:
                raise Exception("Cannot be saved the metrics".capitalize())
        except Exception as e:
            print(f"An error occurred while saving the training history: {e}")

    @staticmethod
    def plot_history():
        config_files = config()
        if os.path.exists(config_files["path"]["metrics_path"]):
            path = config_files["path"]["metrics_path"]

            netG = os.path.join(path, "netG.pkl")
            netD_X = os.path.join(path, "netD_X.pkl")
            netD_Y = os.path.join(path, "netD_Y.pkl")
            files = [netG, netD_X, netD_Y]
            labels = ["Generator Loss", "Discriminator X Loss", "Discriminator Y Loss"]

            plt.figure(figsize=(20, 10))

            for index, (file, label) in enumerate(zip(files, labels)):
                if os.path.exists(file):
                    data = load(filename=file)
                    plt.subplot(1, 3, index + 1)
                    plt.plot(data, label=label)
                    plt.xlabel("Epochs")
                    plt.ylabel("Loss")
                    plt.title(label)
                    plt.legend()

                else:
                    print(f"Error: {file} does not exist.")

            plt.tight_layout()
            (
                plt.savefig(
                    os.path.join(
                        config_files["path"]["metrics_path"], "model_history.jpeg"
                    )
                )
                if os.path.exists(config_files["path"]["metrics_path"])
                else "Cannot be saved the image of the model history".capitalize()
            )
            plt.show()

        else:
            raise Exception("Cannot be open the metrics files".capitalize())

#### Test Model

In [None]:
import imageio

class TestModel:
    def __init__(self, netG_XtoY=None, netG_YtoX=None, device="mps"):
        self.XtoY = netG_XtoY
        self.YtoX = netG_YtoX

        self.device = device

        self.device = device_init(self.device)
        self.config = config()

        self.netG_XtoY = Generator()
        self.netG_YtoX = Generator()

        self.netG_XtoY.to(self.device)
        self.netG_YtoX.to(self.device)

    def select_best_model(self):
        if os.path.exists(self.config["path"]["best_model_path"]):
            model_state_dict = torch.load(
                os.path.join(self.config["path"]["best_model_path"], "best_model.pth")
            )

            self.netG_XtoY.load_state_dict(model_state_dict["netG_XtoY"])
            self.netG_YtoX.load_state_dict(model_state_dict["netG_YtoX"])

        else:
            if isinstance(self.XtoY, Generator) and isinstance(self.YtoX, Generator):
                state_dict_XtoY = torch.load(self.XtoY)
                state_dict_YtoX = torch.load(self.YtoX)

                self.netG_XtoY.load_state_dict(state_dict_XtoY)
                self.netG_YtoX.load_state_dict(state_dict_YtoX)

            else:
                raise ValueError("XtoY and YtoX should be defined".capitalize())

    def create_gif_file(self):
        if os.path.exists(self.config["path"]["train_results"]):
            path = self.config["path"]["train_results"]

            self.images = [
                imageio.imread(os.path.join(path, image)) for image in os.listdir(path)
            ]

            if os.path.exists(self.config["path"]["gif_path"]):
                path = self.config["path"]["gif_path"]

                imageio.mimsave(
                    os.path.join(path, "train_results.gif"), self.images, "GIF"
                )

            else:
                raise Exception("Cannot create the GIF file".capitalize())

        else:
            raise Exception(
                "Cannot extract the images from the train images directory".capitalize()
            )

    def load_dataloader(self):
        if os.path.exists(self.config["path"]["processed_path"]):
            path = self.config["path"]["processed_path"]

            self.test_dataloader = load(
                filename=os.path.join(path, "test_dataloader.pkl")
            )

            return self.test_dataloader

        else:
            raise Exception("processed_path does not exist".capitalize())

    def image_normalized(self, image=None):
        if image is not None:
            return (image - image.min()) / (image.max() - image.min())

        else:
            raise ValueError("image should be a torch.Tensor".capitalize())

    def create_plot(self, **kwargs):
        plt.figure(figsize=(40, 40))

        X = kwargs["X"]
        y = kwargs["y"]

        predicted_y = self.netG_XtoY(X.to(self.device))
        reconstructed_x = self.netG_YtoX(predicted_y)

        for index, image in enumerate(predicted_y):
            real_X = X[index].squeeze().permute(1, 2, 0).cpu().detach().numpy()
            pred_y = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
            real_y = y[index].squeeze().permute(1, 2, 0).cpu().detach().numpy()
            revert_X = reconstructed_x[index].squeeze().permute(1, 2, 0).cpu().detach()

            real_X = self.image_normalized(image=real_X)
            pred_y = self.image_normalized(image=pred_y)
            real_y = self.image_normalized(image=real_y)
            revert_X = self.image_normalized(image=revert_X)

            plt.subplot(4 * 4, 4 * 1, 4 * index + 1)
            plt.imshow(real_X)
            plt.title("X")
            plt.axis("off")

            plt.subplot(4 * 4, 4 * 1, 4 * index + 2)
            plt.imshow(pred_y)
            plt.title("pred_Y")
            plt.axis("off")

            plt.subplot(4 * 4, 4 * 1, 4 * index + 3)
            plt.imshow(real_y)
            plt.title("Y")
            plt.axis("off")

            plt.subplot(4 * 4, 4 * 1, 4 * index + 4)
            plt.imshow(revert_X)
            plt.title("Reconstructed_X")
            plt.axis("off")

        plt.tight_layout()
        if os.path.exists(self.config["path"]["test_result"]):
            path = self.config["path"]["test_result"]
            plt.savefig(os.path.join(path, "test_result.png"))
            print(
                """The result is saved as test_result.png in the "./outputs/test_result" directory"""
            )
        plt.show()

    def test(self):
        try:
            self.select_best_model()
        except Exception as e:
            print("An error occurred {}".format(e))
        else:
            self.test_dataloader = self.load_dataloader()

            X, y = next(iter(self.test_dataloader))

            self.create_plot(X=X, y=y)
            self.create_gif_file()

#### Inference

In [None]:
class Inference(TestModel):

    def __init__(
        self,
        image_size=512,
        channels=3,
        dataloader="dataloader",
        image=None,
        best_model=True,
        XtoY=None,
        YtoX=None,
        device="mps",
    ):
        super(Inference, self).__init__(
            dataloader=dataloader,
            best_model=best_model,
            netG_XtoY=XtoY,
            netG_YtoX=YtoX,
            device=device,
        )
        self.image_size = image_size
        self.channels = channels
        self.image = image

        self.batch_path = self.config["path"]["batch_results_path"]
        self.single_path = self.config["path"]["single_results_path"]

    def transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.Grayscale(num_output_channels=self.channels),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

    def single_image(self):
        try:
            self.select_best_model()

            read_x = cv2.imread(self.image)
            if read_x is None:
                raise ValueError(f"Image at path {self.image} could not be loaded.")

            X = self.transforms()(Image.fromarray(read_x))
            X = X.unsqueeze(0).to(self.device)

            predict_y = self.netG_XtoY(X)
            reconstructed_x = self.netG_YtoX(predict_y)

            predict_y = predict_y.squeeze().permute(1, 2, 0).cpu().detach().numpy()
            reconstructed_x = (
                reconstructed_x.squeeze().permute(1, 2, 0).cpu().detach().numpy()
            )

            predict_y = self.image_normalized(image=predict_y)
            reconstructed_x = self.image_normalized(image=reconstructed_x)

            plt.imshow(predict_y)
            plt.savefig(os.path.join(self.single_path, "pred_y.png"))
            plt.close()

            plt.imshow(reconstructed_x)
            plt.savefig(os.path.join(self.single_path, "reconstructed_x.png"))
            plt.close()

        except Exception as e:
            print(f"An error occurred: {e}")
            raise

    def batch_images(self):
        try:
            self.select_best_model()
            self.dataloader = self.load_dataloader()

            count = 0
            for _, (X, _) in enumerate(self.dataloader):
                predicted_y = self.netG_XtoY(X.to(self.device))
                reconstructed_x = self.netG_YtoX(predicted_y)

                for idx, image in tqdm(enumerate(predicted_y)):

                    pred_y = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
                    pred_y = self.image_normalized(image=pred_y)

                    revert_x = (
                        reconstructed_x[idx]
                        .squeeze()
                        .permute(1, 2, 0)
                        .detach()
                        .cpu()
                        .numpy()
                    )
                    revert_x = self.image_normalized(image=revert_x)

                    if not os.path.exists(self.batch_path):
                        os.makedirs(self.batch_path)

                    plt.imshow(pred_y)
                    plt.savefig(os.path.join(self.batch_path, f"pred_y{count + 1}.png"))
                    plt.close()

                    plt.imshow(revert_x)
                    plt.savefig(
                        os.path.join(self.batch_path, f"reconstructed_x{count + 1}.png")
                    )
                    plt.close()

                    count += 1

        except Exception as e:
            print(f"An error occurred: {e}")
            raise