In [None]:
import os
import cv2
import sys
import yaml
import joblib
import zipfile
import argparse
import traceback
from tqdm import tqdm
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

In [None]:
class CustomException(Exception):
    def __init__(self, message: str):
        self.message = message


def dump(value=None, filename=None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)

    else:
        CustomException("Cannot be possble to dump the value".capitalize())


def load(filename: str):
    if isinstance(filename, str):
        return joblib.load(filename=filename)

    else:
        CustomException("Cannot be possble to load the value".capitalize())


def config():
    with open("../../config.yml", "r") as file:
        return yaml.safe_load(file)


def validate_path(path: str):
    if isinstance(path, str):
        if os.path.exists(path):
            return True
        else:
            return False

    else:
        CustomException("Cannot be possble to validate the path".capitalize())

In [None]:
class Loader:
    def __init__(
        self,
        image_path=None,
        image_size: int = 128,
        channels: int = 3,
        batch_size: int = 8,
        split_size: float = 0.20,
        seed_value: int = 0,
    ):
        self.image_path = image_path
        self.image_size = image_size
        self.channels = channels
        self.batch_size = batch_size
        self.split_size = split_size
        self.seed_value = seed_value

        self.CONFIG = config()

        self.independent: list = []
        self.dependent: list = []
        self.lr_independent: list = []
        self.lr_dependent: list = []

    def unzip_folder(self):
        if validate_path(path=self.CONFIG["path"]["RAW_IMAGE_DATA_PATH"]):
            with zipfile.ZipFile(file=self.image_path, mode="r") as zip_file:
                zip_file.extractall(path=self.CONFIG["path"]["RAW_IMAGE_DATA_PATH"])

        else:
            raise CustomException("Raw data path cannot be found".capitalize())

    def transforms(self, type="lr"):
        if type == "hr":
            return transforms.Compose(
                [
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.CenterCrop((self.image_size, self.image_size)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
                ]
            )

        elif type == "lr":
            return transforms.Compose(
                [
                    transforms.Resize((self.image_size // 4, self.image_size // 4)),
                    transforms.CenterCrop((self.image_size // 4, self.image_size // 4)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
                ]
            )

    def split_dataset(self, X: list, y: list):
        if isinstance(X, list) and isinstance(y, list):
            X_train, X_test, y_train, y_test = train_test_split(
                X, y, test_size=self.split_size, random_state=self.seed_value
            )

            return {
                "X_train": X_train,
                "X_test": X_test,
                "y_train": y_train,
                "y_test": y_test,
            }
        else:
            raise CustomException("X and y should be list".capitalize())

    def feature_extractor(self):
        self.directory = os.path.join(
            self.CONFIG["path"]["RAW_IMAGE_DATA_PATH"], "dataset"
        )
        assert (
            self.directory.split("/")[-1] == "dataset"
        ), "Directory name should be dataset"

        self.X = os.path.join(self.directory, "X")
        self.y = os.path.join(self.directory, "y")

        print(self.X, self.y)

        assert (
            self.X.split("/")[-1] == "X" and self.y.split("/")[-1] == "y"
        ), "Directory name should be X and y"

        for _, image in tqdm(enumerate(os.listdir(self.X))):
            if image in os.listdir(self.y):
                self.imageX = os.path.join(self.X, image)
                self.imageY = os.path.join(self.y, image)

                self.imageX = cv2.imread(self.imageX)
                self.imageY = cv2.imread(self.imageY)

                self.imageX = cv2.cvtColor(self.imageX, cv2.COLOR_BGR2RGB)
                self.imageY = cv2.cvtColor(self.imageY, cv2.COLOR_BGR2RGB)

                self.imageX = Image.fromarray(self.imageX)
                self.imageY = Image.fromarray(self.imageY)

                self._imageX = self.transforms(type="hr")(self.imageX)
                self._imageY = self.transforms(type="hr")(self.imageY)

                self.lr_imageX = self.transforms(type="lr")(self.imageX)
                self.lr_imageY = self.transforms(type="lr")(self.imageY)

                self.independent.append(self._imageX)
                self.dependent.append(self._imageY)

                self.lr_independent.append(self.lr_imageX)
                self.lr_dependent.append(self.lr_imageY)

        assert (
            len(self.independent)
            == len(self.dependent)
            == len(self.lr_dependent)
            == len(self.lr_dependent)
        ), "Length of independent and dependent should be equal"

        try:
            dataset = self.split_dataset(X=self.independent, y=self.dependent)
            lr_dataset = self.split_dataset(X=self.lr_independent, y=self.lr_dependent)
        except CustomException as e:
            print("An error occurred: ", e)
            traceback.print_exc()
        except Exception as e:
            print("An error occurred: ", e)
            traceback.print_exc()
        else:
            return dataset, lr_dataset

    def create_dataloader(self):
        dataset, lr_dataset = self.feature_extractor()

        train_dataloader = DataLoader(
            dataset=list(
                zip(dataset["X_train"], dataset["y_train"], lr_dataset["y_train"])
            ),
            batch_size=self.batch_size,
            shuffle=True,
        )
        valid_dataloader = DataLoader(
            dataset=list(
                zip(dataset["X_test"], dataset["y_test"], lr_dataset["y_test"])
            ),
            batch_size=self.batch_size,
            shuffle=True,
        )

        if validate_path(path=self.CONFIG["path"]["PROCESSED_IMAGE_DATA_PATH"]):
            for value, filename in [
                (train_dataloader, "train_dataloader.pkl"),
                (valid_dataloader, "valid_dataloader.pkl"),
            ]:
                dump(
                    value=value,
                    filename=os.path.join(
                        self.CONFIG["path"]["PROCESSED_IMAGE_DATA_PATH"], filename
                    ),
                )

            print(
                "Train and valid dataloader is saved in the folder {}".format(
                    self.CONFIG["path"]["PROCESSED_IMAGE_DATA_PATH"]
                )
            )
        else:
            raise CustomException(
                "Cannot be created the dataloader and processed path is not found".capitalize()
            )

    @staticmethod
    def plot_images():
        processed_data_path = config()["path"]["PROCESSED_IMAGE_DATA_PATH"]
        if validate_path(path=processed_data_path):
            train_dataloader = load(
                filename=os.path.join(processed_data_path, "train_dataloader.pkl")
            )

            X, y, lr = next(iter(train_dataloader))

            num_of_rows = X.size(0) // 2
            num_of_columns = X.size(0) // num_of_rows

            plt.figure(figsize=(20, 30))

            plt.axis("off")

            for index, image in enumerate(X):
                imageX = image.permute(1, 2, 0).detach().numpy()
                imageY = y[index].permute(1, 2, 0).detach().numpy()
                lowerX = lr[index].permute(1, 2, 0).detach().numpy()

                imageX = (imageX - imageX.min()) / (imageX.max() - imageX.min())
                imageY = (imageY - imageY.min()) / (imageY.max() - imageY.min())
                lowerX = (lowerX - lowerX.min()) / (lowerX.max() - lowerX.min())

                plt.subplot(3 * num_of_rows, 3 * num_of_columns, 3 * index + 1)
                plt.imshow(imageX)
                plt.title("X")
                plt.axis("off")

                plt.subplot(3 * num_of_rows, 3 * num_of_columns, 3 * index + 2)
                plt.imshow(imageY)
                plt.title("Y")
                plt.axis("off")

                plt.subplot(3 * num_of_rows, 3 * num_of_columns, 3 * index + 3)
                plt.imshow(lowerX)
                plt.title("lowerY")
                plt.axis("off")

            plt.tight_layout()
            plt.savefig(os.path.join(config()["path"]["FILES_PATH"], "image.jpeg"))
            plt.show()

            print(
                "Image is saved in the folder {}".format(config()["path"]["FILES_PATH"])
            )

        else:
            raise CustomException(
                "Cannot be imported processed path as it is not found".capitalize()
            )

    @staticmethod
    def dataset_details():
        processed_data_path = config()["path"]["PROCESSED_IMAGE_DATA_PATH"]
        if validate_path(path=processed_data_path):
            train_dataloader = load(
                filename=os.path.join(processed_data_path, "train_dataloader.pkl")
            )
            valid_dataloader = load(
                filename=os.path.join(processed_data_path, "valid_dataloader.pkl")
            )

            trainX, trainY, train_lr_Y = next(iter(train_dataloader))
            validX, validY, valid_lr_Y = next(iter(valid_dataloader))

            pd.DataFrame(
                {
                    "total_dataset": str(
                        sum(X.size(0) for X, _, _ in train_dataloader)
                        + sum(X.size(0) for X, _, _ in valid_dataloader)
                    ),
                    "trainX(shape)": str(trainX.size()),
                    "trainY(shape)": str(trainY.size()),
                    "validX(shape)": str(validX.size()),
                    "validY(shape)": str(validY.size()),
                    "train_lr_Y(shape)": str(train_lr_Y.size()),
                    "valid_lr_Y(shape)": str(valid_lr_Y.size()),
                },
                index=["Deatils".title()],
            ).T.to_csv(
                os.path.join(config()["path"]["FILES_PATH"], "dataset_details.csv")
            )

            print(
                "Dataset details are saved in the folder {}".format(
                    config()["path"]["FILES_PATH"]
                )
            )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Dataloader for the CCGAN".title())
    parser.add_argument(
        "--image_path",
        type=str,
        default=config()["dataloader"]["image_path"],
        help="Batch size for the dataloader".capitalize(),
    )
    parser.add_argument(
        "--channels",
        type=int,
        default=config()["dataloader"]["channels"],
        help="Number of channels".capitalize(),
    )
    parser.add_argument(
        "--image_size",
        type=int,
        default=config()["dataloader"]["image_size"],
        help="Image size".capitalize(),
    )
    parser.add_argument(
        "--batch_size",
        type=int,
        default=config()["dataloader"]["batch_size"],
        help="Batch size".capitalize(),
    )
    parser.add_argument(
        "--split_size",
        type=float,
        default=config()["dataloader"]["split_size"],
        help="Split ratio".capitalize(),
    )
    args = parser.parse_args()

    loader = Loader(
        image_path=args.image_path,
        channels=args.channels,
        image_size=args.image_size,
        batch_size=args.batch_size,
        split_size=args.split_size,
    )
    # try:
    #     loader.unzip_folder()
    # except CustomException as e:
    #     print(e)
    # except Exception as e:
    #     print(e)

    try:
        loader.create_dataloader()
    except CustomException as e:
        print("An error occurred: ", e)
        traceback.print_exc()
    except Exception as e:
        print("An error occurred: ", e)
        traceback.print_exc()

#### Encoder Block

In [None]:
import torch
import argparse
import torch.nn as nn

class EncoderBlock(nn.Module):
    def __init__(
        self,
        in_channels: int = 3,
        out_channels: int = 64,
        batchnorm: bool = True,
        leakyrelu: bool = True,
    ):
        super(EncoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.batchnorm = batchnorm
        self.leakyrelu = leakyrelu

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1
        self.momentum = 0.8
        self.negative_slope = 0.2

        self.encoder_block = self.block()

    def block(self):
        self.layers = []

        self.layers.append(
            nn.Conv2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=False,
            )
        )
        if self.batchnorm:
            self.layers.append(
                nn.BatchNorm2d(num_features=self.out_channels, momentum=self.momentum)
            )
        if self.leakyrelu:
            self.layers.append(
                nn.LeakyReLU(negative_slope=self.negative_slope, inplace=True)
            )

        return nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.encoder_block(x)
        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Encoder Block for Generator".title())
    parser.add_argument(
        "--in_channels", type=int, default=3, help="Input Channels".title()
    )
    parser.add_argument(
        "--out_channels", type=int, default=64, help="Output Channels".title()
    )
    parser.add_argument(
        "--batchnorm", type=bool, default=True, help="Batch Normalization".title()
    )
    parser.add_argument(
        "--leakyrelu", type=bool, default=False, help="Leaky ReLU".title()
    )

    args = parser.parse_args()

    encoder = EncoderBlock(
        in_channels=args.in_channels,
        out_channels=args.out_channels,
        batchnorm=args.batchnorm,
        leakyrelu=args.leakyrelu,
    )

    assert encoder(torch.randn(1, 3, 128, 128)).size() == torch.Size(
        [1, 64, 64, 64]
    ), "Encoder Block is not working properly".capitalize()

#### Decoder Block

In [None]:
class DecoderBlock(nn.Module):
    def __init__(
        self,
        in_channels: int = 512,
        out_channels: int = 512,
        batchnorm: bool = True,
        leakyrelu: bool = True,
    ):
        super(DecoderBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.batchnorm = batchnorm
        self.leakyrelu = leakyrelu

        self.kernel_size = 4
        self.stride_size = 2
        self.padding_size = 1
        self.momentum = 0.8

        self.decoder_block = self.block()

    def block(self):
        self.layers = []

        self.layers.append(
            nn.ConvTranspose2d(
                in_channels=self.in_channels,
                out_channels=self.out_channels,
                kernel_size=self.kernel_size,
                stride=self.stride_size,
                padding=self.padding_size,
                bias=False,
            )
        )

        if self.batchnorm:
            self.layers.append(
                nn.BatchNorm2d(num_features=self.out_channels, momentum=self.momentum)
            )

        if self.leakyrelu:
            self.layers.append(nn.LeakyReLU(0.2, inplace=True))
        else:
            self.layers.append(nn.Tanh())

        return nn.Sequential(*self.layers)

    def forward(self, x: torch.Tensor):
        if isinstance(x, torch.Tensor):
            return self.decoder_block(x)

        else:
            raise TypeError("Input must be a torch.Tensor".capitalize())


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Decoder Block for Generator".title())
    parser.add_argument(
        "--in_channels",
        type=int,
        default=512,
        help="Number of input channels".capitalize(),
    )
    parser.add_argument(
        "--out_channels",
        type=int,
        default=512,
        help="Number of output channels".capitalize(),
    )
    parser.add_argument(
        "--batchnorm", type=bool, default=True, help="Batch Normalization".capitalize()
    )
    parser.add_argument(
        "--leakyrelu", type=bool, default=True, help="Leaky ReLU".capitalize()
    )

    args = parser.parse_args()

    decoder = DecoderBlock(
        in_channels=args.in_channels,
        out_channels=args.out_channels,
        batchnorm=args.batchnorm,
        leakyrelu=args.leakyrelu,
    )

    assert decoder(torch.randn(1, 512, 2, 2)).size() == torch.Size(
        [1, 512, 4, 4]
    ), "Decoder Block is not working properly".capitalize()

#### Generator

In [None]:
class Generator(nn.Module):
    def __init__(self, image_size: tuple = (1, 3, 128, 128)):
        super(Generator, self).__init__()
        self.batch_size, self.channels, self.image_height, self.image_width = image_size

        self.encoder1 = EncoderBlock(
            in_channels=self.channels,
            out_channels=self.image_height // 2,
            batchnorm=False,
        )
        self.encoder2 = EncoderBlock(
            in_channels=self.image_height // 2,
            out_channels=self.image_height,
            batchnorm=True,
        )
        self.encoder3 = EncoderBlock(
            in_channels=self.image_height + self.channels,
            out_channels=self.image_height * 2,
            batchnorm=True,
        )
        self.encoder4 = EncoderBlock(
            in_channels=self.image_height * 2,
            out_channels=self.image_height * 4,
            batchnorm=True,
        )
        self.encoder5 = EncoderBlock(
            in_channels=self.image_height * 4,
            out_channels=self.image_height * 4,
            batchnorm=True,
        )
        self.encoder6 = EncoderBlock(
            in_channels=self.image_height * 4,
            out_channels=self.image_height * 4,
            batchnorm=True,
        )

        self.decoder1 = DecoderBlock(
            in_channels=self.image_height * 4, out_channels=self.image_height * 4
        )
        self.decoder2 = DecoderBlock(
            in_channels=self.image_height * 8, out_channels=self.image_height * 4
        )
        self.decoder3 = DecoderBlock(
            in_channels=self.image_height * 8, out_channels=self.image_height * 2
        )
        self.decoder4 = DecoderBlock(
            in_channels=self.image_height * 4, out_channels=self.image_height
        )
        self.decoder5 = DecoderBlock(
            in_channels=self.image_height * 2 + self.channels,
            out_channels=self.image_height // 2,
        )
        self.decoder6 = DecoderBlock(
            in_channels=self.image_height, out_channels=self.channels
        )

    def forward(self, x: torch.Tensor, lr_image: torch.Tensor):
        if (isinstance(x, torch.Tensor)) and (isinstance(lr_image, torch.Tensor)):
            encoder1 = self.encoder1(x)

            encoder2 = self.encoder2(encoder1)

            _encoder2 = torch.cat((encoder2, lr_image), dim=1)

            encoder3 = self.encoder3(_encoder2)
            encoder3 = torch.dropout(input=encoder3, p=0.5, train=self.training)

            encoder4 = self.encoder4(encoder3)
            encoder4 = torch.dropout(input=encoder4, p=0.5, train=self.training)

            encoder5 = self.encoder5(encoder4)
            encoder5 = torch.dropout(input=encoder5, p=0.5, train=self.training)

            encoder6 = self.encoder6(encoder5)
            encoder6 = torch.dropout(input=encoder6, p=0.5, train=self.training)

            decoder1 = torch.cat((self.decoder1(encoder6), encoder5), dim=1)
            decoder1 = torch.dropout(input=decoder1, p=0.5, train=self.training)

            decoder2 = torch.cat((self.decoder2(decoder1), encoder4), dim=1)
            decoder2 = torch.dropout(input=decoder2, p=0.5, train=self.training)

            decoder3 = torch.cat((self.decoder3(decoder2), encoder3), dim=1)
            decoder3 = torch.dropout(input=decoder3, p=0.5, train=self.training)

            decoder4 = torch.cat((self.decoder4(decoder3), _encoder2), dim=1)

            decoder5 = torch.cat((self.decoder5(decoder4), encoder1), dim=1)

            output = self.decoder6(decoder5)

            assert output.size() == (
                self.batch_size,
                self.channels,
                self.image_height,
                self.image_width,
            ), "Image size is incorrect in Generator".capitalize()

            return output

    @staticmethod
    def total_params(model=None):
        if isinstance(model, Generator):
            return sum(params.numel() for params in model.parameters())

        else:
            raise TypeError("Model must be of type Generator".capitalize())


if __name__ == "__main__":
    batch_size = config()["dataloader"]["batch_size"]
    channels = config()["dataloader"]["channels"]
    image_size = config()["dataloader"]["image_size"]

    image = (batch_size, channels, image_size, image_size)

    parser = argparse.ArgumentParser(description="Generator for CCGAN".title())
    parser.add_argument(
        "--image_size",
        type=parse_tuple,
        default=image,
        help="Image size (e.g., '(1,3,128,128)')".capitalize(),
    )

    args = parser.parse_args()

    netG = Generator(image_size=args.image_size)

    assert (
        netG(
            torch.randn(args.image_size),
            torch.randn(
                args.image_size[0],
                args.image_size[1],
                args.image_size[2] // 4,
                args.image_size[3] // 4,
            ),
        ).size()
    ) == (args.image_size), "Image size is incorrect in Generator".capitalize()

    print("Total params of the netG = {}".format(Generator.total_params(netG)))