In [None]:
import os
import cv2
import yaml
import torch
import joblib
import zipfile
import pandas as pd
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from collections import OrderedDict
from torchvision import transforms
from torchview import draw_graph
from torchsummary import summary
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [4]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value = value, filename = filename)
        
    else:
        raise Exception("Please provide the value and filename to dump".capitalize())


def load(filename):
    if filename is not None:
        return joblib.load(filename = filename)


def device_init(device="mps"):
    if device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    elif device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")

    else:
        return torch.device("cuda")

In [None]:
def config():
    with open("../../config.yml", "r") as file:
        config_files = yaml.safe_load(file)
        
    return config_files

In [None]:
class Loader():
    def __init__(
        self,
        image_path = None,
        channels = 3,
        image_size = 256,
        batch_size = 1,
        split_size = 0.20,
        paired_images = False,
        unpaired_images = True
        ):
        self.image_path = image_path
        self.channels = channels
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size
        self.paired_images = paired_images
        self.unpaired_images = unpaired_images

        self.X = []
        self.y = []

        self.raw_path = config()["path"]["raw_path"]
        self.processed_path = config()["path"]["processed_path"]

    def unzip_folder(self):
        with zipfile.ZipFile(self.image_path, "r") as file:
            if os.path.exists(self.raw_path):
                file.extractall(os.path.join(self.raw_path))

            else:
                raise Exception("Cannot unzip the folder for further process".capitalize())

    def transforms(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.CenterCrop((self.image_size, self.image_size)),
            transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
        ])

    def image_split(self, X, y):
        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=self.split_size, random_state=42
        )

        return {
            "X_train": X_train,
            "X_test": X_test,
            "y_train": y_train,
            "y_test": y_test,
        }

    def feature_extractor(self):
        self.directory = os.path.join(self.raw_path, "images")
        self.categories = ["X", "y"]

        self.paired_check = os.listdir(os.path.join(self.directory, "y"))

        for category in tqdm(self.categories):   
            folder_path = os.path.join(self.directory, category) 

            for image in os.listdir(folder_path): 
                if self.paired_images:
                    if image in self.paired_check:
                        image_path = os.path.join(folder_path, image) 
                else:
                    image_path = os.path.join(folder_path, image)

                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

                image = self.transforms()(Image.fromarray(image))

                self.X.append(image) if category == "X" else self.y.append(image)                  

        return self.image_split(X = self.X, y = self.y)

    def create_dataloader(self):
        try:
            data = self.feature_extractor()

        except Exception as e:
            print("An error occurred while creating the dataloader".capitalize(), e)

        else:
            train_dataloader = DataLoader(
                dataset=list(zip(data["X_train"], data["y_train"])),
                batch_size=self.batch_size,
                shuffle=True
            )
            test_dataloader = DataLoader(
                dataset=list(zip(data["X_test"], data["y_test"])),
                batch_size=self.batch_size*8,
                shuffle=True
            )

            if os.path.exists(self.processed_path):
                dump(
                    value=train_dataloader,
                    filename=os.path.join(self.processed_path, "train_dataloader.pkl")
                )
                dump(
                    value=test_dataloader,
                    filename=os.path.join(self.processed_path, "test_dataloader.pkl")
                )

            else:
                raise Exception("Cannot create the dataloader for further process".capitalize())

    @staticmethod
    def plot_images():
        config_files = config()
        files_path = config_files["path"]["files_path"]
        processed_path = config_files["path"]["processed_path"]

        if os.path.exists(files_path):
            test_dataloader = load(
                filename=os.path.join(processed_path, "test_dataloader.pkl")
            )

            X, y = next(iter(test_dataloader))

        else:
            raise Exception("Cannot load the dataloader for further process".capitalize())

        plt.figure(figsize=(20, 10))

        for index, image in enumerate(X):
            image_X = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
            image_y = y[index].squeeze().permute(1, 2, 0).cpu().detach().numpy()

            image_X = (image_X - image_X.min())/(image_X.max() - image_X.min())
            image_y = (image_y - image_y.min())/(image_y.max() - image_y.min())

            plt.subplot(2* 2, 2 * 4, 2 * index + 1)
            plt.imshow(image_X)
            plt.axis("off")
            plt.title("X")

            plt.subplot(2* 2, 2 * 4, 2 * index + 2)
            plt.imshow(image_y)
            plt.axis("off")
            plt.title("y")

        plt.tight_layout()
        plt.savefig(os.path.join(files_path, "images.png")) if os.path.exists(files_path) else "Cannot be saved the images".capitalize()
        plt.show()

    @staticmethod
    def dataset_details():
        config_files = config()

        files_path = config_files["path"]["files_path"]

        train_dataloader = load(os.path.join(
            config_files["path"]["processed_path"], "train_dataloader.pkl"
        ))
        test_dataloader = load(os.path.join(
            config_files["path"]["processed_path"], "test_dataloader.pkl"
        ))

        pd.DataFrame(
            {
                "train_data(total)": str(sum(X.size(0) for X, _ in train_dataloader)),
                "test_data(total)": str(sum(X.size(0) for X, _ in test_dataloader)),
                "total_data": str(sum(X.size(0) for X, _ in train_dataloader)
                + sum(X.size(0) for X, _ in test_dataloader)),
                "train_data_shape": str(train_dataloader.dataset[0][0].shape),
                "test_data_shape": str(test_dataloader.dataset[0][0].shape)
            },
            index=["Quantity"]
        ).T.to_csv(os.path.join(files_path, "dataset_details.csv") if os.path.exists(files_path)\
            else "Cannot be saved the dataset into csv format".capitalize())


if __name__ == "__main__":
    loader = Loader(
        image_path="/Users/shahmuhammadraditrahman/Desktop/images.zip",
        unpaired_images=True,
        split_size=0.50
    )
    #loader.unzip_folder()
    loader.create_dataloader()
    loader.plot_images()
    loader.dataset_details()

#### Create the Generator

In [None]:
from collections import OrderedDict

class Encoder(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, use_batch_norm = True):
        super(Encoder, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.batch_norm = use_batch_norm

        self.kernel = 4
        self.stride = 2
        self.padding = 1

        self.encoder = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
            )
        )

        layers["leaky_relu"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        if self.batch_norm:
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)
        
        return nn.Sequential(layers)

    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.encoder(x)

        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Encoder):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        else:
            raise Exception("Please provide the model for further process".capitalize())

In [None]:
class Decoder(nn.Module):
    def __init__(self, in_channels = 512, out_channels = 512):
        super(Decoder, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 4
        self.stride = 2
        self.padding = 1

        self.decoder = self.block()
        
    def block(self):
        layers = OrderedDict()

        layers["deconv"] = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
            )
        )

        layers["relu"] = nn.ReLU(inplace=True)
        layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)

        return nn.Sequential(layers)
    
    def forward(self, x, skip_info):
        if isinstance(x, torch.Tensor):
            x =  self.decoder(x)
            return torch.concat((x, skip_info), dim=1)

        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Decoder):
            return sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        else:
            raise Exception("Please provide the model for further process".capitalize())

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Generator, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = 64
        self.kernel = 4
        self.stride = 2
        self.padding = 1
        
        self.encoder1 = Encoder(in_channels=self.in_channels, out_channels = self.out_channels)
        self.encoder2 = Encoder(in_channels=self.out_channels, out_channels = self.out_channels*2)
        self.encoder3 = Encoder(in_channels=self.out_channels*2, out_channels=self.out_channels*4)
        self.encoder4 = Encoder(in_channels=self.out_channels*4, out_channels=self.out_channels*8)
        self.encoder5 = Encoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        self.encoder6 = Encoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        self.encoder7 = Encoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        
        self.decoder1 = Decoder(in_channels=self.out_channels*8, out_channels=self.out_channels*8)
        self.decoder2 = Decoder(in_channels=self.out_channels*8*2, out_channels=self.out_channels*8)
        self.decoder3 = Decoder(in_channels=self.out_channels*8*2, out_channels=self.out_channels*8)
        self.decoder4 = Decoder(in_channels=self.out_channels*8*2, out_channels=self.out_channels*4)
        self.decoder5 = Decoder(in_channels=self.out_channels*4*2, out_channels=self.out_channels*2)
        self.decoder6 = Decoder(in_channels=self.out_channels*2*2, out_channels=self.out_channels)
        
        self.output = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=128,
                out_channels=3,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
            ),
            nn.Tanh()
        )
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            encoder1 = self.encoder1(x)
            encoder2 = self.encoder2(encoder1)
            encoder3 = self.encoder3(encoder2)
            encoder4 = self.encoder4(encoder3)
            encoder5 = self.encoder5(encoder4)
            encoder6 = self.encoder6(encoder5)
            encoder7 = self.encoder7(encoder6)
            
            decoder1 = self.decoder1(encoder7, encoder6)
            decoder2 = self.decoder2(decoder1, encoder5)
            decoder3 = self.decoder3(decoder2, encoder4)
            decoder4 = self.decoder4(decoder3, encoder3)
            decoder5 = self.decoder5(decoder4, encoder2)
            decoder6 = self.decoder6(decoder5, encoder1)
            output = self.output(decoder6)
            
            return output
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model):
        if isinstance(model, Generator):
            return sum(p.numel() for p in model.parameters())
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
        
if __name__ == "__main__":
    in_channels = 3
    netG = Generator(in_channels=in_channels)
    
    summary(model=netG, input_size=(3, 256, 256))
    
    config_files = config()
    files_path = config_files["path"]["files_path"]
    
    draw_graph(model=netG, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(files_path, "netG"), format="png" if os.path.exists(files_path) else "Cannot saved the netG model architecture".capitalize()
    )

#### Discriminator

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels = 3, out_channels = 64, use_batch_norm = False):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.batch_norm = use_batch_norm

        self.kernel = 4
        self.stride = 2
        self.padding = 1

        self.discriminator = self.block()
        
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.Sequential(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel,
                stride=self.stride,
                padding=self.padding
                
            )
        )
        
        if self.batch_norm:
            layers["batch_norm"] = nn.BatchNorm2d(num_features=self.out_channels)
        
        layers["leaky_ReLU"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        
        return nn.Sequential(layers)
    
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.discriminator(x)
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 3):
        super(Discriminator, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = 64
        
        self.layers = []
        
        for idx in range(3):
            self.layers.append(
                DiscriminatorBlock(in_channels=self.in_channels, out_channels=self.out_channels, use_batch_norm=False if idx == 0 else True)
            )
            self.in_channels = self.out_channels
            self.out_channels *= 2
            
        self.block = nn.Sequential(*self.layers)
        
        self.output = nn.Sequential(
            nn.ZeroPad2d((1, 0, 1, 0)),
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=0
            )
        )
        
    def forward(self, x):
        if isinstance(x, torch.Tensor):
            return self.output(self.block(x))
        
        else:
            raise Exception("Please provide the tensor for further process".capitalize())
        
    @staticmethod
    def total_params(model = None):
        if isinstance(model, Discriminator):
            return sum(p.numel() for p in model.parameters())
        
        else:
            raise Exception("Please provide the model for further process".capitalize())
        
        
if __name__ == "__main__":
    in_channels = 3
    config_files = config()
    
    netD = Discriminator(in_channels=in_channels)
    
    assert netD(torch.randn(1, 3, 256, 256)).size() == (1, 1, 30, 30)
    
    summary(model=netD, input_size=(3, 256, 256))
    
    draw_graph(model=netD, input_data=torch.randn(1, 3, 256, 256)).visual_graph.render(
        filename=os.path.join(config_files["path"]["files_path"], "netD"), format="png" if config_files["path"]["files_path"] else "Cannot be saved netD file".capitalize()
    )

#### Define loss

In [None]:
class CycleLoss(nn.Module):
    def __init__(self, reduction="mean"):
        super(CycleLoss, self).__init__()

        self.name = "CycleLoss".title()
        self.reduction = reduction

        self.loss = nn.L1Loss(reduction=self.reduction)

    def forward(self, actual, pred):
        if isinstance(actual, torch.Tensor) and isinstance(pred, torch.Tensor):
            return self.loss(actual, pred)
        else:
            raise TypeError("Inputs must be torch.Tensor".capitalize())


if __name__ == "__main__":
    loss = CycleLoss()
    
    actual = torch.tensor([1.0, 0.0, 1.0, 1.0])
    predicted = torch.tensor([0.0, 1.0, 0.0, 0.0])
    
    print("The loss {}".format(loss(actual, predicted)))

In [None]:
class GradientPenalty(nn.Module):
    def __init__(self, in_channels=3, batch_size=1, device="mps"):
        super(GradientPenalty, self).__init__()

        self.in_channels = in_channels
        self.batch_size = batch_size
        self.device = device

        self.device = device_init(device=self.device)

        self.alpha = torch.randn(
            self.batch_size, self.in_channels // self.in_channels, 1, 1
        )

    def forward(self, netD, X, y):
        if (
            isinstance(netD, Discriminator)
            and isinstance(X, torch.Tensor)
            and isinstance(y, torch.Tensor)
        ):
            interpolated = (self.alpha * X) + ((1 - self.alpha) * y)
            interpolated = interpolated.requires_grad_(True)

            netD = netD.to(self.device)

            d_interpolated = netD(interpolated.to(self.device))

            gradients = torch.autograd.grad(
                outputs=d_interpolated,
                inputs=interpolated,
                grad_outputs=torch.ones_like(d_interpolated).to(self.device),
                create_graph=True,
                retain_graph=True,
            )[0]

            gradients = gradients.view(gradients.size(0), -1)
            gradients = torch.norm(gradients, 2, dim=1)

            return ((gradients - 1) ** 2).mean()

        else:
            raise TypeError("Inputs must be torch.Tensor".capitalize())


if __name__ == "__main__":
    batch_size = 1
    device = device_init(device="mps")
    
    netD = Discriminator(in_channels=in_channels)

    X = torch.randn(1, 3, 256, 256)
    y = torch.randn(1, 3, 256, 256)

    grad_penalty = GradientPenalty(
        in_channels=in_channels, batch_size=batch_size, device=device
    )

    print(grad_penalty(netD, X, y))