<a href="https://colab.research.google.com/github/SarveshD7/Vanilla-GAN-Pytorch/blob/main/GAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
from torch.nn.modules.activation import LeakyReLU

In [9]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)


In [10]:
class Generator(nn.Module):
  def __init__(self, z_dim, img_dim):
    super().__init__()
   # z_dim-> Dimension of Noise
   #  img_dim-> Dimension of Image input ->(Here in case of MNIST) 28x28x1 = 784
    self.gen = nn.Sequential(
        nn.Linear(z_dim, 256),
        nn.LeakyReLU(0.1),
        nn.Linear(256, img_dim),
        nn.Tanh()
    )

  def forward(self, x):
    return self.gen(x)

In [11]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64 # Can also be 128, 256
image_dim = 28*28*1
batch_size = 32
num_epochs = 50

In [12]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)

In [24]:
transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, ), (0.5,))]
)

In [14]:
dataset  = datasets.MNIST(root = "/content/dataset/", transform = transforms, download = True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 144333046.04it/s]


Extracting /content/dataset/MNIST/raw/train-images-idx3-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 33007001.04it/s]


Extracting /content/dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 37140279.24it/s]

Extracting /content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/dataset/MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 14860006.84it/s]


Extracting /content/dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/dataset/MNIST/raw



In [15]:
loader = DataLoader(dataset, batch_size = batch_size, shuffle = True)
opt_disc = optim.Adam(disc.parameters(), lr = lr)
opt_gen = optim.Adam(gen.parameters(), lr = lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"/content/runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"/content/runs/GAN_MNIST/real")
step = 0


In [None]:
for epoch in range(num_epochs):
  print(f"Started Epoch-> {epoch+1}")
  for batch_idx, (real, _) in enumerate(loader):  # We can ignore the label since we are only considered with the image being a real one
    real = real.view(-1, 784)  # Flattening the image
    batch_size = real.shape[0]

    # Training the discriminator---> Maximize log(D(real)) + log(1-D(G(z))
    noise = torch.randn(batch_size, z_dim).to(device)
    fake = gen(noise)
    disc_real = disc(real).view(-1)  # The output of the Discriminator on the real images Flattened
    lossD_real = criterion(disc_real, torch.ones_like(disc_real))
    # criterion-> BCE Loss-> ℓ(x,y) = -wn*[yn . logxn +(1-yn) . log(1-xn)] -> We are passing y as all ones so ulitmately this term will become the first component of the Discriminator Loss
    disc_fake = disc(fake).view(-1)
    lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
    # criterion-> BCE Loss-> ℓ(x,y) = -wn*[yn . logxn +(1-yn) . log(1-xn)] -> We are passing y as all zeros so ulitmately this term will become the second component of the Discriminator Loss
    lossD = (lossD_real + lossD_fake) / 2
    disc.zero_grad()
    lossD.backward(retain_graph = True)
    opt_disc.step()

    # Training the discriminator---> Minimize log(1-D(G(z)) <----> Maximize log(D(g(z)))
    output = disc(fake).view(-1)
    lossG = criterion(output, torch.ones_like(output))
    gen.zero_grad()
    lossG.backward()
    opt_gen.step()
  print(f"Completed Epoch-> {epoch+1}")


# ***Important Points to note***
***GANs are sensitive to the hyperparameters. So better results may be obtained by hyperparameter tuning and experimenting with the same. Results can also be improved by using CNNs in the Discriminator and the Generator which is DCGAN(Deep Convolution GAN)***
---

