In [2]:
import os
import cv2
import torch
import joblib
import zipfile
import torch.nn as nn
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [3]:
RAW_DATA_PATH = "../../data/raw/"
PROCESSED_DATA_PATH = "../../data/processed/"

In [None]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
        
    else:
        raise ValueError("Value or filename cannot be None".capitalize())
    
def load(filename = None):
    if filename is not None:
        return joblib.load(filename)
    
    else:
        raise ValueError("Filename cannot be None".capitalize())

In [None]:
class Loader():
    def __init__(self, image_path = None, image_size = 64, split_size = 0.20, batch_size = 1):
        self.image_path = image_path
        self.image_size = image_size
        self.split_size = split_size
        self.batch_size = batch_size

        self.LR = []
        self.HR = []

    def split_dataset(self, X = None, y = None):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = self.split_size, random_state=42, shuffle=True)

            return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}

    def transforms(self, type = "lr"):
        if type == "lr":
            return transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

        elif type == "hr":
            return transforms.Compose([
                transforms.Resize((self.image_size*4, self.image_size*4)),
                transforms.ToTensor(),
                transforms.CenterCrop((self.image_size*4, self.image_size*4)),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

    def unzip_folder(self):
        if os.path.exists(RAW_DATA_PATH):
            with zipfile.ZipFile(self.image_path, "r") as zip_file:
                zip_file.extractall(os.path.join(RAW_DATA_PATH))
        else:
            raise Exception("RAW data path is not found".capitalize())

    def feature_extraction(self):

        self.directory = os.path.join(RAW_DATA_PATH, "dataset")

        self.higher_resolution_images = os.path.join(self.directory, "HR")
        self.low_resolution_images = os.path.join(self.directory, "LR")

        for image in os.listdir(self.low_resolution_images):
            if image in os.listdir(self.higher_resolution_images):
                lower_resolution_image_path = os.path.join(self.low_resolution_images, image)
                higher_resolution_image_path = os.path.join(self.higher_resolution_images, image)

                lower_resolution_image = cv2.imread(lower_resolution_image_path)
                higher_resolution_image = cv2.imread(higher_resolution_image_path)

                lower_resolution_image = cv2.cvtColor(lower_resolution_image, cv2.COLOR_BGR2RGB)
                higher_resolution_image = cv2.cvtColor(higher_resolution_image, cv2.COLOR_BGR2RGB)

                lower_resolution_image = Image.fromarray(lower_resolution_image)
                higher_resolution_image = Image.fromarray(higher_resolution_image)

                self.LR.append(self.transforms(type="lr")(lower_resolution_image))
                self.HR.append(self.transforms(type="hr")(higher_resolution_image))

        assert len(self.LR) == len(self.HR)

        print("Total {} images have been captured".format(len(self.LR)).capitalize())

        return self.split_dataset(X=self.LR, y=self.HR)

    def create_dataloader(self):
        try:
            dataset = self.feature_extraction()

        except Exception as e:
            raise Exception("Feature extraction process has been failed".capitalize())

        else:
            train_dataloader = DataLoader(
                dataset = list(zip(dataset["X_train"], dataset["y_train"])),
                batch_size = self.batch_size,
                shuffle = True
            )
            valid_dataloader = DataLoader(
                dataset = list(zip(dataset["X_test"], dataset["y_test"])),
                batch_size=self.batch_size*8,
                shuffle=True
            )

            for dataloader, filename in [(train_dataloader, "train_dataloader"), (valid_dataloader, "valid_dataloader")]:
                dump(value=dataloader, filename=os.path.join(PROCESSED_DATA_PATH, filename+".pkl"))

            print("train and valid dataloader has been created in the folder : {}".format(PROCESSED_DATA_PATH).capitalize())

    @staticmethod
    def plot_images():
        dataloader = load(filename=os.path.join(PROCESSED_DATA_PATH, "valid_dataloader.pkl"))

        data, labels = next(iter(dataloader))

        plt.figure(figsize=(20, 10))

        for index, image in enumerate(data):
            X = image.squeeze().permute(1, 2, 0).detach().cpu().numpy()
            y = labels[index].squeeze().permute(1, 2, 0).detach().cpu().numpy()

            X = (X - X.min()) / (X.max() - X.min())
            y = (y - y.min()) / (y.max() - y.min())

            plt.subplot(2 * 2, 2 * 4, 2 * index + 1)
            plt.imshow(X)
            plt.title("LR")
            plt.axis("off")

            plt.subplot(2 * 2, 2 * 4, 2 * index + 2)
            plt.imshow(y)
            plt.title("HR")
            plt.axis("off")

        plt.tight_layout()
        plt.show()


if __name__ == "__main__":
    loader = Loader(
        image_path="../../data/raw/dataset.zip",
        image_size=64,
        split_size=0.40
    )
    loader.unzip_folder()
    loader.create_dataloader()
    
    loader.plot_images()

In [4]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 64):
        super(DenseBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 3
        self.stride = 1
        self.padding = 1
        self.slope = 0.2

        self.block1 = self.block(in_channels = 1 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block2 = self.block(in_channels = 2 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block3 = self.block(in_channels = 3 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block4 = self.block(in_channels = 4 * self.in_channels, out_channels = self.out_channels, use_leaky = True)
        self.block5 = self.block(in_channels = 5 * self.in_channels, out_channels = self.out_channels, use_leaky = False)

    def block(self, in_channels = 64, out_channels = 64, use_leaky = True):
        self.layers = []

        self.layers.append(
            nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride,
            padding=self.padding,
            bias=True
            )
        )

        if use_leaky:
            self.layers.append(
                nn.LeakyReLU(
                negative_slope=self.slope, inplace=True
                )
            )

        return nn.Sequential(*self.layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            outputs = self.block1(x)
            inputs = torch.concat((outputs, x), dim = 1)

            outputs = self.block2(inputs)
            inputs = torch.concat((outputs, inputs), dim = 1)

            outputs = self.block3(inputs)
            inputs = torch.concat((outputs, inputs), dim = 1)

            outputs = self.block4(inputs)
            inputs = torch.concat((outputs, inputs), dim = 1)
            
            outputs = self.block5(inputs)
            
            return outputs

        else:
            raise TypeError("Input must be a tensor".capitalize())
        
        
if __name__ == "__main__":
    layers = []
    for _ in range(5):
        layers += [
            DenseBlock(in_channels = 64, out_channels = 64)
        ]
        
    model = nn.Sequential(*layers)
    
    assert model(torch.randn(1, 64, 256, 256)).size() == (1, 64, 256, 256)

In [None]:
class ResidualInResidual(nn.Module):
    def __init__(self, in_channels=64, res_scale=0.2):
        super(ResidualInResidual, self).__init__()

        self.in_channels = in_channels
        self.res_scale = res_scale

        self.denseblock1 = DenseBlock(
            in_channels=self.in_channels, out_channels=self.in_channels
        )
        self.denseblock2 = DenseBlock(
            in_channels=self.in_channels, out_channels=self.in_channels
        )
        self.denseblock3 = DenseBlock(
            in_channels=self.in_channels, out_channels=self.in_channels
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            output1 = self.denseblock1(x)
            input2 = output1 + x

            output2 = self.denseblock2(input2)
            input3 = output2 + input2

            output = self.denseblock3(input3)
            output = torch.mul(output, self.res_scale) + input3

            return output


if __name__ == "__main__":
    residual_in_residual = ResidualInResidual(in_channels=64)

    print(residual_in_residual)

In [5]:
class OutputBlock(nn.Module):
    def __init__(self, in_channels=64, out_channels=3):
        super(OutputBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1
        self.negative_slope = 0.2
        self.upscale_factor = 2

        self.output_block = self.block()

    def block(self):

        self.layers = []

        for idx in range(2):
            self.layers.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels=self.in_channels,
                        out_channels=self.in_channels * 4,
                        kernel_size=self.kernel_size,
                        stride=self.stride_size,
                        padding=self.padding_size,
                        bias=True,
                    ),
                    nn.PixelShuffle(upscale_factor=self.upscale_factor),
                )
            )
            if idx == 0:
                self.layers.append(
                    nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True),
                )

        return nn.Sequential(*self.layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.output_block(x)
        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":

    outblock = OutputBlock(in_channels=64, out_channels=64)

    assert outblock(torch.randn(1, 64, 64, 64)).size() == (1, 64, 256, 256)

In [9]:
class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=64):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = 3
        self.stride_size = 1
        self.padding_size = 1

        self.layers = []

        self.input_block = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=True,
        )

        self.residual_in_residual_denseblock = nn.Sequential(
            *[ResidualInResidual(in_channels=self.out_channels) for _ in range(16)]
        )

        self.middle_block = nn.Conv2d(
            in_channels=self.out_channels,
            out_channels=out_channels,
            kernel_size=self.kernel_size,
            stride=self.stride_size,
            padding=self.padding_size,
            bias=True,
        )

        self.output = nn.Sequential(
            OutputBlock(in_channels=self.out_channels, out_channels=self.out_channels),
            nn.Conv2d(
                in_channels=self.out_channels,
                out_channels=self.in_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
            ),
        )

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            input_block = self.input_block(x)
            residual_block = self.residual_in_residual_denseblock(input_block)
            middle_block = self.middle_block(residual_block)
            middle_block = torch.add(input_block, middle_block)
            output = self.output(middle_block)

            return output

    @staticmethod
    def total_params(model=None):
        if isinstance(model, Generator):
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":

    netG = Generator(in_channels=3, out_channels=64)

    assert Generator.total_params(model=netG) == 26893315

    assert netG(torch.randn(1, 3, 64, 64)).size() == (1, 3, 256, 256)