In [None]:
import os
import yaml
import cv2
from PIL import Image
import zipfile
import joblib
import pandas as pd
import matplotlib.pyplot as plt
from collections import OrderedDict
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchsummary import summary
from torchview import draw_graph

In [None]:
def dump(value, filename):
    joblib.dump(value = value, filename = filename)
    
def load(filename):
    return joblib.load(filename = filename)

In [None]:
def params():
    with open("../../config.yml", "r") as file:
        return yaml.safe_load(file)
    
config = params()

In [None]:
class Loader:
    def __init__(self, image_path = None, image_size = 256, batch_size = 1, split_size = 0.25):
        self.image_path = image_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.X = []
        self.y = []

    def image_transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size), Image.BICUBIC),
                transforms.ToTensor(),
                transforms.RandomCrop((self.image_size, self.image_size)),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]
        )
    def image_splits(self, **kwargs):
        return train_test_split(
            kwargs["X"],
            kwargs["y"],
            test_size = self.split_size,
            random_state = 42
            )

    def unzip_folder(self):
        if os.path.exists(config["path"]["raw_path"]):
            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(path=config["path"]["raw_path"])
        else:
            raise Exception("Unable to find the zip file".capitalize())

    def extract_features(self):
        if os.path.exists(config["path"]["raw_path"]):
            self.directory = os.path.join(config["path"]["raw_path"], "dataset")
            self.images = os.path.join(self.directory, os.listdir(self.directory)[0])
            self.masks = os.path.join(self.directory, os.listdir(self.directory)[1])

            for image in os.listdir(self.images):
                for mask in os.listdir(self.masks):
                    image_base_name = image.split(".")[0]
                    masks_base_name = mask.split(".")[0]

                    if image_base_name == masks_base_name:
                        self.X.append(
                            self.image_transforms()(
                                Image.fromarray(
                                    cv2.imread(os.path.join(self.images, image))
                                    )
                                )
                            )
                        self.y.append(
                            self.image_transforms()(
                                Image.fromarray(
                                    cv2.imread(os.path.join(self.masks, mask))
                                    )
                                )
                            )

            X_train, X_test, y_train, y_test = self.image_splits(X = self.X, y = self.y)

            return {"X_train": X_train, "y_train":y_train, "X_test": X_test, "y_test": y_test, "X": self.X, "y": self.y}           

        else:
            raise Exception("Unable to find the zip file".capitalize())

    def create_dataloader(self):
        dataset = self.extract_features()

        if os.path.exists(config["path"]["processed_path"]):
            dataloader = DataLoader(
                dataset=list(zip(dataset["X"], dataset["y"])),
                batch_size=self.batch_size*16, shuffle=True)

            train_dataloader = DataLoader(
                dataset=list(zip(dataset["X_train"], dataset["y_train"])),
                batch_size=self.batch_size, shuffle=True)

            test_dataloader = DataLoader(
                dataset=list(zip(dataset["X_test"], dataset["y_test"])),
                batch_size=self.batch_size*4, shuffle=True)

            dump(value=dataloader, filename=os.path.join(config["path"]["processed_path"], "dataloader.pkl"))
            dump(value=train_dataloader, filename=os.path.join(config["path"]["processed_path"], "train_dataloader.pkl"))
            dump(value=test_dataloader, filename=os.path.join(config["path"]["processed_path"], "test_dataloader.pkl"))

        else:
            raise Exception("Unable to create the dataloader file".capitalize())

    @staticmethod
    def plot_images():
        if os.path.exists(config["path"]["processed_path"]):
            test_dataloader = load(os.path.join(config["path"]["processed_path"], "dataloader.pkl"))

            data, label = next(iter(test_dataloader))

            plt.figure(figsize=(25, 15))

            for index, image in enumerate(data):
                X = image.permute(1, 2, 0).numpy()
                y = label[index].permute(1, 2, 0).numpy()

                X = (X - X.min())/(X.max() - X.min())
                y = (y - y.min())/(y.max() - y.min())

                plt.subplot(2 * 4, 2 * 4, 2 * index + 1)
                plt.imshow(X)
                plt.title("X")
                plt.axis("off")

                plt.subplot(2 * 4, 2 * 4, 2 * index + 2)
                plt.imshow(y, cmap="gray")
                plt.title("y")
                plt.axis("off")

            plt.savefig(os.path.join(config["path"]["files_path"], "images.png"))

            plt.tight_layout()
            plt.show()

        else:
            raise Exception("Unable to open the dataloader file".capitalize())

    @staticmethod
    def dataset_details():
        if os.path.exists(config["path"]["processed_path"]):
            dataloader = load(filename=os.path.join(config["path"]["processed_path"], "dataloader.pkl"))
            train_dataloader = load(filename=os.path.join(config["path"]["processed_path"], "train_dataloader.pkl"))
            test_dataloader = load(filename=os.path.join(config["path"]["processed_path"], "test_dataloader.pkl"))

            train_image, _ = next(iter(train_dataloader))
            test_image, _ = next(iter(test_dataloader))

            pd.DataFrame(
                {
                    "total_images": str(sum(image.size(0) for image, _ in dataloader)),
                    "train_data": str(sum(image.size(0) for image, _ in train_dataloader)),
                    "test_data": str(sum(image.size(0) for image, _ in test_dataloader)),
                    "train_data_shape": str(train_image.size()),
                    "test_data_shape": str(test_image.size()),
                },
                index=["Quantity"],
            ).T.to_csv(os.path.join(config["path"]["files_path"], "dataset_details.csv"))

        else:
            raise Exception("Unable to find the dataloader file".capitalize())

In [None]:
loader = Loader(image_path="../../data/raw/dataset.zip")

loader.unzip_folder()
loader.create_dataloader()

In [None]:
loader.dataset_details()

In [None]:
loader.plot_images()

#### Generator

In [None]:
class InputBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64):
        super(InputBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.kernel = 7
        self.stride = 1
        self.padding = 3
        
        self.layers = OrderedDict()
        
        self.input = self.block()
        
    
    def block(self):
        
        self.layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            padding_mode="reflect"
            )
        
        self.layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels)
        
        self.layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(self.layers)
    
    def forward(self, x):
        if x is not None:
            return self.input(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())

In [None]:
class DownBlock(nn.Module):
    def __init__(self, in_channels = 64):
        super(DownBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels * 2
        
        self.kernel = 3
        self.stride = 2
        self.padding = 1
        
        self.layers = OrderedDict()
        
        self.decoder = self.down_block()
        
    
    def down_block(self):
        
        self.layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding
            )
        
        self.layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels)
        
        self.layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(self.layers)
        
    
    def forward(self, x):
        if x is not None:
            return self.decoder(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())

if __name__ == "__main__":
    in_channels = 64
    layers = []
    for _ in range(2):
        layers.append(DownBlock(in_channels=in_channels))
        in_channels *= 2

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels = 256):
        super(ResidualBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = in_channels
        
        self.kernel = 3
        self.stride = 1
        self.padding = 1
        
        self.layers = OrderedDict()
        
        self.residual = self.residual_block()
        
    
    def residual_block(self):
        
        for idx in range(2):
            self.layers["conv_{}".format(idx)] = nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
        
            )
            self.layers["instance_norm_{}".format(idx)] = nn.InstanceNorm2d(
                num_features=self.out_channels
                )
            
            if idx%2 == 0:
                self.layers["ReLU_{}".format(idx)] = nn.ReLU(inplace=True)
                
        return nn.Sequential(self.layers)
        
        
    def forward(self, x):
        if x is not None:
            return x + self.residual(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())
        
        
if __name__ == "__main__":
    in_channels = 256
    layers = []
    
    model = nn.Sequential(*[ResidualBlock(in_channels=in_channels) for _ in range(9)])

In [None]:
class UpsampleBlock(nn.Module):
    def __init__(self, in_channels = 256):
        super(UpsampleBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels // 2
        
        self.kernel = 3
        self.stride = 2
        self.padding = 1
        
        self.layers = OrderedDict()
        
        self.encoder = self.up_block()
        
    def up_block(self):
        
        self.layers["conv"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            output_padding=self.padding
        )
        
        self.layers["instance_norm"] = nn.InstanceNorm2d(
            num_features=self.out_channels
        )
        
        self.layers["ReLU"] = nn.ReLU(inplace=True)
        
        return nn.Sequential(self.layers)
    
    def forward(self, x):
        if x is not None:
            return self.encoder(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())


if __name__ == "__main__":
    in_channels = 256
    layers = []

    for _ in range(2):
        layers.append(UpsampleBlock(in_channels=in_channels))
        in_channels //= 2
        
    model = nn.Sequential(*layers)

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = 64
        self.kernel = 7
        self.stride = 1
        self.padding = 3
        self.num_repetitive = 9
        
        self.layers = []
        
        self.layers.append(InputBlock(
            in_channels=self.in_channels, out_channels=self.out_channels))
        
        for _ in range(2):
            self.layers.append(DownBlock(in_channels=self.out_channels))
            self.out_channels *= 2
            
        for _ in range(self.num_repetitive):
            self.layers.append(ResidualBlock(in_channels=self.out_channels))
            
        for _ in range(2):
            self.layers.append(UpsampleBlock(in_channels=self.out_channels))
            self.out_channels //= 2
        
            
        self.layers.append(
            nn.Sequential(
                nn.Conv2d(
                    in_channels=self.out_channels,
                    out_channels=3,
                    kernel_size=self.kernel,
                    stride=self.stride,
                    padding=self.padding,
                    padding_mode="reflect"
                    ),
                
                nn.Tanh()
            )
        )
        
        self.model = nn.Sequential(*self.layers)
            
    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())

In [None]:
if __name__ == "__main__":
    netG = Generator(in_channels=3)
    summary(model=netG, input_size=(3, 256, 256), batch_size=1)  # Show the model summary
    draw_graph(model=netG, input_data=torch.randn(1, 3, 256, 256)).visual_graph # Check the model architecture

#### Discriminator

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, kernel = 4, stride = 2, padding = 1, is_instance_norm = False, is_lr = True):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel = kernel
        self.stride = stride
        self.padding = padding
        self.use_norm = is_instance_norm
        self.use_lr = is_lr
        
        self.model = self.block()
        
    
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding
            )
        
        if self.use_norm:
            layers["instance_norm"] = nn.InstanceNorm2d(
                num_features=self.out_channels
            )
        
        if self.use_lr:
            layers["LeakyReLU"] = nn.LeakyReLU(
                negative_slope=0.2, inplace=True)
        
        return nn.Sequential(layers)
        
    
    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Discriminator, self).__init__()

        self.layers = []
        self.in_channels = in_channels
        self.out_channels = 64
        self.kernel = 4
        self.stride = 2
        self.padding = 1

        for idx in range(3):
            self.layers.append(DiscriminatorBlock(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel=self.kernel,
                stride=self.stride,
                padding=self.padding,
                is_instance_norm= False if idx == 0 else True
                ))
            
            self.in_channels = self.out_channels
            self.out_channels *= 2

        for idx in range(2):
            self.layers.append(
                DiscriminatorBlock(
                    in_channels=self.in_channels,
                    out_channels=self.out_channels,
                    kernel=self.kernel,
                    stride=self.stride // 2,
                    padding=self.padding,
                    is_instance_norm=True if idx == 0 else False,
                    is_lr=True if idx == 0 else False,
                )
            )
            self.in_channels = self.out_channels
            self.out_channels //= self.out_channels
            
            
        self.model = nn.Sequential(*self.layers)

    def forward(self, x):
        if x is not None:
            return self.model(x)
        else:
            raise Exception("Input to the model cannot be empty".capitalize())


if __name__ == "__main__":
    netD = Discriminator(in_channels=3)
    assert netD(torch.randn(1, 3, 256, 256)).size() == (1, 1, 30, 30)
    
    print(summary(model=netD, input_size=(3, 256, 256), batch_size=1))
    draw_graph(model=netD, input_data=torch.randn(1, 3, 256, 256)).visual_graph # Check the model architecture