In [1]:
# import the libraries
import torch
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchviz import make_dot
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [None]:
# use if IProgress bar error comes
#!pip install ipywidgets
#!jupyter nbextension enable --py widgetsnbextension

In [32]:
# generator network class
class Generator(nn.Module):
    def __init__(self,z_dim = 10, img_dim = 784 , hidden_dim = 128):
        super().__init__()
        self.linear1     = nn.Linear(z_dim,hidden_dim)
        self.batch1      = nn.BatchNorm1d(hidden_dim)

        self.linear2     = nn.Linear(hidden_dim,hidden_dim*2)
        self.batch2      = nn.BatchNorm1d(hidden_dim*2)

        self.linear3     = nn.Linear(hidden_dim*2,hidden_dim*4)
        self.batch3      = nn.BatchNorm1d(hidden_dim*4)

        self.linear4     = nn.Linear(hidden_dim*4,hidden_dim*8)
        self.batch4      = nn.BatchNorm1d(hidden_dim*8)
        
        self.output      = nn.Linear(hidden_dim*8,img_dim)
        self.activation  = nn.ReLU(inplace=True)
        self.out_act     = nn.Sigmoid()
        
    def forward(self,alpha):
        alpha = self.linear1(alpha)
        alpha = self.batch1(alpha)
        alpha = self.activation(alpha)

        alpha = self.linear2(alpha)
        alpha = self.batch2(alpha)
        alpha = self.activation(alpha)

        alpha = self.linear3(alpha)
        alpha = self.batch3(alpha)
        alpha = self.activation(alpha)

        alpha = self.linear4(alpha)
        alpha = self.batch4(alpha)
        alpha = self.activation(alpha)

        return self.out_act(self.output(alpha))


In [35]:
# discriminator network  
class Discriminator(nn.Module):
    def __init__(self, img_dim = 784 , hidden_dim = 128):
        super().__init__()
        self.linear1    = nn.Linear(img_dim,hidden_dim*4)
        self.linear2    = nn.Linear(hidden_dim*4,hidden_dim*2)
        self.linear3    = nn.Linear(hidden_dim*2,hidden_dim)
        self.output     = nn.Linear(hidden_dim,1)
        self.activation = nn.LeakyReLU(0.2)
        
    def forward(self,alpha):
        alpha = self.activation(self.linear1(alpha))
        alpha = self.activation(self.linear2(alpha))
        alpha = self.activation(self.linear3(alpha))
        return self.output(alpha)

In [36]:
# global parameters
n_epochs   = 100
noise_dim  = 64
batch_size = 128
img_dim    = 784
lr = 0.0001

In [None]:
# load the dataset
dataloader = DataLoader(MNIST('.',download=True,transform=transforms.ToTensor()),
                       batch_size = batch_size,shuffle=True)

In [37]:
# generator and discriminator optimizers and networks
gen_network = Generator(z_dim=noise_dim).to(device)
disc_network = Discriminator().to(device)

gen_opt  = torch.optim.Adam(gen_network.parameters(),lr = lr)
disc_opt = torch.optim.Adam(disc_network.parameters(),lr = lr)

# loss function
criterion = nn.BCEWithLogitsLoss()

In [38]:
# function to print the network
temp_inp = torch.rand(batch_size,noise_dim,device=device)
y_hat = gen_network(temp_inp)
make_dot(y_hat, params=dict(list(gen_network.named_parameters()))).render("gen_torchviz1", format="png")

temp_inp = torch.rand(batch_size,img_dim,device=device)
y_hat = disc_network(temp_inp)
make_dot(y_hat, params=dict(list(disc_network.named_parameters()))).render("disc_torchviz1", format="png")

'disc_torchviz1.png'

In [39]:
def disp_imgs(img_out):
    disp_img = img_out.detach().cpu().view(-1,1,28,28)
    img_grid = make_grid(disp_img[:25],nrow=5)
    plt.imshow(img_grid.permute(1,2,0).squeeze())
    plt.show()

# training loop
gen_meanloss  = 0
disc_meanloss = 0
cur_step  = 0
for epoch in range(n_epochs):
    for real_imgs,_ in dataloader:
        cur_step+=1
        cur_batch_size = len(real_imgs)
        # flattenning the images
        real_imgs = real_imgs.view(cur_batch_size,-1).to(device)
        
        # training the discriminator
        disc_opt.zero_grad()
        gen_noise  = torch.randn(cur_batch_size,noise_dim,device=device)
        gen_images = gen_network(gen_noise)
        disc_out1   = disc_network(gen_images.detach())
        disc_loss  = criterion(disc_out1,torch.zeros_like(disc_out1))
        disc_out2   = disc_network(real_imgs)
        disc_loss += criterion(disc_out2,torch.ones_like(disc_out2))
        disc_loss /= 2
        disc_loss.backward(retain_graph=True)
        disc_opt.step()
        
        # training the generator
        gen_opt.zero_grad()
        gen_noise  = torch.randn(cur_batch_size,noise_dim,device=device)
        gen_images = gen_network(gen_noise)
        disc_out   = disc_network(gen_images)
        gen_loss  = criterion(disc_out,torch.ones_like(disc_out))
        gen_loss.backward()
        gen_opt.step()
        gen_meanloss  += gen_loss.item()
        disc_meanloss += disc_loss.item()
        if cur_step%500 == 0:
            print(" Current Step: "+str(cur_step))
            print(" Generator loss     : "+str(gen_meanloss/cur_step))
            print(" Discriminator loss : "+str(disc_meanloss/cur_step))
            gen_noise  = torch.randn(cur_batch_size,noise_dim,device=device)
            img_out = gen_network(gen_noise)
            disp_imgs(img_out=img_out)
            disp_imgs(img_out=real_imgs)