<a href="https://colab.research.google.com/github/TridentifyIshaan/SeasonofAI/blob/mainstream/First_Simple_Gan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Importing Libarries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as T
from torch.utils.tensorboard import SummaryWriter

# Discriminator

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(img_dim, 128),
        nn.LeakyReLU(0.1),
        nn.Linear(128, 1),
        nn.Sigmoid(),
    )

  def forward(self, x):
    return self.disc(x)

# Generator

In [None]:
class generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(), #For Range -1 to 1
        )
    def forward(self, x):
        return self.gen(x)

# Hyperparameter, etc.

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4 # Best learning rate for adam
z_dim = 64 # 128, 256 can also be tried
image_dim = 28 * 28 * 1 #784
batch_size = 32
num_epochs = 50

disc = Discriminator(image_dim).to(device)
gen = generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = T.Compose([T.ToTensor(), T.Normalize((0.1307,), (0.3081,))])
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        # Train Discriminator: max log(D(real)) + log(1- D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        # Train Generator: min log(1- D(G(z))) <--> max log(D(G(z))

        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
          print(
              f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                    Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
          )

          with torch.no_grad():
              fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
              data = real.reshape(-1, 1, 28, 28)
              img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
              img_grid_real = torchvision.utils.make_grid(data, normalize=True)

              writer_fake.add_image("MNIST Fake Images", img_grid_fake, global_step=step)
              writer_real.add_image("MNIST Real Images", img_grid_real, global_step=step)
              step += 1

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to dataset/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 117252079.18it/s]


Extracting dataset/MNIST/raw/train-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 52078974.13it/s]

Extracting dataset/MNIST/raw/train-labels-idx1-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 102154968.93it/s]


Extracting dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1007)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 5762410.40it/s]

Extracting dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to dataset/MNIST/raw






Epoch [0/50] Batch 0/1875                     Loss D: 0.6269, loss G: 0.6843
Epoch [1/50] Batch 0/1875                     Loss D: 0.1487, loss G: 2.2230
Epoch [2/50] Batch 0/1875                     Loss D: 0.1831, loss G: 2.4162
Epoch [3/50] Batch 0/1875                     Loss D: 0.0416, loss G: 4.3018
Epoch [4/50] Batch 0/1875                     Loss D: 0.0282, loss G: 4.5742
Epoch [5/50] Batch 0/1875                     Loss D: 0.0169, loss G: 4.7247
Epoch [6/50] Batch 0/1875                     Loss D: 0.0182, loss G: 4.8843
Epoch [7/50] Batch 0/1875                     Loss D: 0.0182, loss G: 4.7169
Epoch [8/50] Batch 0/1875                     Loss D: 0.0174, loss G: 6.0799
Epoch [9/50] Batch 0/1875                     Loss D: 0.0247, loss G: 5.5348
Epoch [10/50] Batch 0/1875                     Loss D: 0.0176, loss G: 5.3674
Epoch [11/50] Batch 0/1875                     Loss D: 0.0063, loss G: 6.2369
Epoch [12/50] Batch 0/1875                     Loss D: 0.0117, loss G: 5.1