In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.autograd import Variable
from torchvision.utils import save_image

import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import math
import itertools
from glob import glob

# Generator class
the generator class takes the input vector(noise vector), number of color channels and number of feature maps to scale it to a image using ConvTranspose2d layers, each layer decreases the depth of feature maps while increasing the resolution 

In [12]:
class Generator(nn.Module):
    
    def __init__(self, num_ch, noise_vector, num_gen_filter):
        super(Generator, self).__init__()
        self.network = nn.Sequential(
            nn.ConvTranspose2d(
                in_channels=noise_vector,
                out_channels=num_gen_filter * 4,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.BatchNorm2d(num_gen_filter * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=num_gen_filter * 4,
                out_channels=num_gen_filter * 2,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_gen_filter * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=num_gen_filter * 2,
                out_channels=num_gen_filter,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_gen_filter),
            nn.ReLU(True),
            nn.ConvTranspose2d(
                in_channels=num_gen_filter,
                out_channels=num_ch,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.Tanh(),
        )

    def forward(self, input):
        output = self.network(input)
        return output

In [14]:
class Discriminator(nn.Module):

    def __init__(self, num_ch, num_disc_filter):
        super(Discriminator, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(
                in_channels=num_ch,
                out_channels=num_disc_filter,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                in_channels=num_disc_filter,
                out_channels=num_disc_filter * 2,
                kernel_size=4,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_disc_filter * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                in_channels=num_disc_filter * 2,
                out_channels=num_disc_filter * 4,
                kernel_size=3,
                stride=2,
                padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(num_disc_filter * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(
                in_channels=num_disc_filter * 4,
                out_channels=1,
                kernel_size=4,
                stride=1,
                padding=0,
                bias=False,
            ),
            nn.Sigmoid(),
        )

    def forward(self, input):
        output = self.network(input)
        return output.view(-1, 1).squeeze(1)

In [None]:
class DCGAN:
    
    def __init__(self,noise):
        self.input = noise
        self.G = Generator(self.input)
        self.D = Discriminator()
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.D.to(self.device)
        self.G.to(self.device)
        self.D.apply(self.weights_init)
        self.G.apply(self.weights_init)
        
    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)
            
    def dis_loss(self,dis_output):
        batch_size = dis_output.size(0)
        labels = torch.ones(batch_size).to(self.device)*0.8
        cost = nn.BCELoss()
        return cost(dis_output.squeeze(),labels)
    
    def gen_loss(self,dis_output):
        batch_size = dis_output.size(0)
        labels = torch.ones(batch_size).to(self.device)*0.1
        cost = nn.BCELoss()
        return cost(dis_output.squeeze(),labels)
    
    def noise(self,batch_size):
        return torch.randn(batch_size,self.input_size,1,1).to(self.device)
    
    def train_generator(self,batch_size,gen_optimizer):
        gen_optimizer.zero_grad()
        
        noise = self.noise(batch_size)
        fake_images = self.G(noise)
        dis_output = self.D(fake_images)
        
        gen_loss = self.gen_loss(dis_output)
        
        gen_loss.backward()
        gen_optimizer.step()
        
        return gen_loss.item()
    
    def train_discriminator(self,batch_size,dis_optimizer,real_images):
        dis_optimizer.zero_grad()
        
        real_loss = self.dis_loss(self.D(real_images))
        
        noise = self.noise(batch_size)
        fake_images = self.G(noise)
        dis_output = self.D(fake_images)
        
        fake_loss = self.dis_loss(dis_output)
        
        dis_loss = real_loss + fake_loss
        
        dis_loss.backward()
        dis_optimizer.step()
        
        return dis_loss.item()
    
    def show(self, tensor, num=25, wandbactive=0, name=''):
        data = tensor.detach().cpu()
        fig,axes = plt.subplots(figsize=(5, 5), nrows=5, ncols=5, sharey=True, sharex=True)
        fig.frameon = False
        plt.subplots_adjust(wspace=0, hspace=0, left=0, right=1, bottom=0, top=1)
        for ax, img in zip(axes.flatten(), data):
            _, w, h = img.size()

            img = img.detach().cpu().numpy()

            img = np.transpose(img, (1, 2, 0))

            img = ((img + 1) * 255 / (2)).astype(np.uint8)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            im = ax.imshow(img.reshape((w, h, 3)))
        
        plt.show()

    def show_image_sample(self, noise):
        self.G.eval()
        self.show(self.G(noise))
        self.G.train()

    def save_model(self):
        torch.save(self.G.state_dict(), '/kaggle/working/Generator.pth')
        torch.save(self.D.state_dict(), '/kaggle/working/Discriminator.pth')

    def plot_weight_distribution(self, model, model_name):
        plt.figure(figsize=(15, 10))
        plt.suptitle(f'Weight Distribution for {model_name}', fontsize=16, y=1.05)
       
        cns = 1
        for idx, (name, param) in enumerate(model.named_parameters()):
            if "weight" in name: 
                plt.subplot(3, 3, cns)
                plt.hist(param.detach().cpu().numpy().flatten(), bins=50)
                plt.title(f"Weight Distribution for {name}")
                plt.xlabel("Weight Value")
                plt.ylabel("Frequency")
                cns += 1

        plt.tight_layout()
        plt.show() 

    def train(self, batch_size, epochs, gen_optimizer, dis_optimizer, dataloader):
        samples, losses = [], []
        gen_loss, dis_loss, min_gen_loss = np.Inf, np.Inf, np.Inf
        for epoch in range(epochs):
            real_images_iter = iter(dataloader)
            gen_loss_total, dis_loss_total, samples_count = 0, 0, 0
            for real_images, _ in iter(real_images_iter):
                real_images = real_images.to(self.device)
                dis_loss = self.train_discriminator(batch_size, dis_optimizer, real_images)
                gen_loss = self.train_generator(batch_size, gen_optimizer)

            if gen_loss < min_gen_loss and epoch >= 165:
                min_gen_loss = gen_loss
                self.save_model()

            losses.append((dis_loss, gen_loss))

            print(f"Epoch {epoch + 1}/{epochs}, "
                  f"Discriminator Loss: {dis_loss:.4f}, Generator Loss: {gen_loss:.4f}")
            if (epoch + 1) % 10 == 0:
                self.show_image_sample(self.noise(batch_size))

        self.plot_weight_distribution(self.G, "Generator")
        self.plot_weight_distribution(self.D, "Discriminator")

        return losses
        