In [None]:
import os
import cv2
import yaml
import torch
import joblib
import zipfile
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import torch.optim as optim
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision import transforms
from torchview import draw_graph
from torchsummary import summary
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [4]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value = value, filename = filename)
        
    else:
        raise Exception("Please provide the value and filename to dump".capitalize())


def load(filename):
    if filename is not None:
        return joblib.load(filename = filename)


def device_init(device="mps"):
    if device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    elif device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        return torch.device("cuda")

In [None]:
def config():
    with open("../../config.yml", "r") as file:
        config_files = yaml.safe_load(file)
        
    return config_files

In [None]:
class Loader():
    def __init__(
        self,
        image_path = None,
        channels = 3,
        image_size = 256,
        batch_size = 1,
        split_size = 0.20,
        paired_images = False,
        unpaired_images = True
        ):
        self.image_path = image_path
        self.channels = channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size
        self.paired_images = paired_images
        self.unpaired_images = unpaired_images

        self.X = []
        self.y = []

        self.raw_path = config()["path"]["raw_path"]
        self.processed_path = config()["path"]["processed_path"]

    def unzip_folder(self):
        with zipfile.ZipFile(self.image_path, "r") as file:
            if os.path.exists(self.raw_path):
                file.extractall(os.path.join(self.raw_path))

            else:
                raise Exception("Cannot unzip the folder for further process".capitalize())

    def transforms(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.CenterCrop((self.image_size, self.image_size)),
            transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
        ])

    def image_split(self, X, y):
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=self.split_size, random_state=42
        )

        return {
            "X_train": X_train,
            "X_test": X_test,
            "y_train": y_train,
            "y_test": y_test,
        }

    def feature_extractor(self):
        self.directory = os.path.join(self.raw_path, "images")
        self.categories = ["X", "y"]

        self.paired_check = os.listdir(os.path.join(self.directory, "y"))

        for category in tqdm(self.categories):   
            folder_path = os.path.join(self.directory, category) 

            for image in os.listdir(folder_path): 
                if self.paired_images:
                    if image in self.paired_check:
                        image_path = os.path.join(folder_path, image) 
                else:
                    image_path = os.path.join(folder_path, image)

                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                image = self.transforms()(Image.fromarray(image))

                self.X.append(image) if category == "X" else self.y.append(image)                  

        return self.image_split(X = self.X, y = self.y)

    def create_dataloader(self):
        try:
            data = self.feature_extractor()

        except Exception as e:
            print("An error occurred while creating the dataloader".capitalize(), e)

        else:
            train_dataloader = DataLoader(
                dataset=list(zip(data["X_train"], data["y_train"])),
                batch_size=self.batch_size,
                shuffle=True
            )
            test_dataloader = DataLoader(
                dataset=list(zip(data["X_test"], data["y_test"])),
                batch_size=self.batch_size*8,
                shuffle=True
            )

            if os.path.exists(self.processed_path):
                dump(
                    value=train_dataloader,
                    filename=os.path.join(self.processed_path, "train_dataloader.pkl")
                )
                dump(
                    value=test_dataloader,
                    filename=os.path.join(self.processed_path, "test_dataloader.pkl")
                )

            else:
                raise Exception("Cannot create the dataloader for further process".capitalize())

    @staticmethod
    def plot_images():
        config_files = config()
        files_path = config_files["path"]["files_path"]
        processed_path = config_files["path"]["processed_path"]

        if os.path.exists(files_path):
            test_dataloader = load(
                filename=os.path.join(processed_path, "test_dataloader.pkl")
            )

            X, y = next(iter(test_dataloader))

        else:
            raise Exception("Cannot load the dataloader for further process".capitalize())

        plt.figure(figsize=(20, 10))

        for index, image in enumerate(X):
            image_X = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
            image_y = y[index].squeeze().permute(1, 2, 0).cpu().detach().numpy()

            image_X = (image_X - image_X.min())/(image_X.max() - image_X.min())
            image_y = (image_y - image_y.min())/(image_y.max() - image_y.min())

            plt.subplot(2* 2, 2 * 4, 2 * index + 1)
            plt.imshow(image_X)
            plt.axis("off")
            plt.title("X")

            plt.subplot(2* 2, 2 * 4, 2 * index + 2)
            plt.imshow(image_y)
            plt.axis("off")
            plt.title("y")

        plt.tight_layout()
        plt.savefig(os.path.join(files_path, "images.png")) if os.path.exists(files_path) else "Cannot be saved the images".capitalize()
        plt.show()

    @staticmethod
    def dataset_details():
        config_files = config()

        files_path = config_files["path"]["files_path"]

        train_dataloader = load(os.path.join(
            config_files["path"]["processed_path"], "train_dataloader.pkl"
        ))
        test_dataloader = load(os.path.join(
            config_files["path"]["processed_path"], "test_dataloader.pkl"
        ))

        pd.DataFrame(
            {
                "train_data(total)": str(sum(X.size(0) for X, _ in train_dataloader)),
                "test_data(total)": str(sum(X.size(0) for X, _ in test_dataloader)),
                "total_data": str(sum(X.size(0) for X, _ in train_dataloader)
                + sum(X.size(0) for X, _ in test_dataloader)),
                "train_data_shape": str(train_dataloader.dataset[0][0].shape),
                "test_data_shape": str(test_dataloader.dataset[0][0].shape)
            },
            index=["Quantity"]
        ).T.to_csv(os.path.join(files_path, "dataset_details.csv") if os.path.exists(files_path)\
            else "Cannot be saved the dataset into csv format".capitalize())


if __name__ == "__main__":
    loader = Loader(
        image_path="/Users/shahmuhammadraditrahman/Desktop/images.zip",
        unpaired_images=True,
        split_size=0.50
    )
    #loader.unzip_folder()
    loader.create_dataloader()
    loader.plot_images()
    loader.dataset_details()

#### Create the Generator

In [None]:
from collections import OrderedDict

class Encoder(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, use_batch_norm = True):
        super(Encoder, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.batch_norm = use_batch_norm

        self.kernel = 4
        self.stride = 2
        self.padding = 1

        self.encoder = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
            )
        )

        layers["leaky_relu"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        if self.batch_norm:
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)
        
        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder(x)

        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Encoder):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        else:
            raise Exception("Please provide the model for further process".capitalize())

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels = 512, out_channels = 512):
        super(Decoder, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 4
        self.stride = 2
        self.padding = 1

        self.decoder = self.block()
        
    def block(self):
        layers = OrderedDict()

        layers["deconv"] = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
            )
        )

        layers["relu"] = nn.ReLU(inplace=True)
        layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)

        return nn.Sequential(layers)
    
    def forward(self, x, skip_info):
        if isinstance(x, torch.Tensor):
            x =  self.decoder(x)
            return torch.concat((x, skip_info), dim=1)

        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Decoder):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        else:
            raise Exception("Please provide the model for further process".capitalize())

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = 64
        self.kernel = 4
        self.stride = 2
        self.padding = 1
        
        self.encoder1 = Encoder(in_channels=self.in_channels, out_channels = self.out_channels)
        self.encoder2 = Encoder(in_channels=self.out_channels, out_channels = self.out_channels*2)
        self.encoder3 = Encoder(in_channels=self.out_channels*2, out_channels=self.out_channels*4)
        self.encoder4 = Encoder(in_channels=self.out_channels*4, out_channels=self.out_channels*8)
        self.encoder5 = Encoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        self.encoder6 = Encoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        self.encoder7 = Encoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        
        self.decoder1 = Decoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        self.decoder2 = Decoder(in_channels=self.out_channels*8*2, out_channels=self.out_channels*8)
        self.decoder3 = Decoder(in_channels=self.out_channels*8*2, out_channels=self.out_channels*8)
        self.decoder4 = Decoder(in_channels=self.out_channels*8*2, out_channels=self.out_channels*4)
        self.decoder5 = Decoder(in_channels=self.out_channels*4*2, out_channels=self.out_channels*2)
        self.decoder6 = Decoder(in_channels=self.out_channels*2*2, out_channels=self.out_channels)
        
        self.output = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=128,
                out_channels=3,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
            ),
            nn.Tanh()
        )
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            encoder1 = self.encoder1(x)
            encoder2 = self.encoder2(encoder1)
            encoder3 = self.encoder3(encoder2)
            encoder4 = self.encoder4(encoder3)
            encoder5 = self.encoder5(encoder4)
            encoder6 = self.encoder6(encoder5)
            encoder7 = self.encoder7(encoder6)
            
            decoder1 = self.decoder1(encoder7, encoder6)
            decoder2 = self.decoder2(decoder1, encoder5)
            decoder3 = self.decoder3(decoder2, encoder4)
            decoder4 = self.decoder4(decoder3, encoder3)
            decoder5 = self.decoder5(decoder4, encoder2)
            decoder6 = self.decoder6(decoder5, encoder1)
            output = self.output(decoder6)
            
            return output
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Generator):
            return sum(p.numel() for p in model.parameters())
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
        
if __name__ == "__main__":
    in_channels = 3
    netG = Generator(in_channels=in_channels)
    
    summary(model=netG, input_size=(3, 256, 256))
    
    config_files = config()
    files_path = config_files["path"]["files_path"]
    
    draw_graph(model=netG, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(files_path, "netG"), format="png" if os.path.exists(files_path) else "Cannot saved the netG model architecture".capitalize()
    )

#### Discriminator

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, use_batch_norm = False):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.batch_norm = use_batch_norm

        self.kernel = 4
        self.stride = 2
        self.padding = 1

        self.discriminator = self.block()
        
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
                
            )
        )
        
        if self.batch_norm:
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)
        
        layers["leaky_ReLU"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        return nn.Sequential(layers)
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.discriminator(x)
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Discriminator, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = 64
        
        self.layers = []
        
        for idx in range(3):
            self.layers.append(
                DiscriminatorBlock(in_channels=self.in_channels, out_channels=self.out_channels, use_batch_norm=False if idx == 0 else True)
            )
            self.in_channels = self.out_channels
            self.out_channels *= 2
            
        self.block = nn.Sequential(*self.layers)
        
        self.output = nn.Sequential(
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=0
            )
        )
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.output(self.block(x))
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model = None):
        if isinstance(model, Discriminator):
            return sum(p.numel() for p in model.parameters())
        
        else:
            raise Exception("Please provide the model for further process".capitalize())
        
        
if __name__ == "__main__":
    in_channels = 3
    config_files = config()
    
    netD = Discriminator(in_channels=in_channels)
    
    assert netD(torch.randn(1, 3, 256, 256)).size() == (1, 1, 30, 30)
    
    summary(model=netD, input_size=(3, 256, 256))
    
    draw_graph(model=netD, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netD"), format="png" if config_files["path"]["files_path"] else "Cannot be saved netD file".capitalize()
    )

#### Define loss

In [None]:
class CycleLoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(CycleLoss, self).__init__()

        self.name = "CycleLoss".title()
        self.reduction = reduction

        self.loss = nn.L1Loss(reduction=self.reduction)

    def forward(self, actual, pred):
        if isinstance(actual, torch.Tensor) and isinstance(pred, torch.Tensor):
            return self.loss(actual, pred)
        else:
            raise TypeError("Inputs must be torch.Tensor".capitalize())


if __name__ == "__main__":
    loss = CycleLoss()
    
    actual = torch.tensor([1.0, 0.0, 1.0, 1.0])
    predicted = torch.tensor([0.0, 1.0, 0.0, 0.0])
    
    print("The loss {}".format(loss(actual, predicted)))

In [None]:
class GradientPenalty(nn.Module):
    def __init__(self, in_channels=3, batch_size=1, device="mps"):
        super(GradientPenalty, self).__init__()

        self.in_channels = in_channels
        self.batch_size = batch_size
        self.device = device

        self.device = device_init(device=self.device)

        self.alpha = torch.randn(
            self.batch_size, self.in_channels // self.in_channels, 1, 1
        )

    def forward(self, netD, X, y):
        if (
            isinstance(netD, Discriminator)
            and isinstance(X, torch.Tensor)
            and isinstance(y, torch.Tensor)
        ):
            interpolated = (self.alpha * X) + ((1 - self.alpha) * y)
            interpolated = interpolated.requires_grad_(True)

            netD = netD.to(self.device)

            d_interpolated = netD(interpolated.to(self.device))

            gradients = torch.autograd.grad(
                outputs=d_interpolated,
                inputs=interpolated,
                grad_outputs=torch.ones_like(d_interpolated).to(self.device),
                create_graph=True,
                retain_graph=True,
            )[0]

            gradients = gradients.view(gradients.size(0), -1)
            gradients = torch.norm(gradients, 2, dim=1)

            return ((gradients - 1) ** 2).mean()

        else:
            raise TypeError("Inputs must be torch.Tensor".capitalize())


if __name__ == "__main__":
    batch_size = 1
    device = device_init(device="mps")
    
    netD = Discriminator(in_channels=in_channels)

    X = torch.randn(1, 3, 256, 256)
    y = torch.randn(1, 3, 256, 256)

    grad_penalty = GradientPenalty(
        in_channels=in_channels, batch_size=batch_size, device=device
    )

    print(grad_penalty(netD, X, y))

#### Helper

In [None]:
def load_dataloader():
    config_files = config()
    dataloader_path = config_files["path"]["processed_path"]

    if os.path.exists(dataloader_path):
        train_dataloader = load(
            filename=os.path.join(dataloader_path, "train_dataloader.pkl")
        )
        test_dataloader = load(
            filename=os.path.join(dataloader_path, "test_dataloader.pkl")
        )

        return {
            "train_dataloader": train_dataloader,
            "test_dataloader": test_dataloader,
        }

    else:
        raise Exception("Cannot load the dataloader for further process".capitalize())


def helpers(**kwargs):
    in_channels = kwargs["in_channels"]
    lr = kwargs["lr"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]
    device = kwargs["device"]

    netG_XtoY = Generator(in_channels=in_channels)
    netG_YtoX = Generator(in_channels=in_channels)

    netD_X = Discriminator(in_channels=in_channels)
    netD_Y = Discriminator(in_channels=in_channels)

    if adam:
        optimizerG = optim.Adam(
            params=list(netG_XtoY.parameters()) + list(netG_YtoX.parameters()),
            lr=lr,
            betas=(0.5, 0.999),
        )
        optimizerD_X = optim.Adam(
            params=netD_X.parameters(),
            lr=lr,
            betas=(0.5, 0.999),
        )
        optimizerD_Y = optim.Adam(
            params=netD_Y.parameters(),
            lr=lr,
            betas=(0.5, 0.999),
        )

    elif SGD:
        optimizerG = optim.SGD(
            params=list(netG_XtoY.parameters()) + list(netG_YtoX.parameters()),
            lr=lr,
            momentum=0.95,
        )
        optimizerD_X = optim.SGD(params=netD_X.parameters(), lr=lr, momentum=0.95)
        optimizerD_Y = optim.SGD(params=netD_Y.parameters(), lr=lr, momentum=0.95)

    try:
        dataloader = load_dataloader()

    except Exception as e:
        print("An error occurred {}".format(e))

    cycle_loss = CycleLoss(reduction="mean")
    grad_penalty = GradientPenalty(in_channels=in_channels, batch_size=1, device=device)

    return {
        "optimizerG": optimizerG,
        "optimizerD_X": optimizerD_X,
        "optimizerD_Y": optimizerD_Y,
        "netG_XtoY": netG_XtoY,
        "netG_YtoX": netG_YtoX,
        "netD_X": netD_X,
        "netD_Y": netD_Y,
        "cycle_loss": cycle_loss,
        "grad_penalty": grad_penalty,
        "train_dataloader": dataloader["train_dataloader"],
        "test_dataloader": dataloader["test_dataloader"],
    }


if __name__ == "__main__":
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    init = helpers(
        lr=0.001,
        adam=True,
        SGD=False,
        in_channels=3,
        device=device,
    )

    print(init["netG_XtoY"])
    print(init["netG_YtoX"])

    print(init["netD_X"])
    print(init["netD_Y"])

    print(init["optimizerG"])
    print(init["optimizerD_X"])
    print(init["optimizerD_Y"])

    print(init["train_dataloader"])
    print(init["test_dataloader"])

    print(init["cycle_loss"])
    print(init["grad_penalty"])

#### Trainer

In [None]:
class Trainer:
    def __init__(
        self,
        in_channels=3,
        epochs=500,
        lr=2e-3,
        num_critics=5,
        device="cuda",
        adam=True,
        SGD=False,
        is_weight_init=True,
        is_display=True,
        lr_scheduler=False,
    ):
        self.in_channels = in_channels
        self.epochs = epochs
        self.lr = lr
        self.num_critics = num_critics
        self.device = device
        self.adam = adam
        self.SGD = SGD
        self.is_weight_init = is_weight_init
        self.is_display = is_display
        self.lr_scheduler = lr_scheduler

        self.device = device_init(device=self.device)

        self.init = helpers(
            in_channels=in_channels,
            lr=lr,
            adam=adam,
            SGD=SGD,
            device=device_init(device=device),
        )

        self.netGX_toY = self.init["netG_XtoY"].to(self.device)
        self.netGY_toX = self.init["netG_YtoX"].to(self.device)

        self.netD_X = self.init["netD_X"].to(self.device)
        self.netD_Y = self.init["netD_Y"].to(self.device)

        self.optimizerG = self.init["optimizerG"]
        self.optimizerD_X = self.init["optimizerD_X"]
        self.optimizerD_Y = self.init["optimizerD_Y"]

        self.train_dataloader = self.init["train_dataloader"]
        self.test_dataloader = self.init["test_dataloader"]

        self.cycle_loss = self.init["cycle_loss"]
        self.grad_penalty = self.init["grad_penalty"]

        if self.is_weight_init:
            self.netGX_toY.apply(weight_init)
            self.netGY_toX.apply(weight_init)
            self.netD_X.apply(weight_init)
            self.netD_Y.apply(weight_init)

        if self.lr_scheduler:
            self.schedulerG = StepLR(
                optimizer=self.optimizerG,
                step_size=20,
                gamma=0.5,
            )
            self.schedulerD_X = StepLR(
                optimizer=self.optimizerD_X,
                step_size=20,
                gamma=0.5,
            )
            self.schedulerD_Y = StepLR(
                optimizer=self.optimizerD_Y,
                step_size=20,
                gamma=0.5,
            )

        try:
            self.config_files = config()
        except Exception as e:
            print("An Error occurred in the code:", e)
        else:
            self.netG_XtoY_path = self.config_files["path"]["netG_XtoY_path"]
            self.netG_YtoX_path = self.config_files["path"]["netG_YtoX_path"]
            self.best_model_path = self.config_files["path"]["best_model_path"]
            self.train_results = self.config_files["path"]["train_results"]
            self.metrics_path = self.config_files["path"]["metrics_path"]

        self.loss = float("inf")

        self.total_netG_loss = []
        self.total_netD_X_loss = []
        self.total_netD_Y_loss = []

        self.history = {"netG_loss": [], "netD_X_loss": [], "netD_Y_loss": []}

    def l1(self, model):
        if isinstance(model, Generator):
            return sum(torch.norm(params, 1) for params in model.parameters())

        else:
            raise ValueError("Model is not a Generator".capitalize())

    def l2(self, model):
        if isinstance(model, Generator):
            return sum(torch.norm(params, 2) for params in model.parameters())

        else:
            raise ValueError("Model is not a Generator".capitalize())

    def elastic_net(self, model):
        if isinstance(model, Generator):
            return self.l1(model) + self.l2(model)

        else:
            raise ValueError("Model is not a Generator".capitalize())

    def update_netG(self, **kwargs):
        self.optimizerG.zero_grad()

        fake_y = self.netGX_toY(kwargs["X"])
        fake_y_loss = -torch.mean(self.netD_Y(fake_y))
        reconstructed_X = self.netGY_toX(fake_y)

        fake_x = self.netGY_toX(kwargs["y"])
        fake_x_loss = -torch.mean(self.netD_X(fake_x))
        reconstructed_y = self.netGX_toY(fake_x)

        total_W_loss = 1.0 * (fake_x_loss + fake_y_loss)

        cycle_loss_X = self.cycle_loss(kwargs["X"], reconstructed_X).mean()
        cycle_loss_y = self.cycle_loss(kwargs["y"], reconstructed_y).mean()

        total_cycle_loss = 10.0 * (cycle_loss_X + cycle_loss_y)

        pixel_loss_X = torch.abs(fake_x - kwargs["X"]).mean()
        pixel_loss_y = torch.abs(fake_y - kwargs["y"]).mean()

        total_content_loss = 10.0 * (pixel_loss_X + pixel_loss_y)

        total_G_loss = total_W_loss + total_cycle_loss + total_content_loss

        total_G_loss.backward()
        self.optimizerG.step()

        return total_G_loss.item()

    def update_netD_X(self, **kwargs):
        self.optimizerD_X.zero_grad()

        fake_x = self.netGY_toX(kwargs["y"]).detach()
        real_x_predict = self.netD_X(kwargs["X"])
        fake_x_predict = self.netD_X(fake_x)
        grad_X_loss = self.grad_penalty(
            self.netD_X, kwargs["X"], kwargs["y"], device=self.device
        )
        netD_X_loss = (
            -torch.mean(real_x_predict)
            + torch.mean(fake_x_predict)
            + (10.0 * grad_X_loss)
        )

        netD_X_loss.backward()
        self.optimizerD_X.step()

        return netD_X_loss.item()

    def update_netD_Y(self, **kwargs):
        self.optimizerD_Y.zero_grad()

        fake_y = self.netGX_toY(kwargs["X"]).detach()
        real_y_predict = self.netD_Y(kwargs["y"])
        fake_y_predict = self.netD_Y(fake_y)
        grad_Y_loss = self.grad_penalty(
            self.netD_Y, kwargs["X"], kwargs["y"], device=self.device
        )
        netD_Y_loss = (
            -torch.mean(real_y_predict)
            + torch.mean(fake_y_predict)
            + (10.0 * grad_Y_loss)
        )

        netD_Y_loss.backward()
        self.optimizerD_Y.step()

        return netD_Y_loss.item()

    def show_progress(self, **kwargs):
        if self.is_display:
            print(
                "Epochs - [{}/{}] - netG_loss: {} - netD_X_loss: {} - netD_Y_loss: {}".format(
                    kwargs["epoch"],
                    self.epochs,
                    np.mean(kwargs["netG_loss"]),
                    np.mean(kwargs["netD_X_loss"]),
                    np.mean(kwargs["netD_Y_loss"]),
                )
            )
        else:
            print(
                "Epochs - [{}/{}] is completed".format(
                    kwargs["epoch"],
                    self.epochs,
                )
            )

    def saved_checkpoints(self, **kwargs):
        if (
            os.path.exists(self.netG_XtoY_path)
            and os.path.exists(self.netG_YtoX_path)
            and os.path.exists(self.best_model_path)
        ):
            for filename, model, path in [
                ("netG_XtoY", self.netGX_toY, self.netG_XtoY_path),
                ("netG_YtoX", self.netGY_toX, self.netG_YtoX_path),
            ]:
                torch.save(
                    model.state_dict(),
                    os.path.join(path, "{}{}.pth".format(filename, kwargs["epoch"])),
                )

            if self.loss > kwargs["netG_loss"]:
                self.loss = kwargs["netG_loss"]

                torch.save(
                    {
                        "netG_XtoY": self.netGX_toY.state_dict(),
                        "netG_YtoX": self.netGY_toX.state_dict(),
                        "netG_loss": kwargs["netG_loss"],
                        "epoch": kwargs["epoch"],
                    },
                    os.path.join(self.best_model_path, "best_model.pth"),
                )

        else:
            raise Exception("Cannot save the model")

    def early_stopping(self):
        pass

    def saved_training_results(self, **kwargs):
        X, y = next(iter(self.test_dataloader))

        predict_y = self.netGX_toY(X.to(self.device))
        reconstructed_X = self.netGY_toX(predict_y.to(self.device))

        for filename, samples in [
            ("predict_y", predict_y),
            ("reconstructed_X", reconstructed_X),
        ]:
            save_image(
                samples,
                os.path.join(
                    self.train_results, "{}{}.png".format(filename, kwargs["epoch"])
                ),
                normalize=True,
                nrow=4,
            )

    def train(self):
        for epoch in tqdm(range(self.epochs)):
            netG_loss = []
            netD_X_loss = []
            netD_Y_loss = []

            for idx, (X, y) in enumerate(self.train_dataloader):
                X = X.to(self.device)
                y = y.to(self.device)

                netD_X_loss.append(self.update_netD_X(X=X, y=y))
                netD_Y_loss.append(self.update_netD_Y(X=X, y=y))

                if (idx + 1) % self.num_critics:
                    netG_loss.append(self.update_netG(X=X, y=y))

            if self.lr_scheduler:
                self.schedulerG.step()
                self.schedulerD_X.step()
                self.schedulerD_Y.step()

            self.show_progress(
                epoch=epoch + 1,
                netG_loss=netG_loss,
                netD_X_loss=netD_X_loss,
                netD_Y_loss=netD_Y_loss,
            )

            self.saved_checkpoints(epoch=epoch + 1, netG_loss=np.mean(netG_loss))

            if (epoch + 1) % 50:
                self.saved_training_results(epoch=epoch + 1)

            self.total_netG_loss.append(np.mean(netG_loss))
            self.total_netD_X_loss.append(np.mean(netD_X_loss))
            self.total_netD_Y_loss.append(np.mean(netD_Y_loss))

        self.history["netG_loss"].extend(self.total_netG_loss)
        self.history["netD_X_loss"].extend(self.total_netD_X_loss)
        self.history["netD_Y_loss"].extend(self.total_netD_Y_loss)

        for filename, file in self.history.items():

            dump(
                value=file,
                filename=os.path.join(self.metrics_path, "{}.pkl".format(filename)),
            )

    @staticmethod
    def plot_history():
        config_files = config()
        metrics_path = config_files["path"]["metrics_path"]

        plt.figure(figsize=(20, 10))

        if os.path.exists(metrics_path):
            history = load(os.path.join(metrics_path, "history.pkl"))
            for index, (filename, loss) in enumerate(
                [
                    ("netG_loss", history["netG_loss"]),
                    ("netD_X_loss", history["netD_X_loss"]),
                    ("netD_Y_loss", history["netD_Y_loss"]),
                ]
            ):
                plt.subplot(1, 3, index + 1)

                plt.plot(loss)
                plt.title(filename)

                plt.xlabel("Epochs")
                plt.ylabel("Loss")
                plt.legend()

            plt.tight_layout()
            plt.savefig(os.path.join(metrics_path, "history.png"))
            plt.show()

        else:
            raise Exception("Cannot save the metrics".capitalize())

### Test Model


In [None]:
class TestModel:
    def __init__(
        self,
        in_channels=3,
        dataloader="test",
        best_model=True,
        XtoY=None,
        YtoX=None,
        device="cuda",
        create_gif=False,
    ):
        self.in_channels = in_channels
        self.dataloader = dataloader
        self.best_model = best_model
        self.XtoY = XtoY
        self.YtoX = YtoX
        self.device = device_init(device=device)
        self.create_gif_images = create_gif

        self.netG_XtoY = Generator(in_channels=3).to(self.device)
        self.netG_YtoX = Generator(in_channels=3).to(self.device)

        self.config_files = config()

    def select_best_model(self):
        if self.best_model:
            best_model_path = self.config_files["path"]["best_model_path"]
            if os.path.exists(best_model_path):
                state_dict = torch.load(os.path.join(best_model_path, "best_model.pth"))

                self.netG_XtoY.load_state_dict(state_dict["netG_XtoY"])
                self.netG_YtoX.load_state_dict(state_dict["netG_YtoX"])

        else:
            if isinstance(self.XtoY, str) and isinstance(self.YtoX, str):
                state_XtoY = torch.load(self.XtoY)
                state_YtoX = torch.load(self.YtoX)

                self.netG_XtoY.load_state_dict(state_XtoY["netG_XtoY"])
                self.netG_YtoX.load_state_dict(state_YtoX["netG_YtoX"])

    def load_dataloader(self):
        if self.dataloader == "test":
            dataloader = load(
                filename=os.path.join(
                    self.config_files["path"]["processed_path"], "test_dataloader.pkl"
                )
            )
        elif self.dataloader == "train":
            dataloader = load(
                filename=os.path.join(
                    self.config_files["path"]["processed_path"], "train_dataloader.pkl"
                )
            )
        else:
            raise ValueError("Invalid dataloader")

        return dataloader

    def normalize_image(self, image):
        return (image - image.min()) / (image.max() - image.min())

    def plot(self, dataloader=None):
        if isinstance(dataloader, DataLoader):
            dataloader = dataloader

            plt.figure(figsize=(10, 12))

            X, y = next(iter(dataloader))

            X = X.to(self.device)
            y = y.to(self.device)

            predict_Y = self.netG_XtoY(X)
            reconstructed_X = self.netG_YtoX(predict_Y)

            for index, image in enumerate(predict_Y):
                pred_y = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
                pred_y = self.normalize_image(pred_y)

                revert_X = (
                    reconstructed_X[index]
                    .squeeze()
                    .permute(1, 2, 0)
                    .cpu()
                    .detach()
                    .numpy()
                )
                revert_X = self.normalize_image(revert_X)

                real_X = X[index].permute(1, 2, 0).cpu().detach().numpy()
                real_X = self.normalize_image(real_X)

                real_Y = y[index].permute(1, 2, 0).cpu().detach().numpy()
                real_Y = self.normalize_image(real_Y)

                for idx, (title, image) in enumerate(
                    [
                        ("real X", real_X),
                        ("pred Y", pred_y),
                        ("real Y", real_Y),
                        ("revert X", revert_X),
                    ]
                ):

                    plt.subplot(4 * 2, 4 * 2, 4 * index + (idx + 1))
                    plt.imshow(image)
                    plt.title(title)
                    plt.axis("off")

            plt.savefig(
                os.path.join(
                    self.config_files["path"]["test_result"], "test_result.jpeg"
                )
            )
            print(
                "Test images saved in the folder: ",
                self.config_files["path"]["test_result"],
            )
            plt.show()
            plt.close()

        else:
            raise ValueError("Invalid dataloader".capitalize())

    def create_gif(self):
        if self.create_gif_images:
            train_images = self.config_files["path"]["train_results"]
            gif_path = self.config_files["path"]["gif_path"]
            images = [
                imageio.imread(os.path.join(train_images, image))
                for image in os.listdir(train_images)
            ]

            imageio.mimsave(os.path.join(gif_path, "gif.gif"), images, "GIF")

        else:
            pass

    def test(self):
        try:
            self.select_best_model()
        except Exception as e:
            print(f"An error occurred while selecting the best model: {e}")
            return

        try:
            dataloader = self.load_dataloader()
        except Exception as e:
            print(f"An error occurred while loading the dataloader: {e}")
            return

        try:
            self.plot(dataloader=dataloader)
        except Exception as e:
            print(f"An error occurred while plotting: {e}")
            return