<a href="https://colab.research.google.com/github/aakashpaul-2/computer-vision/blob/main/mnist_simple_GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms, models
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import sys, os
from glob import glob
import imageio

In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.5,), std=(0.5,))])

In [None]:
train_dataset = torchvision.datasets.MNIST(root=".", train=True, transform=transform, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!




In [None]:
len(train_dataset)

60000

In [None]:
batch_size = 128
data_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)

In [None]:
#Discriminator

D = nn.Sequential(nn.Linear(784, 512), 
                  nn.LeakyReLU(0.2),
                  nn.Linear(512, 256),
                  nn.LeakyReLU(0.2),
                  nn.Linear(256, 1))

In [None]:
latent_dim = 100
G = nn.Sequential(nn.Linear(latent_dim, 256),
                  nn.LeakyReLU(0.2),
                  nn.BatchNorm1d(256, momentum=0.7),
                  nn.Linear(256, 512),
                  nn.LeakyReLU(0.2),
                  nn.BatchNorm1d(512, momentum=0.7),
                  nn.Linear(512, 1024),
                  nn.LeakyReLU(0.2),
                  nn.BatchNorm1d(1024, momentum=0.7),
                  nn.Linear(1024, 784),
                  nn.Tanh())

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
D = D.to(device)
G = G.to(device)

cuda:0


In [None]:
criterion = nn.BCEWithLogitsLoss()
d_optimizer = torch.optim.Adam(D.parameters(), lr = 0.0002, betas=(0.5,0.999))
g_optimizer = torch.optim.Adam(G.parameters(), lr = 0.0002, betas=(0.5,0.999))

In [None]:
def scale_image(img):
  out = (img + 1) / 2
  return out

In [None]:
if not os.path.exists("gan_images"):
  os.makedirs("gan_images")

In [None]:
# training loop

ones_ = torch.ones(batch_size,1).to(device)
zeros_ = torch.zeros(batch_size,1).to(device)

d_losses = []
g_losses = []

for epoch in range(200):
  for inputs, _ in data_loader:
    #print(inputs.shape)
    n = inputs.size(0)
    inputs = inputs.reshape(n, 784).to(device)

    ones = ones_[:n]
    zeros = zeros_[:n]


    ### training discriminator ###

    # real images
    real_outputs = D(inputs)
    d_loss_real = criterion(real_outputs, ones)

    # fake images
    noise = torch.rand(n, latent_dim).to(device)
    fake_images = G(noise)
    fake_outputs = D(fake_images)
    d_loss_fake = criterion(fake_outputs, zeros)

    # gradient descent step
    d_loss = 0.5* (d_loss_real+d_loss_fake)
    d_optimizer.zero_grad()
    g_optimizer.zero_grad()
    d_loss.backward()
    d_optimizer.step()
    
    ### train generator ###

    for _ in range(2):
      # fake images
      noise = torch.rand(n, latent_dim).to(device)
      fake_images = G(noise)
      fake_outputs = D(fake_images)

      g_loss = criterion(fake_outputs, ones)

      d_optimizer.zero_grad()
      g_optimizer.zero_grad()
      g_loss.backward()
      g_optimizer.step()

    d_losses.append(d_loss.item())
    g_losses.append(g_loss.item())

  print("epoch: {}, d_loss: {}, g_loss: {}".format(epoch, d_loss.item(), g_loss.item()))

  fake_images = fake_images.reshape(-1,1,28,28)
  save_image(scale_image(fake_images), f"gan_images/{epoch+1}.png")


epoch: 0, d_loss: 0.6915111541748047, g_loss: 0.7321199178695679
epoch: 1, d_loss: 0.6817969679832458, g_loss: 0.6892649531364441
epoch: 2, d_loss: 0.6906196475028992, g_loss: 0.6711704134941101
epoch: 3, d_loss: 0.683830201625824, g_loss: 0.7287164926528931
epoch: 4, d_loss: 0.6675338745117188, g_loss: 0.759560227394104
epoch: 5, d_loss: 0.6826400756835938, g_loss: 0.7208020091056824
epoch: 6, d_loss: 0.6924548745155334, g_loss: 0.7565855979919434
epoch: 7, d_loss: 0.6957719326019287, g_loss: 0.704592227935791
epoch: 8, d_loss: 0.6866205930709839, g_loss: 0.7063879370689392
epoch: 9, d_loss: 0.674645185470581, g_loss: 0.7176911234855652
epoch: 10, d_loss: 0.6851933598518372, g_loss: 0.713244616985321
epoch: 11, d_loss: 0.6893501281738281, g_loss: 0.7463488578796387
epoch: 12, d_loss: 0.6851441860198975, g_loss: 0.6878588199615479
epoch: 13, d_loss: 0.6765113472938538, g_loss: 0.7466691732406616
epoch: 14, d_loss: 0.680479884147644, g_loss: 0.7275842428207397
epoch: 15, d_loss: 0.67864