In [None]:
import sys
import os
import yaml
import zipfile
import joblib
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary
from torchview import draw_graph

In [None]:
def config():
    with open("../../config.yml", "r") as file:
        config = yaml.safe_load(file)
        
    return config

In [None]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
        
    else:
        raise ValueError("Value or filename is not found".capitalize())
    
def load(filename = None):
    if filename is not None:
        return joblib.load(filename=filename)
        
    else:
        raise ValueError("Filename is not found".capitalize())

In [None]:
class Loader():
    def __init__(self, image_path = None, image_size = 64, batch_size = 4, split_size = 0.50):
        self.image_path = image_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.config = config()

        self.images = []
        self.labels = []

    def split_images(self, **kwargs):
        X_train, X_test, y_train, y_test = train_test_split(
            kwargs["X"], kwargs["y"], test_size=self.split_size, random_state=42)

        return {"X_train": X_train, "X_test": X_test, "y_train": y_train, "y_test": y_test}

    def transforms(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.CenterCrop((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, ], std=[0.5, ])
        ])

    def unzip_folder(self):
        if os.path.exists(self.config["path"]["raw_path"]):

            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(self.config["path"]["raw_path"])

        else:
            raise FileNotFoundError("Raw path is not found for the directory".capitalize())

    def extract_feature(self):

        self.directory = os.path.join(self.config["path"]["raw_path"], "brain")
        self.categories = [file for file in os.listdir(self.directory)]

        for category in self.categories: 
            full_path = os.path.join(self.directory, category)

            for image in os.listdir(full_path): 
                full_image_path = os.path.join(full_path, image)

                image_object = Image.fromarray(cv2.imread(full_image_path))
                self.images.append(self.transforms()(image_object))
                self.labels.append(self.categories.index(category))

        dataset = self.split_images(X = self.images, y = self.labels)

        X_limit = round(len(dataset["X_train"]) * 1.0)
        y_limit = round(len(dataset["y_train"]) * 1.0)

        sub_dataset = self.split_images(
            X = dataset["X_train"][0:X_limit], y = dataset["y_train"][0:y_limit])

        return {
            "X_train": sub_dataset["X_train"],
            "y_train": sub_dataset["y_train"],
            "X_test": sub_dataset["X_test"],
            "y_test": sub_dataset["y_test"],
            "val_train": dataset["X_test"],
            "val_test": dataset["y_test"]
            }

    def create_dataloader(self):
        dataset = self.extract_feature()

        self.train_dataloader = DataLoader(
            dataset=list(zip(dataset["X_train"], dataset["y_train"])), batch_size=self.batch_size, shuffle=True
        )

        self.test_dataloader = DataLoader(
            dataset=list(zip(dataset["X_test"], dataset["y_test"])), batch_size=self.batch_size*4, shuffle=True
        )

        self.val_dataloader = DataLoader(
            dataset=list(zip(dataset["val_train"], dataset["val_test"])), batch_size=self.batch_size, shuffle=True
        )
        if os.path.exists(self.config["path"]["processed_path"]):

            dump(
                value=self.train_dataloader,
                filename=os.path.join(
                    self.config["path"]["processed_path"], "train_dataloader.pkl"
                ),
            )

            dump(
                value=self.test_dataloader,
                filename=os.path.join(
                    self.config["path"]["processed_path"], "test_dataloader.pkl"
                ),
            )

            dump(
                value=self.val_dataloader,
                filename=os.path.join(
                    self.config["path"]["processed_path"], "val_dataloader.pkl"
                ),
            )

        else:
            raise FileNotFoundError("Processed path is not found for the directory".capitalize())

    @staticmethod
    def dataset_details():
        config_files = config()

        if os.path.exists(config_files["path"]["processed_path"]):

            train_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "train_dataloader.pkl"
                ),
            )

            test_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "test_dataloader.pkl"
                ),
            )

            val_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "val_dataloader.pkl"
                ),
            )

            pd.DataFrame(
                {
                    "total_data(Train)": sum(
                        data.size(0) for data, _ in train_dataloader
                    ),
                    "total_data(Test)": sum(
                        data.size(0) for data, _ in test_dataloader
                    ),
                    "total_data(Val)": sum(data.size(0) for data, _ in val_dataloader),
                    "total_batch(Train)": str(len(train_dataloader)),
                    "total_batch(Test)": str(len(test_dataloader)),
                    "total_data": sum(
                        [
                            sum([data.size(0) for data, _ in dataloader])
                            for dataloader in [
                                train_dataloader,
                                test_dataloader,
                                val_dataloader,
                            ]
                        ]
                    ),
                    "train_shape": str(
                        [data.size() for _, (data, _) in enumerate(train_dataloader)][0]
                    ),
                    "test_shape": str(
                        [data.size() for _, (data, _) in enumerate(test_dataloader)][0]
                    ),
                    "val_shape": str(
                        [data.size() for _, (data, _) in enumerate(val_dataloader)][0]
                    ),
                },
                index=["quantity"],
            ).T.to_csv(
                os.path.join(config_files["path"]["files_path"], "dataset_details.csv")
                if os.path.exists(config_files["path"]["files_path"])
                else os.makedirs(config_files["path"]["files_path"])
            )

        else:
            raise FileNotFoundError("Processed path is not found for the directory".capitalize())

    @staticmethod
    def plot_images():
        config_files = config()
        
        plt.figure(figsize=(20, 10))
        
        if os.path.exists(config_files["path"]["processed_path"]):
            test_dataloader = load(filename=os.path.join(config_files["path"]["processed_path"], "test_dataloader.pkl"))
            
            images, labels = next(iter(test_dataloader))
            
            for index, image in enumerate(images):
                image = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
                image = (image - image.min())/(image.max() - image.min())
                
                plt.subplot(4, 4, index + 1)
                
                plt.imshow(image, cmap="gray")
                plt.title("Yes" if labels[index] == 0 else "No")
                plt.axis("off")
                
            plt.tight_layout()
            
            plt.savefig(os.path.join(config_files["path"]["files_path"], "images.png")) \
                if os.path.exists(config_files["path"]["files_path"]) else "Cannot save the images to".capitalize()
            plt.show()
        
        else:
            raise FileNotFoundError("Processed path is not found for the directory".capitalize())

if __name__ == "__main__":
    loader = Loader(image_path="/Users/shahmuhammadraditrahman/Desktop/tumor.zip")
    loader.unzip_folder()
    loader.create_dataloader()
    loader.dataset_details()
    loader.plot_images()

#### Create the discriminator

In [None]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels = 1, out_channels = 512):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 2
        self.padding = 1

        self.discriminator = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            bias=False,
        )

        layers["batch_norm"] = nn.BatchNorm2d(self.out_channels)

        layers["leaky_relu"] = nn.LeakyReLU(0.2, inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if x is not None:
            return self.discriminator(x)
        else:
            raise ValueError("Input is not found".capitalize())
        
    @staticmethod
    def total_params(model = None):
        if model is not None:
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":
    in_channels = 1
    out_channels = 512

    layers = []

    for _ in range(4):
        layers.append(DiscriminatorBlock(in_channels=in_channels, out_channels=out_channels))
        
        in_channels = out_channels
        out_channels //= 2
        
    model = nn.Sequential(*layers)
    
    print(summary(model = model, input_size = (1, 128, 128)))
    
    draw_graph(model = model, input_data = torch.randn(4, 1, 64, 64)).visual_graph

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 1, out_channels = 512):
        super(Discriminator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.layers = []

        for _ in tqdm(range(4)):
            self.layers.append(
                DiscriminatorBlock(
                    in_channels=self.in_channels, out_channels=self.out_channels
                )
            )
            
            self.in_channels = self.out_channels
            self.out_channels //= 2
            
        self.model = nn.Sequential(*self.layers)
        
        self.classification = nn.Sequential(
            nn.Linear(in_features=64 * 4 * 4, out_features=2, bias=False),
            nn.Softmax(dim = 1),
        )
        
        self.validity = nn.Sequential(
            nn.Linear(in_features=64 * 4 * 4, out_features=1, bias=False),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        if x is not None:
            x = self.model(x)
            
            classification = self.classification(x.view(x.size(0), -1))
            validity = self.validity(x.view(x.size(0), -1))
            
            return classification, validity
        
        else:
            raise ValueError("Input is not found".capitalize())
        
        
if __name__ == "__main__":
    netD = Discriminator(in_channels = 1, out_channels = 512)
    
    classification, validity = netD(torch.randn(4, 1, 64, 64))
    
    print(classification.size(), validity.size())
    
    print(summary(model = netD, input_size = (1, 64, 64)))
    
    draw_graph(model = netD, input_data = torch.randn(4, 1, 64, 64)).visual_graph

#### Create Generator

In [None]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 128, is_last = False):
        super(GeneratorBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_last = is_last
        
        self.kernel = 4
        self.stride = 2
        self.padding = 1
        
        self.generator = self.block()
        
    
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding = self.padding)
        
        if self.is_last:
            layers["tanh"] = nn.Tanh()
        
        else:
            layers["batch_norm"] = nn.BatchNorm2d(self.out_channels)
            layers["leaky_relu"] = nn.LeakyReLU(0.2, inplace=True)
            
        return nn.Sequential(layers)
        
    
    
    def forward(self, x):
        if x is not None:
            return self.generator(x)
        
        else:
            raise ValueError("Input is not found".capitalize())


if __name__ == "__main__":
    latent_space = 100
    out_channels = 64
    kernel = 4
    stride = 1
    padding = 0
    num_repetitive = 4 

    layers = []

    layers.append(
        nn.ConvTranspose2d(
            in_channels = latent_space,
            out_channels = out_channels,
            kernel_size = kernel,
            stride = stride,
            padding = padding,
            bias = False,)
    )

    in_channels = out_channels

    for idx in tqdm(range(4)):
        layers.append(
            GeneratorBlock(
                in_channels=in_channels,
                out_channels=1 if idx == (num_repetitive - 1) else in_channels * 2,
                is_last=True if idx == (num_repetitive - 1) else False,
            )
        )

        in_channels *= 2

    model = nn.Sequential(*layers)

    print(model(torch.randn(4, 100, 1, 1)).size())
    
    print(summary(model = model, input_size = (100, 1, 1)))
    
    draw_graph(model = model, input_data = torch.randn(4, 100, 1, 1)).visual_graph