In [39]:
import sys
import os
import yaml
import zipfile
import joblib
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import OrderedDict
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from torchsummary import summary
from torchview import draw_graph

In [3]:
def config():
    with open("../../config.yml", "r") as file:
        config = yaml.safe_load(file)
        
    return config

In [14]:
def dump(value = None, filename = None):
    if (value is not None) and (filename is not None):
        joblib.dump(value=value, filename=filename)
        
    else:
        raise ValueError("Value or filename is not found".capitalize())
    
def load(filename = None):
    if filename is not None:
        return joblib.load(filename=filename)
        
    else:
        raise ValueError("Filename is not found".capitalize())
    
def weight_init(m):
    classname = m.__class__.__name__
    
    if classname.find("Conv") != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        
    if classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
        
        
def device_init(device = "mps"):
    if device == "mps":
        return torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    
    elif device == "cuda":
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    else:
        return torch.device("cpu")

In [36]:
class Loader:
    def __init__(self, image_path=None, image_size=64, batch_size=4, split_size=0.50):
        self.image_path = image_path
        self.image_size = image_size
        self.batch_size = batch_size
        self.split_size = split_size

        self.config = config()

        self.images = []
        self.labels = []

    def split_images(self, **kwargs):
        X_train, X_test, y_train, y_test = train_test_split(
            kwargs["X"], kwargs["y"], test_size=self.split_size, random_state=42
        )

        return {
            "X_train": X_train,
            "X_test": X_test,
            "y_train": y_train,
            "y_test": y_test,
        }

    def transforms(self):
        return transforms.Compose(
            [
                transforms.Resize((self.image_size, self.image_size)),
                transforms.CenterCrop((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[
                        0.5,
                    ],
                    std=[
                        0.5,
                    ],
                ),
            ]
        )

    def unzip_folder(self):
        if os.path.exists(self.config["path"]["raw_path"]):

            with zipfile.ZipFile(self.image_path, "r") as zip_ref:
                zip_ref.extractall(self.config["path"]["raw_path"])

        else:
            raise FileNotFoundError(
                "Raw path is not found for the directory".capitalize()
            )

    def extract_feature(self):

        self.directory = os.path.join(self.config["path"]["raw_path"], "brain")
        self.categories = [file for file in os.listdir(self.directory)]

        for category in self.categories:
            full_path = os.path.join(self.directory, category)

            for image in os.listdir(full_path):
                full_image_path = os.path.join(full_path, image)

                image_object = Image.fromarray(cv2.imread(full_image_path))
                self.images.append(self.transforms()(image_object))
                self.labels.append(self.categories.index(category))

        dataset = self.split_images(X=self.images, y=self.labels)

        X_limit = round(len(dataset["X_train"]) * 1.0)
        y_limit = round(len(dataset["y_train"]) * 1.0)

        sub_dataset = self.split_images(
            X=dataset["X_train"][0:X_limit], y=dataset["y_train"][0:y_limit]
        )

        return {
            "X_train": sub_dataset["X_train"],
            "y_train": sub_dataset["y_train"],
            "X_test": sub_dataset["X_test"],
            "y_test": sub_dataset["y_test"],
            "val_train": dataset["X_test"],
            "val_test": dataset["y_test"],
        }

    def create_dataloader(self):
        dataset = self.extract_feature()

        self.train_dataloader = DataLoader(
            dataset=list(zip(dataset["X_train"], dataset["y_train"])),
            batch_size=self.batch_size,
            shuffle=True,
        )

        self.test_dataloader = DataLoader(
            dataset=list(zip(dataset["X_test"], dataset["y_test"])),
            batch_size=self.batch_size * 4,
            shuffle=True,
        )

        self.val_dataloader = DataLoader(
            dataset=list(zip(dataset["val_train"], dataset["val_test"])),
            batch_size=self.batch_size,
            shuffle=True,
        )
        if os.path.exists(self.config["path"]["processed_path"]):

            dump(
                value=self.train_dataloader,
                filename=os.path.join(
                    self.config["path"]["processed_path"], "train_dataloader.pkl"
                ),
            )

            dump(
                value=self.test_dataloader,
                filename=os.path.join(
                    self.config["path"]["processed_path"], "test_dataloader.pkl"
                ),
            )

            dump(
                value=self.val_dataloader,
                filename=os.path.join(
                    self.config["path"]["processed_path"], "val_dataloader.pkl"
                ),
            )

        else:
            raise FileNotFoundError(
                "Processed path is not found for the directory".capitalize()
            )

    @staticmethod
    def dataset_details():
        config_files = config()

        if os.path.exists(config_files["path"]["processed_path"]):

            train_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "train_dataloader.pkl"
                ),
            )

            test_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "test_dataloader.pkl"
                ),
            )

            val_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "val_dataloader.pkl"
                ),
            )

            pd.DataFrame(
                {
                    "total_data(Train)": sum(
                        data.size(0) for data, _ in train_dataloader
                    ),
                    "total_data(Test)": sum(
                        data.size(0) for data, _ in test_dataloader
                    ),
                    "total_data(Val)": sum(data.size(0) for data, _ in val_dataloader),
                    "total_batch(Train)": str(len(train_dataloader)),
                    "total_batch(Test)": str(len(test_dataloader)),
                    "total_data": sum(
                        [
                            sum([data.size(0) for data, _ in dataloader])
                            for dataloader in [
                                train_dataloader,
                                test_dataloader,
                                val_dataloader,
                            ]
                        ]
                    ),
                    "train_shape": str(
                        [data.size() for _, (data, _) in enumerate(train_dataloader)][0]
                    ),
                    "test_shape": str(
                        [data.size() for _, (data, _) in enumerate(test_dataloader)][0]
                    ),
                    "val_shape": str(
                        [data.size() for _, (data, _) in enumerate(val_dataloader)][0]
                    ),
                },
                index=["quantity"],
            ).T.to_csv(
                os.path.join(config_files["path"]["files_path"], "dataset_details.csv")
                if os.path.exists(config_files["path"]["files_path"])
                else os.makedirs(config_files["path"]["files_path"])
            )

        else:
            raise FileNotFoundError(
                "Processed path is not found for the directory".capitalize()
            )

    @staticmethod
    def plot_images():
        config_files = config()

        plt.figure(figsize=(20, 10))

        if os.path.exists(config_files["path"]["processed_path"]):
            test_dataloader = load(
                filename=os.path.join(
                    config_files["path"]["processed_path"], "test_dataloader.pkl"
                )
            )

            images, labels = next(iter(test_dataloader))

            for index, image in enumerate(images):
                image = image.squeeze().permute(1, 2, 0).cpu().detach().numpy()
                image = (image - image.min()) / (image.max() - image.min())

                plt.subplot(4, 4, index + 1)

                plt.imshow(image, cmap="gray")
                plt.title("Yes" if labels[index] == 0 else "No")
                plt.axis("off")

            plt.tight_layout()

            (
                plt.savefig(
                    os.path.join(config_files["path"]["files_path"], "images.png")
                )
                if os.path.exists(config_files["path"]["files_path"])
                else "Cannot save the images to".capitalize()
            )
            plt.show()

        else:
            raise FileNotFoundError(
                "Processed path is not found for the directory".capitalize()
            )


if __name__ == "__main__":
    loader = Loader(image_path="/Users/shahmuhammadraditrahman/Desktop/tumor.zip")
    # loader.unzip_folder()
    loader.create_dataloader()
    loader.dataset_details()
    loader.plot_images()

FileNotFoundError: [Errno 2] No such file or directory: './data/raw/brain'

#### Create the discriminator

In [6]:
class DiscriminatorBlock(nn.Module):
    def __init__(self, in_channels = 1, out_channels = 512):
        super(DiscriminatorBlock, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.kernel = 3
        self.stride = 2
        self.padding = 1

        self.discriminator = self.block()

    def block(self):
        layers = OrderedDict()

        layers["conv"] = nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding=self.padding,
            bias=False,
        )

        layers["batch_norm"] = nn.BatchNorm2d(self.out_channels)

        layers["leaky_relu"] = nn.LeakyReLU(0.2, inplace=True)

        return nn.Sequential(layers)

    def forward(self, x):
        if x is not None:
            return self.discriminator(x)
        else:
            raise ValueError("Input is not found".capitalize())
        
    @staticmethod
    def total_params(model = None):
        if model is not None:
            return sum(params.numel() for params in model.parameters())


if __name__ == "__main__":
    in_channels = 1
    out_channels = 512

    layers = []

    for _ in range(4):
        layers.append(DiscriminatorBlock(in_channels=in_channels, out_channels=out_channels))
        
        in_channels = out_channels
        out_channels //= 2
        
    model = nn.Sequential(*layers)
    
    print(summary(model = model, input_size = (1, 128, 128)))
    
    draw_graph(model = model, input_data = torch.randn(4, 1, 64, 64)).visual_graph

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 512, 64, 64]           4,608
       BatchNorm2d-2          [-1, 512, 64, 64]           1,024
         LeakyReLU-3          [-1, 512, 64, 64]               0
DiscriminatorBlock-4          [-1, 512, 64, 64]               0
            Conv2d-5          [-1, 256, 32, 32]       1,179,648
       BatchNorm2d-6          [-1, 256, 32, 32]             512
         LeakyReLU-7          [-1, 256, 32, 32]               0
DiscriminatorBlock-8          [-1, 256, 32, 32]               0
            Conv2d-9          [-1, 128, 16, 16]         294,912
      BatchNorm2d-10          [-1, 128, 16, 16]             256
        LeakyReLU-11          [-1, 128, 16, 16]               0
DiscriminatorBlock-12          [-1, 128, 16, 16]               0
           Conv2d-13             [-1, 64, 8, 8]          73,728
      BatchNorm2d-14             [-1, 

In [7]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 1, out_channels = 512):
        super(Discriminator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels

        self.layers = []

        for _ in tqdm(range(4)):
            self.layers.append(
                DiscriminatorBlock(
                    in_channels=self.in_channels, out_channels=self.out_channels
                )
            )
            
            self.in_channels = self.out_channels
            self.out_channels //= 2
            
        self.model = nn.Sequential(*self.layers)
        
        self.classification = nn.Sequential(
            nn.Linear(in_features=64 * 4 * 4, out_features=2, bias=False),
            nn.Softmax(dim = 1),
        )
        
        self.validity = nn.Sequential(
            nn.Linear(in_features=64 * 4 * 4, out_features=1, bias=False),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        if x is not None:
            x = self.model(x)
            
            classification = self.classification(x.view(x.size(0), -1))
            validity = self.validity(x.view(x.size(0), -1))
            
            return classification, validity
        
        else:
            raise ValueError("Input is not found".capitalize())
        
        
if __name__ == "__main__":
    netD = Discriminator(in_channels = 1, out_channels = 512)
    
    classification, validity = netD(torch.randn(4, 1, 64, 64))
    
    print(classification.size(), validity.size())
    
    print(summary(model = netD, input_size = (1, 64, 64)))
    
    draw_graph(model = netD, input_data = torch.randn(4, 1, 64, 64)).visual_graph

100%|██████████| 4/4 [00:00<00:00, 529.23it/s]

torch.Size([4, 2]) torch.Size([4, 1])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 512, 32, 32]           4,608
       BatchNorm2d-2          [-1, 512, 32, 32]           1,024
         LeakyReLU-3          [-1, 512, 32, 32]               0
DiscriminatorBlock-4          [-1, 512, 32, 32]               0
            Conv2d-5          [-1, 256, 16, 16]       1,179,648
       BatchNorm2d-6          [-1, 256, 16, 16]             512
         LeakyReLU-7          [-1, 256, 16, 16]               0
DiscriminatorBlock-8          [-1, 256, 16, 16]               0
            Conv2d-9            [-1, 128, 8, 8]         294,912
      BatchNorm2d-10            [-1, 128, 8, 8]             256
        LeakyReLU-11            [-1, 128, 8, 8]               0
DiscriminatorBlock-12            [-1, 128, 8, 8]               0
           Conv2d-13             [-1, 64, 4, 4]          73,728





#### Create Generator

In [8]:
class GeneratorBlock(nn.Module):
    def __init__(self, in_channels = 64, out_channels = 128, is_last = False):
        super(GeneratorBlock, self).__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.is_last = is_last
        
        self.kernel = 4
        self.stride = 2
        self.padding = 1
        
        self.generator = self.block()
        
    
    def block(self):
        layers = OrderedDict()
        
        layers["conv"] = nn.ConvTranspose2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=self.kernel,
            stride=self.stride,
            padding = self.padding)
        
        if self.is_last:
            layers["tanh"] = nn.Tanh()
        
        else:
            layers["batch_norm"] = nn.BatchNorm2d(self.out_channels)
            layers["leaky_relu"] = nn.LeakyReLU(0.2, inplace=True)
            
        return nn.Sequential(layers)
        
    
    
    def forward(self, x):
        if x is not None:
            return self.generator(x)
        
        else:
            raise ValueError("Input is not found".capitalize())


if __name__ == "__main__":
    latent_space = 100
    out_channels = 64
    kernel = 4
    stride = 1
    padding = 0
    num_repetitive = 4 

    layers = []

    layers.append(
        nn.ConvTranspose2d(
            in_channels = latent_space,
            out_channels = out_channels,
            kernel_size = kernel,
            stride = stride,
            padding = padding,
            bias = False,)
    )

    in_channels = out_channels

    for idx in tqdm(range(4)):
        layers.append(
            GeneratorBlock(
                in_channels=in_channels,
                out_channels=1 if idx == (num_repetitive - 1) else in_channels * 2,
                is_last=True if idx == (num_repetitive - 1) else False,
            )
        )

        in_channels *= 2

    model = nn.Sequential(*layers)

    print(model(torch.randn(4, 100, 1, 1)).size())
    
    print(summary(model = model, input_size = (100, 1, 1)))
    
    draw_graph(model = model, input_data = torch.randn(4, 100, 1, 1)).visual_graph

100%|██████████| 4/4 [00:00<00:00, 375.61it/s]

torch.Size([4, 1, 64, 64])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1             [-1, 64, 4, 4]         102,400
   ConvTranspose2d-2            [-1, 128, 8, 8]         131,200
       BatchNorm2d-3            [-1, 128, 8, 8]             256
         LeakyReLU-4            [-1, 128, 8, 8]               0
    GeneratorBlock-5            [-1, 128, 8, 8]               0
   ConvTranspose2d-6          [-1, 256, 16, 16]         524,544
       BatchNorm2d-7          [-1, 256, 16, 16]             512
         LeakyReLU-8          [-1, 256, 16, 16]               0
    GeneratorBlock-9          [-1, 256, 16, 16]               0
  ConvTranspose2d-10          [-1, 512, 32, 32]       2,097,664
      BatchNorm2d-11          [-1, 512, 32, 32]           1,024
        LeakyReLU-12          [-1, 512, 32, 32]               0
   GeneratorBlock-13          [-1, 512, 32, 32]               0
  ConvTransp




In [9]:
class Generator(nn.Module):
    def __init__(self, latent_space = 100, out_channels = 1):
        super(Generator, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels * 64

        self.kernel = 4
        self.stride = 1
        self.padding = 0

        self.layers = []

        self.layers.append(
            nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels=latent_space,
                    out_channels=self.out_channels,
                    kernel_size=self.kernel,
                    stride=self.stride,
                    padding=self.padding,
                    bias=False,
                ),
                nn.BatchNorm2d(self.out_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        )
        self.in_channels = self.out_channels

        for idx in tqdm(range(4)):
            self.layers.append(
                GeneratorBlock(
                    in_channels=self.in_channels,
                    out_channels=1 if idx == (4 - 1) else self.in_channels * 2,
                    is_last=True if idx == (4 - 1) else False,
                )
            )
            
            self.in_channels *= 2
            
        self.model = nn.Sequential(*self.layers)

    def forward(self, x):
        if x is not None:
            return self.model(x)

        else:
            raise ValueError("Input is not found".capitalize())
        
    @staticmethod
    def total_params(model = None):
        if model is not None:
            return sum(params.numel() for params in model.parameters())
        
        
if __name__ == "__main__":
    netG = Generator(latent_space = 100, out_channels = 1)
    
    print(netG(torch.randn(4, 100, 1, 1)).size())
    
    print(summary(model = netG, input_size = (100, 1, 1)))
    
    draw_graph(model = netG, input_data = torch.randn(4, 100, 1, 1)).visual_graph

100%|██████████| 4/4 [00:00<00:00, 331.74it/s]

torch.Size([4, 1, 64, 64])
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
   ConvTranspose2d-1             [-1, 64, 4, 4]         102,400
       BatchNorm2d-2             [-1, 64, 4, 4]             128
         LeakyReLU-3             [-1, 64, 4, 4]               0
   ConvTranspose2d-4            [-1, 128, 8, 8]         131,200
       BatchNorm2d-5            [-1, 128, 8, 8]             256
         LeakyReLU-6            [-1, 128, 8, 8]               0
    GeneratorBlock-7            [-1, 128, 8, 8]               0
   ConvTranspose2d-8          [-1, 256, 16, 16]         524,544
       BatchNorm2d-9          [-1, 256, 16, 16]             512
        LeakyReLU-10          [-1, 256, 16, 16]               0
   GeneratorBlock-11          [-1, 256, 16, 16]               0
  ConvTranspose2d-12          [-1, 512, 32, 32]       2,097,664
      BatchNorm2d-13          [-1, 512, 32, 32]           1,024
        Leak




#### Define the loss

In [10]:
class ClassificationLoss(nn.Module):
    def __init__(self):
        super(ClassificationLoss, self).__init__()

        self.name = "adversarial loss".capitalize()

        self.criterion = nn.CrossEntropyLoss()

    def forward(self, predict, actual):
        if (predict is not None) and (actual is not None):
            return self.criterion(predict, actual)

        else:
            raise ValueError("Input is not found".capitalize())


if __name__ == "__main__":
    adversarial_loss = ClassificationLoss()
    
    actual = torch.tensor([0., 1., 0., 1., 0.])
    predict = torch.tensor([1., 0., 1., 0., 1.])
    
    print(adversarial_loss(predict, actual))

tensor(4.6359)


In [11]:
class ValidityLoss(nn.Module):
    def __init__(self):
        super(ValidityLoss, self).__init__()
        
        self.name = "Validity loss".capitalize()
        
        self.criterion = nn.BCELoss()
    def forward(self, predict, actual):
        if (predict is not None) and (actual is not None):
            return self.criterion(predict, actual)
        
        else:
            raise ValueError("Input is not found".capitalize())
        
        
if __name__ == "__main__":
    validity_loss = ValidityLoss()
    
    actual = torch.tensor([1., 0., 1., 1., 1.])
    predict = torch.tensor([1., 0., 1., 0., 1.])
    
    print(validity_loss(predict, actual))

tensor(20.)


#### Helper method

In [55]:
from torch.optim.lr_scheduler import StepLR

def load_dataloader():
    config_files = config()

    if os.path.exists(config_files["path"]["processed_path"]):
        train_dataloader = load(
            os.path.join(config_files["path"]["processed_path"], "train_dataloader.pkl")
        )
        test_dataloader = load(
            os.path.join(config_files["path"]["processed_path"], "test_dataloader.pkl")
        )
        val_dataloader = load(
            os.path.join(config_files["path"]["processed_path"], "val_dataloader.pkl")
        )

        return {
            "train_dataloader": train_dataloader,
            "test_dataloader": test_dataloader,
            "val_dataloader": val_dataloader
        }

    else:
        raise FileNotFoundError("Could not find the dataloader".capitalize())
    
def init_loss():
    adversarial_loss = ClassificationLoss()
    criterion_loss = ValidityLoss()
    
    return {"adversarial_loss": adversarial_loss,"criterion_loss": criterion_loss}
    
def helper(**kwargs):
    lr = kwargs["lr"]
    adam = kwargs["adam"]
    SGD = kwargs["SGD"]
    
    netG = Generator(latent_space= 100 , out_channels=1)
    netD = Discriminator(in_channels = 1, out_channels = 512)
    
    if adam:
        optimizerG = optim.Adam(params=netG.parameters(), lr=lr, betas=(0.5, 0.999))
        optimizerD = optim.Adam(params=netD.parameters(), lr=lr, betas=(0.5, 0.999))
    
    if SGD:
        optimizerG = optim.SGD(params=netG.parameters(), lr=lr, momentum=0.85)
        optimizerD = optim.SGD(params=netD.parameters(), lr=lr, momentum=0.85)
        
    try:
        loss = init_loss()
        dataloader = load_dataloader()
    
    except AttributeError as e:
        print("The exception raised {}".format(e))
        
    finally:
        print("All are extracted...".capitalize())
        
    return {
        "train_dataloader": dataloader["train_dataloader"],
        "test_dataloader": dataloader["test_dataloader"],
        "val_dataloader": dataloader["val_dataloader"],
        "netG": netG,
        "netD": netD,
        "optimizerG": optimizerG,
        "optimizerD": optimizerD,
        "adversarial_loss": loss["adversarial_loss"],
        "criterion_loss": loss["criterion_loss"],
    }


if __name__ == "__main__":
    init = helper(
        lr = 0.0002,
        adam = True,
        SGD = False,
    )

100%|██████████| 4/4 [00:00<00:00, 389.75it/s]
100%|██████████| 4/4 [00:00<00:00, 641.23it/s]

All are extracted...





In [75]:
class Trainer:
    def __init__(
        self,
        epochs=200,
        lr=0.0002,
        latent=100,
        adam=True,
        SGD=False,
        device="mps",
        lr_scheduler=False,
        l1_loss=False,
        l2_loss=False,
        elastic_net=False,
        is_display=False,
        is_weight_init=True,
    ):

        self.epochs = epochs
        self.lr = lr
        self.latent = latent
        self.adam = adam
        self.SGD = SGD
        self.device = device
        self.lr_scheduler = lr_scheduler
        self.l1_loss = l1_loss
        self.l2_loss = l2_loss
        self.elastic = elastic_net
        self.is_display = is_display
        self.weight_init = is_weight_init

        init = helper(
            lr=self.lr,
            adam=self.adam,
            SGD=self.SGD,
        )

        self.train_dataloader = init["train_dataloader"]
        self.test_dataloader = init["test_dataloader"]
        self.val_dataloader = init["val_dataloader"]

        self.netG = init["netG"]
        self.netD = init["netD"]

        if self.weight_init:
            self.netG.apply(weight_init)
            self.netD.apply(weight_init)

        self.netG.to(self.device)
        self.netD.to(self.device)

        self.optimizerG = init["optimizerG"]
        self.optimizerD = init["optimizerD"]

        self.adversarial_loss = init["adversarial_loss"]
        self.criterion_loss = init["criterion_loss"]

        if self.lr_scheduler:
            self.schedulerG = StepLR(self.optimizerG, step_size=10, gamma=0.50)
            self.schedulerD = StepLR(self.optimizerD, step_size=10, gamma=0.50)

        self.config = config()

    def l1(self, model):
        if model is not None:
            return sum(
                torch.norm(params, 1) for params in model.parameters())

        else:
            raise ValueError("Model is not found".capitalize())

    def l2(self, model):
        if model is not None:
            return sum(
                torch.norm(params, 2) for params in model.parameters())
        else:
            raise ValueError("Model is not found".capitalize())

    def elastic_net(self, model):
        l1 = self.l1(model=model)
        l2 = self.l2(model=model)

        return 0.01 * (l1 + l2)

    def saved_checkpoints(self, **kwargs):
        if os.path.exists(self.config["path"]["train_models"]):
            torch.save(
                self.netD.state_dict(), os.path.join(self.config["path"]["train_models"], "model{}.pth".format(kwargs["epoch"]))
            )

        else:
            raise Exception("Cannot be saved the model".capitalize())

    def saved_metrics(self, **kwargs):
        pass

    def saved_train_images(self, **kwargs):
        pass

    def update_discriminator_model(self, **kwargs):
        self.optimizerD.zero_grad()

        fake_clf_pred, fake_validity_pred = self.netD(
            self.netG(kwargs["fake_samples"])
        )
        fake_validity_loss = self.criterion_loss(fake_validity_pred, kwargs["fake_zeros"])
        fake_clf_loss = self.adversarial_loss(fake_clf_pred, kwargs["fake_labels"])

        real_clf_pred, real_validity_pred = self.netD(kwargs["images"])
        real_validity_loss = self.criterion_loss(real_validity_pred, kwargs["real_ones"])
        real_clf_loss = self.adversarial_loss(real_clf_pred, kwargs["labels"])

        total_loss = 0.5 * (
            fake_validity_loss + fake_clf_loss + real_validity_loss + real_clf_loss
        )

        if self.l1_loss:
            total_loss += 0.01 * self.l1(model = self.netD)

        if self.l2_loss:

            total_loss += 0.01 * self.l2(model = self.netD)

        if self.elastic_net:
            total_loss += 0.01 * self.elastic_net(model = self.netD)

        total_loss.backward()
        self.optimizerD.step()

        train_accuracy = accuracy_score(
            torch.argmax(real_clf_pred, dim=1).cpu().detach().numpy(),
            kwargs["labels"].cpu().detach().numpy(),
        )

        return {"loss": total_loss.item(), "accuracy": train_accuracy}

    def update_generator_model(self, **kwargs):
        self.netG.zero_grad()

        _, generated_predict = self.netD(
            self.netG(kwargs["fake_samples"])
        )

        loss = self.criterion_loss(generated_predict, kwargs["real_ones"])

        if self.l1_loss:
            loss += 0.01 * self.l1(model = self.netG)

        if self.l2_loss:
            loss += 0.01 * self.l2(model = self.netG)

        if self.elastic_net:
            loss += 0.01 * self.elastic_net(model = self.netG)

        loss.backward()
        self.optimizerG.step()

        return {"loss": loss.item()}

    def show_progress(self, **kwargs):
        if self.is_display:
            print("Epochs - [{}/{}] - netG_loss: [{:.4f}] - netD_loss: [{:.4f}] - train_acc: [{:.4f}] - test_accu: [{:.4f}]".format(
                kwargs["epoch"],
                self.epochs,
                kwargs["netG_loss"],
                kwargs["netD_loss"],
                kwargs["train_accuracy"],
                kwargs["test_accuracy"],
            ))

        else:
            print(
                "Epoch -[{}/{}] is completed.".format(
                    kwargs["epoch"],
                    self.epochs
                )
            )

    def train(self):
        for epoch in tqdm(range(self.epochs)):
            self.netG_loss = []
            self.netD_loss = []
            self.train_accuracy = []
            self.test_accuracy = []

            for _, (images, labels) in enumerate(self.train_dataloader):
                images = images.to(self.device)
                labels = labels.to(self.device)

                batch = images.size(0)
                fake_samples = torch.randn((batch, self.latent, 1, 1)).to(
                    self.device
                )
                fake_labels = torch.randint(0, 2, (batch, 2)).to(self.device) 
                real_ones = torch.ones((batch, 1)).to(self.device)  
                fake_zeros = torch.zeros((batch, 1)).to(self.device)  

                netD = self.update_discriminator_model(
                    images=images,
                    labels=labels,
                    fake_samples=fake_samples,
                    fake_labels = fake_labels,
                    real_ones=real_ones,
                    fake_zeros=fake_zeros,
                )

                self.netD_loss.append(netD["loss"])
                self.train_accuracy.append(netD["accuracy"])

                netG = self.update_generator_model(
                    images=images,
                    fake_samples=fake_samples,
                    real_ones=real_ones,
                )

                self.netG_loss.append(netG["loss"])

            for _, (images, labels) in enumerate(self.test_dataloader):
                images = images.to(self.device)
                labels = labels.to(self.device)

                clf_pred, _ = self.netD(images)
                clf_labels = torch.argmax(clf_pred, dim=1)
                clf_labels = clf_labels.cpu().detach().numpy()

                self.test_accuracy.append(
                    accuracy_score(clf_labels, labels.cpu().detach().numpy())
                )

            if self.lr_scheduler:
                self.schedulerG.step()
                self.schedulerD.step()

            if self.is_display:  
                self.show_progress(
                    epoch=epoch+1,
                    netG_loss=np.mean(self.netG_loss),
                    netD_loss=np.mean(self.netD_loss),
                    train_accuracy=np.mean(self.train_accuracy),
                    test_accuracy=np.mean(self.test_accuracy)
                )

            self.saved_checkpoints(epoch=epoch+1)

    @staticmethod
    def plot_history():
        pass


if __name__ == "__main__":
    trainer = Trainer(epochs=200, lr = 0.001, is_display=True, lr_scheduler=True)
    trainer.train()

100%|██████████| 4/4 [00:00<00:00, 442.03it/s]
100%|██████████| 4/4 [00:00<00:00, 263.82it/s]


All are extracted...


  0%|          | 1/200 [00:01<05:31,  1.67s/it]

Epochs - [1/200] - netG_loss: [9.2994] - netD_loss: [3.0907] - train_acc: [0.5781] - test_accu: [0.7584]


  1%|          | 2/200 [00:03<05:00,  1.52s/it]

Epochs - [2/200] - netG_loss: [9.0146] - netD_loss: [2.9731] - train_acc: [0.8281] - test_accu: [0.6034]


  2%|▏         | 3/200 [00:04<04:47,  1.46s/it]

Epochs - [3/200] - netG_loss: [9.1927] - netD_loss: [2.6795] - train_acc: [0.8281] - test_accu: [0.6070]


  2%|▏         | 4/200 [00:05<04:42,  1.44s/it]

Epochs - [4/200] - netG_loss: [9.9613] - netD_loss: [2.4929] - train_acc: [0.8594] - test_accu: [0.7392]


  2%|▎         | 5/200 [00:07<04:40,  1.44s/it]

Epochs - [5/200] - netG_loss: [8.7109] - netD_loss: [2.5287] - train_acc: [0.8750] - test_accu: [0.7825]


  3%|▎         | 6/200 [00:08<04:38,  1.43s/it]

Epochs - [6/200] - netG_loss: [6.9274] - netD_loss: [2.5353] - train_acc: [0.8906] - test_accu: [0.7584]


  4%|▎         | 7/200 [00:10<04:35,  1.42s/it]

Epochs - [7/200] - netG_loss: [7.8493] - netD_loss: [2.5373] - train_acc: [0.8438] - test_accu: [0.7512]


  4%|▍         | 8/200 [00:11<04:32,  1.42s/it]

Epochs - [8/200] - netG_loss: [7.4161] - netD_loss: [2.2474] - train_acc: [0.9219] - test_accu: [0.7668]


  4%|▍         | 9/200 [00:12<04:30,  1.42s/it]

Epochs - [9/200] - netG_loss: [8.2207] - netD_loss: [2.0332] - train_acc: [0.9375] - test_accu: [0.7091]


  5%|▌         | 10/200 [00:14<04:29,  1.42s/it]

Epochs - [10/200] - netG_loss: [7.7750] - netD_loss: [2.0795] - train_acc: [0.9688] - test_accu: [0.7512]


  6%|▌         | 11/200 [00:15<04:31,  1.44s/it]

Epochs - [11/200] - netG_loss: [7.4635] - netD_loss: [1.8462] - train_acc: [0.9688] - test_accu: [0.7668]


  6%|▌         | 12/200 [00:17<04:29,  1.43s/it]

Epochs - [12/200] - netG_loss: [7.4239] - netD_loss: [1.8618] - train_acc: [0.9844] - test_accu: [0.8450]


  6%|▋         | 13/200 [00:18<04:27,  1.43s/it]

Epochs - [13/200] - netG_loss: [7.4906] - netD_loss: [1.7470] - train_acc: [0.9844] - test_accu: [0.8053]


  7%|▋         | 14/200 [00:20<04:27,  1.44s/it]

Epochs - [14/200] - netG_loss: [8.2467] - netD_loss: [1.6186] - train_acc: [1.0000] - test_accu: [0.7933]


  8%|▊         | 15/200 [00:21<04:24,  1.43s/it]

Epochs - [15/200] - netG_loss: [8.6118] - netD_loss: [1.5559] - train_acc: [1.0000] - test_accu: [0.8017]


  8%|▊         | 16/200 [00:23<04:21,  1.42s/it]

Epochs - [16/200] - netG_loss: [8.4920] - netD_loss: [1.5197] - train_acc: [1.0000] - test_accu: [0.7897]


  8%|▊         | 17/200 [00:24<04:18,  1.41s/it]

Epochs - [17/200] - netG_loss: [8.4401] - netD_loss: [1.5503] - train_acc: [0.9844] - test_accu: [0.7704]


  9%|▉         | 18/200 [00:25<04:16,  1.41s/it]

Epochs - [18/200] - netG_loss: [8.1658] - netD_loss: [1.5954] - train_acc: [0.9844] - test_accu: [0.7668]


 10%|▉         | 19/200 [00:27<04:14,  1.41s/it]

Epochs - [19/200] - netG_loss: [8.1924] - netD_loss: [1.6200] - train_acc: [0.9844] - test_accu: [0.8053]


 10%|█         | 20/200 [00:28<04:12,  1.40s/it]

Epochs - [20/200] - netG_loss: [7.3764] - netD_loss: [1.5487] - train_acc: [1.0000] - test_accu: [0.8209]


 10%|█         | 21/200 [00:29<04:11,  1.40s/it]

Epochs - [21/200] - netG_loss: [7.1736] - netD_loss: [1.5612] - train_acc: [1.0000] - test_accu: [0.7668]


 11%|█         | 22/200 [00:31<04:12,  1.42s/it]

Epochs - [22/200] - netG_loss: [7.2764] - netD_loss: [1.4159] - train_acc: [1.0000] - test_accu: [0.7825]


 12%|█▏        | 23/200 [00:32<04:09,  1.41s/it]

Epochs - [23/200] - netG_loss: [7.8257] - netD_loss: [1.4104] - train_acc: [1.0000] - test_accu: [0.7356]


 12%|█▏        | 24/200 [00:34<04:07,  1.40s/it]

Epochs - [24/200] - netG_loss: [7.7529] - netD_loss: [1.3756] - train_acc: [1.0000] - test_accu: [0.8137]


 12%|█▎        | 25/200 [00:35<04:05,  1.40s/it]

Epochs - [25/200] - netG_loss: [7.9060] - netD_loss: [1.3798] - train_acc: [1.0000] - test_accu: [0.8017]


 13%|█▎        | 26/200 [00:37<04:04,  1.40s/it]

Epochs - [26/200] - netG_loss: [8.1627] - netD_loss: [1.3345] - train_acc: [1.0000] - test_accu: [0.8053]


 14%|█▎        | 27/200 [00:38<04:05,  1.42s/it]

Epochs - [27/200] - netG_loss: [8.4999] - netD_loss: [1.2812] - train_acc: [1.0000] - test_accu: [0.7740]


 14%|█▍        | 28/200 [00:39<04:02,  1.41s/it]

Epochs - [28/200] - netG_loss: [8.5840] - netD_loss: [1.2442] - train_acc: [1.0000] - test_accu: [0.7933]


 14%|█▍        | 29/200 [00:41<04:00,  1.41s/it]

Epochs - [29/200] - netG_loss: [8.8738] - netD_loss: [1.2196] - train_acc: [1.0000] - test_accu: [0.8017]


 15%|█▌        | 30/200 [00:42<03:59,  1.41s/it]

Epochs - [30/200] - netG_loss: [9.1624] - netD_loss: [1.2408] - train_acc: [1.0000] - test_accu: [0.7825]


 16%|█▌        | 31/200 [00:44<04:03,  1.44s/it]

Epochs - [31/200] - netG_loss: [7.9151] - netD_loss: [1.2136] - train_acc: [1.0000] - test_accu: [0.7861]


 16%|█▌        | 32/200 [00:45<04:02,  1.45s/it]

Epochs - [32/200] - netG_loss: [7.5898] - netD_loss: [1.1743] - train_acc: [1.0000] - test_accu: [0.7861]


 16%|█▋        | 33/200 [00:47<04:06,  1.47s/it]

Epochs - [33/200] - netG_loss: [7.8734] - netD_loss: [1.1619] - train_acc: [1.0000] - test_accu: [0.7704]


 17%|█▋        | 34/200 [00:48<04:08,  1.50s/it]

Epochs - [34/200] - netG_loss: [7.6001] - netD_loss: [1.1352] - train_acc: [1.0000] - test_accu: [0.7897]


 18%|█▊        | 35/200 [00:50<04:08,  1.50s/it]

Epochs - [35/200] - netG_loss: [8.4159] - netD_loss: [1.1248] - train_acc: [1.0000] - test_accu: [0.7861]


 18%|█▊        | 35/200 [00:51<04:02,  1.47s/it]


KeyboardInterrupt: 

In [73]:
dataloader = load(
    filename="/Users/shahmuhammadraditrahman/Desktop/SGAN/data/processed/val_dataloader.pkl"
)

device = device_init(device="mps")
netD = Discriminator().to(device)

state_dict = torch.load(
    "/Users/shahmuhammadraditrahman/Desktop/SGAN/checkpoints/train_models/model200.pth"
)

netD.load_state_dict(state_dict)

real_labels = []
pred_labels = []

for (images, labels) in dataloader:
    images = images.to(device)
    labels = labels.to(device)

    clf_pred, _ = trainer.netD(images)
    clf_labels = torch.argmax(clf_pred, dim=1)
    clf_labels = clf_labels.cpu().detach().flatten().numpy()

    real_labels.extend(labels.cpu().detach().flatten().numpy())
    pred_labels.extend(clf_labels)

100%|██████████| 4/4 [00:00<00:00, 605.06it/s]


In [74]:
accuracy_score(real_labels, pred_labels)

0.7642276422764228