In [4]:
import os
import cv2
import yaml
import torch
import joblib
import zipfile
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [5]:
RAW_DATA_PATH = "../../data/raw/"
PROCESSED_DATA_PATH = "../../data/processed/"

In [6]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
        
    else:
        raise ValueError("Value or filename cannot be None".capitalize())

def load(filename = None):
    if filename is not None:
        return joblib.load(filename)

    else:
        raise ValueError("Filename cannot be None".capitalize())


def config():
    with open("../../config.yml", "r") as ymlfile:
        return yaml.safe_load(ymlfile)


def device_init(device="cuda"):
    if device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    elif device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    else:
        return torch.device("cpu")


def weight_init(m):
    classname = m.__class__.__name__

    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

    elif classname.find("BatchNorm") != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Loader():
    def __init__(self, image_path = None, image_size = 64, split_size = 0.20, batch_size = 1):
        self.image_path = image_path
        self.image_size = image_size
        self.split_size = split_size
        self.batch_size = batch_size

        self.LR = []
        self.HR = []

    def split_dataset(self, X = None, y = None):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = self.split_size, random_state=42, shuffle=True)

            return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}

    def transforms(self, type = "lr"):
        if type == "lr":
            return transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

        elif type == "hr":
            return transforms.Compose([
                transforms.Resize((self.image_size*4, self.image_size*4)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size*4, self.image_size*4)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

    def unzip_folder(self):
        if os.path.exists(RAW_DATA_PATH):
            with zipfile.ZipFile(self.image_path, "r") as zip_file:
                zip_file.extractall(os.path.join(RAW_DATA_PATH))
        else:
            raise Exception("RAW data path is not found".capitalize())

    def feature_extraction(self):

        self.directory = os.path.join(RAW_DATA_PATH, "dataset")

        self.higher_resolution_images = os.path.join(self.directory, "HR")
        self.low_resolution_images = os.path.join(self.directory, "LR")

        for image in os.listdir(self.low_resolution_images):
            if image in os.listdir(self.higher_resolution_images):
                lower_resolution_image_path = os.path.join(self.low_resolution_images, image)
                higher_resolution_image_path = os.path.join(self.higher_resolution_images, image)

                lower_resolution_image = cv2.imread(lower_resolution_image_path)
                higher_resolution_image = cv2.imread(higher_resolution_image_path)

                lower_resolution_image = cv2.cvtColor(lower_resolution_image, cv2.COLOR_BGR2RGB)
                higher_resolution_image = cv2.cvtColor(higher_resolution_image, cv2.COLOR_BGR2RGB)

                lower_resolution_image = Image.fromarray(lower_resolution_image)
                higher_resolution_image = Image.fromarray(higher_resolution_image)

                self.LR.append(self.transforms(type="lr")(lower_resolution_image))
                self.HR.append(self.transforms(type="hr")(higher_resolution_image))

        assert len(self.LR) == len(self.HR)

        print("Total {} images have been captured".format(len(self.LR)).capitalize())

        return self.split_dataset(X=self.LR, y=self.HR)

    def create_dataloader(self):
        try:
            dataset = self.feature_extraction()

        except Exception as e:
            raise Exception("Feature extraction process has been failed".capitalize())

        else:
            train_dataloader = DataLoader(
                dataset = list(zip(dataset["X_train"], dataset["y_train"])),
                batch_size = self.batch_size,
                shuffle = True
            )
            valid_dataloader = DataLoader(
                dataset = list(zip(dataset["X_test"], dataset["y_test"])),
                batch_size=self.batch_size*8,
                shuffle=True
            )

            for dataloader, filename in [(train_dataloader, "train_dataloader"), (valid_dataloader, "valid_dataloader")]:
                dump(value=dataloader, filename=os.path.join(PROCESSED_DATA_PATH, filename+".pkl"))

            print("train and valid dataloader has been created in the folder : {}".format(PROCESSED_DATA_PATH).capitalize())

    @staticmethod
    def plot_images():
        dataloader = load(filename=os.path.join(PROCESSED_DATA_PATH, "valid_dataloader.pkl"))

        data, labels = next(iter(dataloader))

        plt.figure(figsize=(20, 10))

        for index, image in enumerate(data):
            X = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
            y = labels[index].squeeze().permute(1, 2, 0).detach().cpu().numpy()

            X = (X - X.min()) / (X.max() - X.min())
            y = (y - y.min()) / (y.max() - y.min())

            plt.subplot(2 * 2, 2 * 4, 2 * index + 1)
            plt.imshow(X)
            plt.title("LR")
            plt.axis("off")

            plt.subplot(2 * 2, 2 * 4, 2 * index + 2)
            plt.imshow(y)
            plt.title("HR")
            plt.axis("off")

        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    loader = Loader(
        image_path="../../data/raw/dataset.zip",
        image_size=64,
        split_size=0.40
    )
    loader.unzip_folder()
    loader.create_dataloader()
    
    loader.plot_images()

In [7]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 64):
        super(DenseBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        self.slope = 0.2

        self.block1 = self.block(in_channels = 1 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block2 = self.block(in_channels = 2 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block3 = self.block(in_channels = 3 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block4 = self.block(in_channels = 4 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block5 = self.block(in_channels = 5 * self.in_channels, out_channels = self.out_channels, use_leaky = False)

    def block(self, in_channels = 64, out_channels = 64, use_leaky = True):
        self.layers = []

        self.layers.append(
            nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            bias=True
            )
        )

        if use_leaky:
            self.layers.append(
                nn.LeakyReLU(
                negative_slope=self.slope, inplace=True
                )
            )

        return nn.Sequential(*self.layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            outputs = self.block1(x)
            inputs = torch.concat((outputs, x), dim = 1)

            outputs = self.block2(inputs)
            inputs = torch.concat((outputs, inputs), dim = 1)

            outputs = self.block3(inputs)
            inputs = torch.concat((outputs, inputs), dim = 1)

            outputs = self.block4(inputs)
            inputs = torch.concat((outputs, inputs), dim = 1)
            
            outputs = self.block5(inputs)
            
            return outputs

        else:
            raise TypeError("Input must be a tensor".capitalize())
        
        
if __name__ == "__main__":
    layers = []
    for _ in range(5):
        layers += [
            DenseBlock(in_channels = 64, out_channels = 64)
        ]
        
    model = nn.Sequential(*layers)
    
    assert model(torch.randn(1, 64, 256, 256)).size() == (1, 64, 256, 256)

In [None]:
class ResidualInResidual(nn.Module):
    def __init__(self, in_channels=64, res_scale=0.2):
        super(ResidualInResidual, self).__init__()

        self.in_channels = in_channels
        self.res_scale = res_scale

        self.denseblock1 = DenseBlock(
            in_channels=self.in_channels, out_channels=self.in_channels
        )
        self.denseblock2 = DenseBlock(
            in_channels=self.in_channels, out_channels=self.in_channels
        )
        self.denseblock3 = DenseBlock(
            in_channels=self.in_channels, out_channels=self.in_channels
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            output1 = self.denseblock1(x)
            input2 = output1 + x

            output2 = self.denseblock2(input2)
            input3 = output2 + input2

            output = self.denseblock3(input3)
            output = torch.mul(output, self.res_scale) + input3

            return output


if __name__ == "__main__":
    residual_in_residual = ResidualInResidual(in_channels=64)

    print(residual_in_residual)

In [9]:
class OutputBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=3):
        super(OutputBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1
        self.negative_slope = 0.2
        self.upscale_factor = 2

        self.output_block = self.block()

    def block(self):

        self.layers = []

        for idx in range(2):
            self.layers.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=self.in_channels,
                        out_channels=self.in_channels * 4,
                        kernel_size=self.kernel_size,
                        stride=self.stride_size,
                        padding=self.padding_size,
                        bias=True,
                    ),
                    nn.PixelShuffle(upscale_factor=self.upscale_factor),
                )
            )
            if idx == 0:
                self.layers.append(
                    nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
                )

        return nn.Sequential(*self.layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.output_block(x)
        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":

    outblock = OutputBlock(in_channels=64, out_channels=64)

    assert outblock(torch.randn(1, 64, 64, 64)).size() == (1, 64, 256, 256)

In [10]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1

        self.layers = []

        self.input_block = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=True,
        )

        self.residual_in_residual_denseblock = nn.Sequential(
            *[ResidualInResidual(in_channels=self.out_channels) for _ in range(16)]
        )

        self.middle_block = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=True,
        )

        self.output = nn.Sequential(
            OutputBlock(in_channels=self.out_channels, out_channels=self.out_channels),
            nn.Conv2d(
                in_channels=self.out_channels,
                out_channels=self.in_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
            ),
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            input_block = self.input_block(x)
            residual_block = self.residual_in_residual_denseblock(input_block)
            middle_block = self.middle_block(residual_block)
            middle_block = torch.add(input_block, middle_block)
            output = self.output(middle_block)

            return output

    @staticmethod
    def total_params(model=None):
        if isinstance(model, Generator):
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":

    netG = Generator(in_channels=3, out_channels=64)

    assert Generator.total_params(model=netG) == 26893315

    assert netG(torch.randn(1, 3, 64, 64)).size() == (1, 3, 256, 256)

In [11]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super(Discriminator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 3
        self.stride_size = 2
        self.padding_size = 1
        self.negative_slope = 0.2

        self.layers = []

        self.input_block = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=True,
            ),
            nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
        )

        for idx in range(3):
            self.layers.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=self.out_channels,
                        out_channels=self.out_channels * 2,
                        kernel_size=self.kernel_size,
                        stride=(
                            self.stride_size // 2 if idx % 3 == 0 else self.stride_size
                        ),
                        padding=self.padding_size,
                    ),
                    nn.BatchNorm2d(num_features=self.out_channels * 2),
                    nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
                )
            )
            self.out_channels = self.out_channels * 2

        self.immediate_block = nn.Sequential(*self.layers)

        self.ouput_block = nn.Sequential(
            nn.Conv2d(
                in_channels=self.out_channels,
                out_channels=self.in_channels // self.in_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size // self.stride_size,
            )
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            input = self.input_block(x)
            immediate = self.immediate_block(input)
            output = self.ouput_block(immediate)

            return output

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())

    @staticmethod
    def total_params(model):
        if isinstance(model, Discriminator):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)

        else:
            raise TypeError("Model must be a Discriminator".capitalize())


if __name__ == "__main__":

    netD = Discriminator(in_channels=3, out_channels=64)

    assert netD(torch.randn(1, 3, 256, 256)).size() == (1, 1, 30, 30)

    assert Discriminator.total_params(netD) == 1557377

In [None]:
import torch
import torch.nn as nn


class MSELoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(MSELoss, self).__init__()

        self.reduction = reduction

        self.loss = nn.MSELoss(reduction=reduction)

    def forward(self, pred, target):
        if isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor):
            return self.loss(pred, target)

        else:
            raise TypeError("Input must be torch.Tensor".capitalize())


if __name__ == "__main__":
    loss = MSELoss(reduction="mean")

    target = torch.tensor([1.0, 0.0, 1.0, 0.0])
    predicted = torch.tensor([1.0, 0.0, 1.0, 1.0])

    print(loss(predicted, target))

In [13]:
import torchvision.models as models

import warnings

warnings.filterwarnings("ignore")


class VGG19(nn.Module):
    def __init__(self, name="VGG19"):
        super(VGG19, self).__init__()

        self.name = name

        self.model = models.vgg19(pretrained=True)

        for params in self.model.parameters():
            params.requires_grad = False

        self.model = nn.Sequential(*list(self.model.children())[0][:35])

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.model(x)
        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":

    model = VGG19(name="VGG19")

    assert model(torch.randn(1, 3, 256, 256)).size() == (1, 512, 16, 16)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim


def load_dataset():

    if config()["path"]["PROCESSED_DATA_PATH"]:
        train_daloader = os.path.join(
            config()["path"]["PROCESSED_DATA_PATH"], "train_dataloader.pkl"
        )
        valid_dataloader = os.path.join(
            config()["path"]["PROCESSED_DATA_PATH"], "valid_dataloader.pkl"
        )

        return {
            "train_dataloader": load(train_daloader),
            "valid_dataloader": load(valid_dataloader),
        }

    else:
        raise Exception("No processed data found".capitalize())


def helper(**kwargs):
    lr = kwargs["lr"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]
    beta1 = kwargs["beta1"]
    beta2 = kwargs["beta2"]
    momentum = kwargs["momentum"]

    netG = Generator(in_channels=3, out_channels=64)
    netD = Discriminator(in_channels=64, out_channels=64)

    if adam:
        optimizerG = optim.Adam(params=netG.parameters(), lr=lr, betas=(beta1, beta2))
        optimizerD = optim.Adam(params=netD.parameters(), lr=lr, betas=(beta1, beta2))

    elif SGD:
        optimizerG = optim.SGD(params=netG.parameters(), lr=lr, momentum=momentum)
        optimizerD = optim.SGD(params=netD.parameters(), lr=lr, momentum=momentum)

    try:
        dataset = load_dataset()

    except Exception as e:
        print(e)

    adversarial_loss = MSELoss(reduction="mean")
    perceptual_loss = VGG19(name="VGG19")

    return {
        "netG": netG,
        "netD": netD,
        "optimizerG": optimizerG,
        "optimizerD": optimizerD,
        "adversarial_loss": adversarial_loss,
        "perceptual_loss": perceptual_loss,
        "train_dataloader": dataset["train_dataloader"],
        "valid_dataloader": dataset["valid_dataloader"],
    }


if __name__ == "__main__":
    init = helper(
        lr=0.0002,
        adam=True,
        SGD=False,
        beta1=0.5,
        beta2=0.999,
        momentum=0.9,
    )

    assert init["netG"].__class__.__name__ == "Generator"
    assert init["netD"].__class__.__name__ == "Discriminator"
    assert init["optimizerG"].__class__.__name__ == "Adam"
    assert init["optimizerD"].__class__.__name__ == "Adam"
    assert init["adversarial_loss"].__class__.__name__ == "MSELoss"
    assert init["perceptual_loss"].__class__.__name__ == "VGG19"
    assert type(init["train_dataloader"]) == torch.utils.data.DataLoader
    assert type(init["valid_dataloader"]) == torch.utils.data.DataLoader

In [None]:
import argparse
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision.utils import save_image


class Trainer:
    def __init__(
        self,
        epochs=100,
        lr=0.0002,
        beta1=0.5,
        beta2=0.999,
        adam=True,
        SGD=False,
        momentum=0.9,
        content_loss=0.01,
        pixel_loss=0.05,
        device="cuda",
        lr_scheduler=False,
        is_weight_init=False,
        verbose=True,
    ):
        self.epochs = epochs
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.adam = adam
        self.SGD = SGD
        self.momentum = momentum
        self.content_loss = content_loss
        self.pixel_loss = pixel_loss
        self.device = device
        self.lr_scheduler = lr_scheduler
        self.is_weight_init = is_weight_init
        self.is_verbose = verbose

        self.init = helper(
            lr=self.lr,
            beta1=self.beta1,
            beta2=self.beta2,
            adam=self.adam,
            SGD=self.SGD,
            momentum=self.momentum,
        )

        self.device = device_init(device=self.device)

        self.netG = self.init["netG"].to(self.device)
        self.netD = self.init["netD"].to(self.device)

        if self.is_weight_init:
            self.netG.apply(weight_init)
            self.netD.apply(weight_init)

        self.optimizerG = self.init["optimizerG"]
        self.optimizerD = self.init["optimizerD"]

        self.adversarial_loss = self.init["adversarial_loss"].to(self.device)
        self.perceptual_loss = self.init["perceptual_loss"].to(self.device)

        self.train_dataloader = self.init["train_dataloader"]
        self.val_dataloader = self.init["valid_dataloader"]

        self.loss = float("inf")

        self.CONFIG = config()

        self.history = {"netG_loss": [], "netD_loss": []}

    def l1_loss(self, model=None):
        if model is not None:
            return sum(torch.norm(params, 1) for params in model.parameters())

    def l2_loss(self, model=None):
        if model is not None:
            return sum(torch.norm(params, 2) for params in model.parameters())

    def elastic_loss(self, model=None):
        l1 = self.l1_loss(model)
        l2 = self.l2_loss(model)

        return l1 + l2

    def update_netG(self, **kwargs):
        self.optimizerG.zero_grad()

        lr_image = kwargs["lr_image"]
        hr_image = kwargs["hr_image"]

        generated_hr_image = self.netG(lr_image)
        predicted_generated_hr_image = self.netD(generated_hr_image)

        real_features = self.perceptual_loss(hr_image)
        generated_features = self.perceptual_loss(generated_hr_image)

        content_loss = torch.abs(real_features - generated_features).mean()
        pixelwise_loss = torch.abs(hr_image - generated_hr_image).mean()

        loss_generated_hr_image = self.adversarial_loss(
            predicted_generated_hr_image, torch.ones_like(predicted_generated_hr_image)
        )

        total_loss = (
            loss_generated_hr_image
            + self.content_loss * content_loss
            + self.pixel_loss * pixelwise_loss
        )

        total_loss.backward()
        self.optimizerG.step()

        return total_loss.item()

    def update_netD(self, **kwargs):
        self.optimizerD.zero_grad()

        lr_image = kwargs["lr_image"]
        hr_image = kwargs["hr_image"]

        generated_hr_image = self.netG(lr_image)
        predicted_generated_hr_image = self.netD(generated_hr_image)

        predicted_real_hr_image = self.netD(hr_image)

        loss_generated_hr_image = self.adversarial_loss(
            predicted_generated_hr_image, torch.zeros_like(predicted_generated_hr_image)
        )
        loss_real_hr_image = self.adversarial_loss(
            predicted_real_hr_image, torch.ones_like(predicted_real_hr_image)
        )

        total_loss = 0.5 * (loss_generated_hr_image + loss_real_hr_image)

        total_loss.backward()
        self.optimizerD.step()

        return total_loss.item()

    def saved_checkpoints(self, **kwargs):
        netG_loss = kwargs["netG_loss"]

        if self.loss > netG_loss:
            self.loss = netG_loss

            torch.save(
                {
                    "netG": self.netG.state_dict(),
                    "loss": netG_loss,
                    "epoch": kwargs["epoch"],
                },
                os.path.join(
                    self.CONFIG["path"]["BEST_MODEL_CHECKPOINT_PATH"], "netG.pth"
                ),
            )
        torch.save(
            self.netG.state_dict(),
            os.path.join(
                self.CONFIG["path"]["TRAIN_MODEL_CHECKPOINT_PATH"],
                "netG{}.pth".format(kwargs["epoch"] + 1),
            ),
        )

    def show_progress(self, **kwargs):
        if self.is_verbose:
            print(
                "Epochs - [{}/{}] - netG_loss: [{:.4f}] - netD_loss: [{:.4f}]".format(
                    kwargs["epoch"] + 1,
                    self.epochs,
                    kwargs["netG_loss"],
                    kwargs["netD_loss"],
                )
            )

        else:
            print(
                "Epochs - [{}/{}] is completed".capitalize().format(
                    kwargs["epoch"] + 1, self.epochs
                )
            )

    def saved_train_images(self, epoch=None):

        lr_image, hr_image = next(iter(self.train_dataloader))
        lr_image = lr_image.to(self.device)
        hr_image = hr_image.to(self.device)

        generated_hr_image = self.netG(lr_image)

        save_image(
            generated_hr_image,
            os.path.join(
                config()["path"]["TRAIN_IMAGES_PATH"],
                "image_{}.png".format(epoch + 1),
            ),
        )

    def train(self):
        for epoch in tqdm(range(self.epochs)):
            self.netD_loss = []
            self.netG_loss = []

            for _, (lr_image, hr_image) in enumerate(self.train_dataloader):

                lr_image = lr_image.to(self.device)
                hr_image = hr_image.to(self.device)

                self.netD_loss.append(
                    self.update_netD(lr_image=lr_image, hr_image=hr_image)
                )
                self.netG_loss.append(
                    self.update_netG(lr_image=lr_image, hr_image=hr_image)
                )

            self.show_progress(
                epoch=epoch,
                netG_loss=np.mean(self.netG_loss),
                netD_loss=np.mean(self.netD_loss),
            )

            self.saved_train_images(epoch=epoch)
            self.saved_checkpoints(netG_loss=np.mean(self.netG_loss), epoch=epoch)

            self.history["netG_loss"].append(np.mean(self.netG_loss))
            self.history["netD_loss"].append(np.mean(self.netD_loss))

        dump(
            value=self.history,
            filename=os.path.join(self.CONFIG["path"]["METRICS_PATH"], "history.pkl"),
        )

    @staticmethod
    def plot_history():

        if os.path.exists(config()["path"]["METRICS_PATH"]):
            history = load(
                filename=os.path.join(config()["path"]["METRICS_PATH"], "history.pkl")
            )

        else:
            raise Exception("Metrics path cannot be extraced".capitalize())

        plt.figure(figsize=(10, 5))

        plt.plot(history["netG_loss"], label="netG_loss")
        plt.plot(history["netD_loss"], label="netD_loss")
        plt.xlabel("Epochs")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()


if __name__ == "__main__":

    trainer = Trainer(
        epochs=100,
        device="cuda",
        lr=0.0002,
        beta1=0.5,
        beta2=0.999,
        pixel_loss=5e-3,
        content_loss=1e-3,
        adam=True,
        SGD=False,
        momentum=0.95,
        verbose=True,
        lr_scheduler=False,
        weight_init=True
    )

    trainer.train()
    trainer.plot_history()

In [None]:
class Tester:
    def __init__(self, model="best", dataloader="valid", device="cuda"):
        self.model = model
        self.dataloader = dataloader
        self.device = device

        self.device = device_init(device=self.device)

        self.netG = Generator(in_channels=3, out_channels=64).to(self.device)

    def select_best_model(self):
        if self.model == "best":
            state_dict = torch.load(
                os.path.join(config()["path"]["BEST_MODEL_CHECKPOINT_PATH"], "netG.pth")
            )

            return state_dict["netG"]

        else:
            return torch.load(self.model)

    def load_dataloader(self):
        if self.dataloader == "valid":
            return load(
                filename=os.path.join(
                    config()["path"]["PROCESSED_DATA_PATH"], "valid_dataloader.pkl"
                )
            )

        else:
            return load(
                filename=os.path.join(
                    config()["path"]["PROCESSED_DATA_PATH"], "train_dataloader.pkl"
                )
            )

    def plot(self, X=None, y=None, generated_hr_image=None):
        plt.figure(figsize=(40, 15))

        for index, image in enumerate(generated_hr_image):
            gen_image = image.permute(1, 2, 0).detach().cpu().numpy()
            real_lr_image = X[index].permute(1, 2, 0).detach().cpu().numpy()
            real_hr_image = y[index].permute(1, 2, 0).detach().cpu().numpy()

            gen_image = (gen_image - gen_image.min()) / (
                gen_image.max() - gen_image.min()
            )
            real_lr_image = (real_lr_image - real_lr_image.min()) / (
                real_lr_image.max() - real_lr_image.min()
            )
            real_hr_image = (real_hr_image - real_hr_image.min()) / (
                real_hr_image.max() - real_hr_image.min()
            )

            plt.subplot(3 * 2, 3 * 4, 3 * index + 2)
            plt.imshow(real_lr_image)
            plt.title("real_LR".capitalize())
            plt.axis("off")

            plt.subplot(3 * 2, 3 * 4, 3 * index + 1)
            plt.imshow(gen_image)
            plt.title("Generated".capitalize())
            plt.axis("off")

            plt.subplot(3 * 2, 3 * 4, 3 * index + 3)
            plt.imshow(real_hr_image)
            plt.title("real_HR".capitalize())
            plt.axis("off")

        plt.tight_layout()
        plt.savefig(os.path.join(config()["path"]["TEST_IMAGE_PATH"], "test.png"))
        plt.show()

        print("Image is saved in the folder: ", config()["path"]["TEST_IMAGE_PATH"])

    def test(self):
        try:

            self.netG.load_state_dict(self.select_best_model())

        except Exception as e:
            print("0000", e)
            return

        try:

            datloader = self.load_dataloader()

            X, y = next(iter(datloader))

        except Exception as e:
            print(e)
            return

        generated_hr_image = self.netG(X.to(self.device))

        try:
            self.plot(X, y, generated_hr_image)

        except Exception as e:
            print(e)
            return


if __name__ == "__main__":

    tester = Tester(
        model=None,
        dataloader="valid",
        device="cuda",
    )

    tester.test()