In [None]:
import sys
import os
import yaml
import zipfile
import joblib
import cv2
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchsummary import summary
from torchview import draw_graph

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
def config():
    with open("../../config.yml", "r") as file:
        config_files = yaml.safe_load(file)
        
    return config_files

In [None]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
        
def load(filename = None):
    if filename is not None:
        return joblib.load(filename=filename)

In [None]:
class Loader():
    def __init__(self, image_path = None, image_size = 256, channels = 3, batch_size = 1, split_size = 0.20, paired_images = False, unpaired_images = True):
        self.image_path = image_path
        self.image_size = image_size
        self.channels = channels
        self.batch_size = batch_size
        self.split_size = split_size
        self.paired_image = paired_images
        self.unpaired_image = unpaired_images

        self.config = config()

        self.X = []
        self.y = []

    def transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
                transforms.ToTensor(),
                # transforms.Grayscale(num_output_channels=self.channels),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )

    def data_split(self, X, y):
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=self.split_size, random_state=42)

        return {
            "X_train": X_train,
            "X_test": X_test,
            "y_train": y_train,
            "y_test": y_test,
        }

    def image_info(self):
        pass

    def unzip_folder(self):
        if os.path.exists(self.config["path"]["raw_path"]):
            path = self.config["path"]["raw_path"]

            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(os.path.join(path,))

        else:
            raise Exception("Unable to unzip the folder as the path is not exists".capitalize())

    def extract_features(self):
        self.directory = os.path.join(self.config["path"]["raw_path"])
        self.categories = os.listdir(os.path.join(self.directory, "images"))

        for index, category in enumerate(self.categories[0]) if self.paired_image else enumerate(self.categories): # X
            path = os.path.join(self.directory, "images", category) # Full path

            for image in os.listdir(path): # Each image

                if self.paired_image:
                    if image in os.listdir(os.path.join(self.directory, "images", "y")):
                        image_path_X = os.path.join(path, image) # Full image path - X
                        image_path_y = os.path.join(
                            self.directory, "images", "y", image
                        )  # Full image path - y

                        image_X = cv2.imread(image_path_X)
                        image_y = cv2.imread(image_path_y)

                        image_X = cv2.cvtColor(image_X, cv2.COLOR_BGR2RGB)
                        image_y = cv2.cvtColor(image_y, cv2.COLOR_BGR2RGB)

                        self.X.append(self.transforms()(Image.fromarray(image_X)))
                        self.y.append(self.transforms()(Image.fromarray(image_y)))

                elif self.unpaired_image:
                    image_path = os.path.join(path, image) # Full path

                    image_read =cv2.imread(image_path)
                    image_read = cv2.cvtColor(image_read, cv2.COLOR_BGR2RGB)

                    self.X.append(self.transforms()(Image.fromarray(image_read))) if index %2\
                        else self.y.append(self.transforms()(Image.fromarray(image_read)))

        data = self.data_split(self.X, self.y)

        return data

    def create_dataloader(self):

        self.data = self.extract_features()

        if os.path.exists(self.config["path"]["processed_path"]):
            processed_path = self.config["path"]["processed_path"]

            train_dataloader = DataLoader(
                dataset=list(zip(self.data["X_train"], self.data["y_train"])),
                batch_size=self.batch_size,
                shuffle=True,)

            test_dataloader = DataLoader(
                dataset=list(zip(self.data["X_test"], self.data["y_test"])),
                batch_size=self.batch_size*4,
                shuffle=True,)

            dataloader = DataLoader(
                dataset=list(zip(self.X, self.y)),
                batch_size=self.batch_size*8,
                shuffle=True,)

            dump(
                value=train_dataloader, filename=os.path.join(processed_path, "train_dataloader.pkl"))

            dump(
                value=test_dataloader, filename=os.path.join(processed_path, "test_dataloader.pkl"))

            dump(
                value=dataloader, filename=os.path.join(processed_path, "dataloader.pkl"))

        else:
            raise Exception("Unable to create the pickle file".capitalize())

    @staticmethod
    def plot_images():
        config_files = config()

        if os.path.exists(config_files["path"]["processed_path"]):
            dataloader = load(os.path.join(config_files["path"]["processed_path"], "dataloader.pkl"))

            X, y = next(iter(dataloader))

            plt.figure(figsize=(10, 10))

            for index, image in enumerate(X):
                image_X = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
                image_y = y[index].squeeze().permute(1, 2, 0).cpu().detach().numpy()

                image_X = (image_X - image_X.min())/(image_X.max() - image_X.min())
                image_y = (image_y - image_y.min())/(image_y.max() - image_y.min())

                plt.subplot(2 * 4, 2 * 2, 2 * index + 1)
                plt.imshow(image_X)
                plt.axis("off")

                plt.subplot(2 * 4, 2 * 2, 2 * index + 2)
                plt.imshow(image_y)
                plt.axis("off")

            plt.tight_layout()
            
            if os.path.exists(config_files["path"]["files_path"]):
                plt.savefig(os.path.join(config_files["path"]["files_path"], "images.png"))
            else:
                raise Exception("Unable to save the images as the path is not exists".capitalize())
            
            plt.show()

        else:
            raise Exception("Unable to plot the images as the path is not exists".capitalize())

    @staticmethod
    def dataset_details():
        config_files = config()

        if os.path.exists(config_files["path"]["processed_path"]):
            path = config_files["path"]["processed_path"]

            train_dataloader = load(filename=os.path.join(path, "train_dataloader.pkl"))
            test_dataloader = load(filename=os.path.join(path, "test_dataloader.pkl"))
            dataloader = load(filename=os.path.join(path, "dataloader.pkl"))

            pd.DataFrame(
                {
                    "train_data(total)": str(sum(data.size(0) for data, _ in train_dataloader)),
                    "test_data(total)": str(sum(data.size(0) for data, _ in test_dataloader)),
                    "data(total)": str(sum(data.size(0) for data, _ in dataloader)),
                    "train_data(batch)": str(len(train_dataloader)),
                    "test_data(batch)": str(len(test_dataloader)),
                    "train_data(shape)": str(train_dataloader.dataset[0][0].shape),
                    "test_data(shape)": str(test_dataloader.dataset[0][0].shape),
                    "data(shape)": str(dataloader.dataset[0][0].shape),
                },
                index=["Details dataset".capitalize()],
            ).T.to_csv(os.path.join(
                os.path.join(config_files["path"]["files_path"], "dataset_details.csv") if os.path.exists(config_files["path"]["files_path"])\
                else os.path.join(config_files["path"["files_path"]], "dataset_details.csv")))

        else:
            raise Exception("Unable to create the pickle file".capitalize())    


if __name__ == "__main__":
    loader = Loader(
        image_path="/Users/shahmuhammadraditrahman/Desktop/images.zip", paired_images=True)
    loader.unzip_folder()
    loader.extract_features()
    loader.create_dataloader()
    loader.dataset_details()
    loader.plot_images()

#### Create the Generator

In [None]:
class InputBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64):
        super(InputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel = 7
        self.stride = 1
        self.padding = 3
        
        self.input_block = self.block()
        
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            padding_mode="reflect",
            bias=False,
        )
        layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(layers)
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.input_block(x)
        
        else:
            raise Exception("Unable to process the input".capitalize())
        
if __name__ == "__main__":
    in_channels = 3
    out_channels = 64
    config_files = config()
    
    input_block = InputBlock(
        in_channels=in_channels,
        out_channels=out_channels,
    )
    
    print(input_block(torch.randn(1, 3, 256, 256)).size())
    print(summary(model=input_block, input_size=(3, 256, 256)))
    draw_graph(model=input_block, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_input_block"), format="jpeg"
    )

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 128):
        super(EncoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 2
        self.padding = 1

        self.encoder_block = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            bias=False,
        )
        layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        layers["ReLU"] = nn.ReLU(inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 64
    out_channels = 128
    num_repetitive = 2

    layers = []

    for _ in tqdm(range(num_repetitive)):
        layers.append(EncoderBlock(in_channels=in_channels, out_channels=out_channels))

        in_channels = out_channels
        out_channels *= 2

    model = nn.Sequential(*layers)

    print(model(torch.randn(1, 64, 256, 256)).size())
    print(summary(model=model, input_size=(64, 256, 256)))
    draw_graph(model=model, input_data=torch.randn(1, 64, 256, 256)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_encoder_block"), format="jpeg"
    )

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels = 256, out_channels = 256):
        super(ResidualBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 1
        self.padding = 1

        self.residual_block = self.block()

    def block(self):
        layers = OrderedDict()
        for idx in range(2):
            layers["conv{}".format(idx+1)] = nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                padding=self.padding,
                bias=False,
            )
            layers["instance_norm{}".format(idx+1)] = nn.InstanceNorm2d(
                num_features=self.out_channels)

            if idx==0:
                layers["ReLU"] = nn.ReLU(inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return  x + self.residual_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())

if __name__ == "__main__":
    in_channels = 256
    num_repetitive = 9
    
    layers = []
    
    for idx in tqdm(range(num_repetitive)):
        layers+=[
            ResidualBlock(
            in_channels=in_channels, out_channels=in_channels)
        ]
        
    model = nn.Sequential(*layers)
    print(model(torch.randn(1, 256, 64, 64)).size())
    print(summary(model=model, input_size=(256, 64, 64)))
    draw_graph(model=model, input_data=torch.randn(1, 256, 64, 64)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_residual_block"), format="jpeg"
    )

In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, in_channels = 256, out_channels = 128):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 2
        self.padding = 1
        self.output_padding = 1

        self.decoder_block = self.block()
        
    def block(self):
        layers = OrderedDict()
        
        layers["convTranspose"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            output_padding=self.output_padding,
            bias=False,
        )
        layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(layers)
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.decoder_block(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 256
    out_channels = 128
    num_repetitive = 2

    layers = []

    for _ in tqdm(range(num_repetitive)):
        layers.append(
            DecoderBlock(in_channels=in_channels, out_channels=out_channels)
        )
        in_channels = out_channels
        out_channels //= 2
        
    model = nn.Sequential(*layers)
    
    print(model(torch.randn(1, 256, 64, 64)).size())
    print(summary(model=model, input_size=(256, 64, 64)))
    draw_graph(model=model, input_data=torch.randn(1, 256, 64, 64)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_decoder_block"), format="jpeg"
    )

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = 64
        
        self.kernel = 7
        self.stride = 1
        self.padding = 3

        self.layers = []

        self.layers.append(
            InputBlock(in_channels=self.in_channels, out_channels=self.out_channels)
        )
        self.in_channels = self.out_channels
        self.out_channels *= 2
        
        for _ in tqdm(range(2)):
            self.layers.append(
                EncoderBlock(in_channels=self.in_channels, out_channels=self.out_channels)
            )
            
            self.in_channels = self.out_channels
            self.out_channels *= 2
            
        self.in_channels = self.in_channels
        self.out_channels //= 2
        
        for _ in tqdm(range(9)):
            self.layers.append(
                ResidualBlock(in_channels=self.in_channels, out_channels=self.in_channels)
            )
            
        self.out_channels //= 2
        
        for _ in tqdm(range(2)):
            self.layers.append(
                DecoderBlock(in_channels=self.in_channels, out_channels=self.out_channels)
            )
            
            self.in_channels = self.out_channels
            self.out_channels //= 2
            
        self.model = nn.Sequential(*self.layers)
        
        self.output = nn.Sequential(
            nn.Conv2d(
                in_channels=self.out_channels*2,
                out_channels=3,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding,
                bias=False,
            ),
            nn.Tanh(),
        )
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            x = self.model(x)
            return self.output(x)

        else:
            raise Exception("Unable to process the input".capitalize())


if __name__ == "__main__":
    in_channels = 3
    
    netG = Generator(in_channels=in_channels)

    print(netG(torch.randn((1, 3, 256, 256))).size())
    
    print(summary(model=netG, input_size=(3, 256, 256)))
    
    draw_graph(model=netG, input_data=torch.randn((1, 3, 256, 256))).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netG_model"), format="jpeg"
    )