In [1]:
import os
import zipfile
import joblib as pkl
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torch.utils.data import DataLoader

In [17]:
def pickle(value = None, filename = None):
    if value and filename:
        pkl.dump(value = value, filename=filename)
    else:
        raise ValueError("value and filename are required".capitalize())
    
def clean(path = None):
    if path:
        for file in os.listdir(path):
            os.remove(os.path.join(path, file))
    else:
        raise ValueError("path is required".capitalize())
    
def total_params(model = None):
    return sum(p.numel() for p in model.parameters())

In [13]:
to_extract = "../data/raw/"
to_save = "../data/processed/"

In [34]:
class Loader:
    def __init__(self, image_path = None, batch_size = 64, image_size = 64, normalized = True):
        self.image_path = image_path
        self.batch_size = batch_size
        self.image_size = image_size
        self.use_normalized = normalized

    def unzip_images(self):
        with zipfile.ZipFile(self.image_path, "r") as zip_ref:
            if os.path.exists(to_extract):
                zip_ref.extractall(to_extract)
            else:
                raise Exception("Extracting images failed".capitalize())

    def _normalized(self):
        if self.use_normalized:
            transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(self.image_size),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 5, 0.5]),
                transforms.Grayscale(num_output_channels=1)
            ])

            return transform

    @staticmethod
    def class_to_idx(dataset = None):
        if dataset is not None:
            return dataset.class_to_idx

    def create_dataloader(self):
        if os.path.exists(to_extract):
            datasets = ImageFolder(root=os.path.join(to_extract, "Dataset"), transform=self._normalized())
            dataloader = DataLoader(datasets, batch_size=self.batch_size, shuffle=True)

            if os.path.exists(to_save):

                try:
                    pickle(value=dataloader, filename=os.path.join(to_save, "dataloader.pkl"))
                    pickle(value=Loader.class_to_idx(dataset=datasets), filename=os.path.join(to_save, "dataset.pkl"))
                except Exception as e:
                    print(e)
            else:
                raise Exception("Creating dataloader failed".capitalize())
        else:
            raise Exception("Extracting images failed from the create dataloader method".capitalize())

        return dataloader, datasets.class_to_idx


if __name__ == "__main__":
    loader = Loader(
        image_path="../Desktop/archive.zip",
        batch_size=64,
        image_size=64,
        normalized=True,
    )
    
    loader.unzip_images()
    dataloader, labels = loader.create_dataloader()

#### Create the Generator model

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import OrderedDict

class Generator(nn.Module):
    def __init__(self, latent_space = 50, num_labels = 4, image_size = 64, in_channels = 1):
        self.latent_space = latent_space
        self.num_labels = num_labels
        self.image_size = image_size
        self.in_channels = in_channels

        super(Generator, self).__init__()
        self.config_layers = [
            (self.latent_space * 2, self.image_size * 8, 4, 1, 0, True, False),
            (self.image_size * 8, self.image_size * 4, 4, 2, 1, True, False),
            (self.image_size * 4, self.image_size * 2, 4, 2, 1, True, False),
            (self.image_size * 2, self.image_size, 4, 2, 1, True, False),
            (self.image_size, self.in_channels, 4, 2, 1, False, False),
        ]
        self.labels = nn.Embedding(num_embeddings=self.num_labels, embedding_dim=self.latent_space)
        self.model = self.connected_layer(config_layers=self.config_layers)
        
    def connected_layer(self, config_layers = None):
        if config_layers is not None:
            layers = OrderedDict()
            
            for idx, (in_channels, out_channels, kernel_size, stride, padding, batch_norm, bias) in enumerate(config_layers[:-1]):
                layers[f"ConvTranspose{idx+1}"] = nn.ConvTranspose2d(
                    in_channels, out_channels, kernel_size, stride, padding, bias=bias)
                
                if batch_norm:
                    layers[f"BatchNorm{idx+1}"] = nn.BatchNorm2d(out_channels)
                    
                layers[f"ReLU{idx+1}"] = nn.ReLU(inplace=True)
            
            in_channels, out_channels, kernel_size, stride, padding, batch_norm, bias = config_layers[-1]
            layers[f"outConvTranspose"] = nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=bias)
            layers["outLayer"] = nn.Tanh()
            
            return nn.Sequential(layers)
        
        else:
            raise ValueError("config layer should be defined".capitalize())

    def forward(self, noise, labels):
        labels = self.labels(labels)
        labels = labels.view(labels.size(0), self.latent_space, 1, 1)
        return self.model(torch.cat((noise, labels), dim=1))


if __name__ == "__main__":
    net_G = Generator()
    labels = torch.randint(0, 4, (64,))
    noise = torch.randn(64, 50, 1, 1)
    
    print(net_G(noise, labels).shape)
    print(total_params(model=net_G))

torch.Size([64, 1, 64, 64])
3574856


##### Define the Discriminator model

In [31]:
class Discriminator(nn.Module):
    def __init__(self, in_channels = 1, num_labels = 4, image_size = 64):
        self.in_channels = in_channels
        self.num_labels = num_labels
        self.image_size = image_size
        
        super(Discriminator, self).__init__()
        
        self.config_layers = [
            (self.in_channels, self.image_size, 4, 2, 1, False),
            (self.image_size, self.image_size*2, 4, 2,1, True),
            (self.image_size*2, self.image_size*4, 4, 2, 1, True),
            (self.image_size*4, self.image_size*8, 4, 2, 1, True),
            (self.image_size*8, 1 + self.num_labels, 4, 1, 0)
            
        ]
        
        self.model = self.connected_layer(config_layers=self.config_layers)
        
    
    def connected_layer(self, config_layers = None):
        if config_layers is not None:
            layers = OrderedDict()
            
            for idx, (in_channels, out_channels, kernel_size, stride, padding, batch_norm) in enumerate(config_layers[:-1]):
                layers['conv{}'.format(idx+1)] = nn.Conv2d(
                    in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
                if batch_norm:
                    layers[f"batchNorm{idx+1}"] = nn.BatchNorm2d(num_features=out_channels)
                    
                layers[f"leaky_relu{idx+1}"] = nn.LeakyReLU(negative_slope=0.2, inplace=True)
                
            (in_channels, out_channels, kernel_size, stride, padding) = config_layers[-1]
            layers["out"] = nn.Conv2d(
                in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
            
            return nn.Sequential(layers)
                
        else:
            raise ValueError("config layer should be defined".capitalize())
        
    
    def forward(self, x):
        output = self.model(x)
        real_or_fake = output[:, 0:1]
        labels = output[:, 1:]
        return torch.sigmoid(real_or_fake.view(-1, 1)), F.log_softmax(labels.view(-1, self.num_labels))
    

if __name__ == "__main__":
    
    net_D = Discriminator()
    
    noise_data = torch.randn(64, 1, 64, 64)
    real_fake, labels = net_D(noise_data)
    print(real_fake.shape, labels.shape)
    print(total_params(net_D))

torch.Size([64, 1]) torch.Size([64, 4])
2796288


  return torch.sigmoid(real_or_fake.view(-1, 1)), F.log_softmax(labels.view(-1, self.num_labels))


In [None]:
'''
64, 1, 64, 64 -> real image
64, 1, 64, 64 -> noise image that we will get from Generator

target:
64, 5
0/1 -> 64, 1
softmax -> 64, 4
'''