In [1]:
import torch 
from torch import optim,nn
from torchvision import datasets,transforms as T
from torch.utils.data import DataLoader
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from tqdm.auto import tqdm

In [2]:
data = datasets.MNIST('MNIST_data/',train = True,download=True,transform = T.ToTensor())
trainloader = DataLoader(data,batch_size = 128, shuffle = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw
Processing...
Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [3]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_unflat = image_tensor.detach().cpu().view(-1, *size)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [4]:
class generator(nn.Module):
  def __init__(self,input_dim = 64 , hidden_dim = 128 , output_dim = 784):
    super(generator,self).__init__()
 
    self.l1 = nn.Linear(input_dim,hidden_dim)
    self.l2 = nn.Linear(hidden_dim,hidden_dim*2)
    self.l3 = nn.Linear(hidden_dim*2,hidden_dim*4)
    self.l4 = nn.Linear(hidden_dim*4,hidden_dim*8)
    self.l5 = nn.Linear(hidden_dim*8,output_dim)

    self.b1 = nn.BatchNorm1d(hidden_dim)
    self.b2 = nn.BatchNorm1d(hidden_dim*2)
    self.b3 = nn.BatchNorm1d(hidden_dim*4)
    self.b4 = nn.BatchNorm1d(hidden_dim*8)

  def forward(self,x):
      a1 = F.leaky_relu(self.l1(x), 0.2)
      a11 = self.b1(a1)

      a2 = F.leaky_relu(self.l2(a11), 0.2)
      a22 = self.b2(a2)

      a3 = F.leaky_relu(self.l3(a22), 0.2)
      a33 = self.b3(a3)

      a4 = F.leaky_relu(self.l4(a33), 0.2)
      a44 = self.b4(a4)

      a5 = F.sigmoid(self.l5(a44))

      return a5




In [5]:
class discriminator(nn.Module):
  def __init__(self,input_dim = 784 , hidden_dim = 128 , output_dim = 1):
    super(discriminator,self).__init__()
    self.fc1 = nn.Linear(input_dim, hidden_dim*4)
    self.fc2 = nn.Linear(hidden_dim*4, hidden_dim*2)
    self.fc3 = nn.Linear(hidden_dim*2, hidden_dim)
    self.fc4 = nn.Linear(hidden_dim, output_dim)

    self.dropout = nn.Dropout(0.3)

  def forward(self,x):


    x = x.view(-1,28*28)

    a1 = F.leaky_relu(self.fc1(x),0.2)
   

    a2 = F.leaky_relu(self.fc2(a1),0.2)
   

    a3 = F.leaky_relu(self.fc3(a2),0.2)
    

    a4 = self.fc4(a3)
    return a4

In [6]:
Gen = generator()
Gen_opt = optim.Adam(Gen.parameters(),lr = 0.0001)
Disc = discriminator()
Disc_opt = optim.Adam(Disc.parameters(),lr = 0.0001)
print(Gen)
print(Disc)

generator(
  (l1): Linear(in_features=64, out_features=128, bias=True)
  (l2): Linear(in_features=128, out_features=256, bias=True)
  (l3): Linear(in_features=256, out_features=512, bias=True)
  (l4): Linear(in_features=512, out_features=1024, bias=True)
  (l5): Linear(in_features=1024, out_features=784, bias=True)
  (b1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (b2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (b3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (b4): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)
discriminator(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (fc3): Linear(in_features=256, out_features=128, bias=True)
  (fc4): Linear(in_features=128, out_features=1, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)


In [7]:
def get_noise(n_sample,z_dim):
  noise = torch.randn(n_sample,z_dim)
  return noise

In [8]:
def get_gen_loss(Gen,Disc,n_sample,z_dim,criterion):

  #Generate the the noise
  noise = get_noise(n_sample,z_dim)
  # Feed generated noise to generator
  fake_image = Gen(noise)
  # Feed fake_image to discriminator
  disc_fake_pred = Disc(fake_image)
  # Now calculate the loss
  gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))

  return gen_loss

In [9]:
def get_disc_loss(Gen,Disc,n_sample,z_dim,criterion,real_img):
  #Generate the the noise
  noise = get_noise(n_sample,z_dim)
  # Feed generated noise to generator
  fake_image = Gen(noise)
  # Feed fake_image to discriminator
  disc_fake_pred = Disc(fake_image.detach())
  #Calculate criterion loss with zeros
  fake_loss = criterion(disc_fake_pred,torch.zeros_like(disc_fake_pred)) 
  #Feed real image to discrimimator
  disc_real_pred = Disc(real_img)
  #Calculate criterion loss with ones
  real_loss = criterion(disc_real_pred,torch.ones_like(disc_real_pred))

  # take average of real and fake loss

  disc_loss = ( fake_loss + real_loss) / 2

  return disc_loss 

In [10]:
criterion = nn.BCEWithLogitsLoss()
n_epochs = 200
z_dim = 64
display_step = 500


In [11]:
cur_step = 0
mean_generator_loss = 0
mean_discriminator_loss = 0

for epoch in range(n_epochs):
  for real, _ in tqdm(trainloader):
    n_sample = len(real)
  
    real_img = real.view(n_sample, -1)
   
    Disc_opt.zero_grad()

    disc_loss = get_disc_loss(Gen, Disc,n_sample,z_dim, criterion, real_img)

    disc_loss.backward(retain_graph=True)

    Disc_opt.step()

    Gen_opt.zero_grad()
      
    gen_loss = get_gen_loss(Gen, Disc,n_sample, z_dim,criterion)
  
    gen_loss.backward()

    Gen_opt.step()


    mean_discriminator_loss += disc_loss.item() / display_step
   # Keep track of the average generator loss
    mean_generator_loss += gen_loss.item() / display_step
    ### Visualization code ###
    if cur_step % display_step == 0 and cur_step > 0:
      print(f"Step {cur_step}: Generator loss: {mean_generator_loss}, discriminator loss: {mean_discriminator_loss}")
      fake_noise = get_noise(n_sample, z_dim)
      fake = Gen(fake_noise)
      show_tensor_images(fake)
      show_tensor_images(real)
      mean_generator_loss = 0
      mean_discriminator_loss = 0
    cur_step += 1

  

Output hidden; open in https://colab.research.google.com to view.