In [3]:
from IPython.display import display, Latex
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard

### GAN Architecture

The explanation consists of the flow as mention in published [paper](https://arxiv.org/pdf/1406.2661.pdf) by Ian Goodfellow. The images are own visual understandings to remember. 

* General Overview of the GAN
* Value Function
* Global Optimality

* Consists of 2-networks that fight with each other
    - Generator
    - Discrinator
* These networks are trying constantly to counterfeit with each other 
* One network is trying to generate fake images and other is trying to catch the fake ones
* The counterfeit is *generator* and cop who's trying to catch the fake images is the discriminator
* The Generator keeps generating fake images until it completely becomes perfect in fooling the cop
* So, Generator is trying to maximize the probability of Discriminator making a mistake

This is the broad overview of the model, diving into the details
* The Generator tries to learn the joint probability of input and output given by following equation

    P(X,Y) = ($P(X|Y)P(Y), P(Y|X)P(X)$)
    
* The Discriminator tries to learn the output given input variable

    P(Y|X=x)

* The adversial modelling is applied to Multilayered Perceptrons 

$\theta_g$ and $\theta_d$ are nothing but weights. $P_{(data)}(x)$ is the original data and $P_g(x)$  is the probability of noise obtained after passing the noise distribution $P_z(z)$ into generator network, which randomly samples using the generator weights. The discriminator judges the output based on the input passed through it by outputting $y=0$ and $y=1$ 


<img src="../../images/image4.png" alt="drawing" width="600"/>


Error Function for this can be explained very intuitively for this!
Natural logarithm is used to calculate the error. In general, if output is closer to the input then error should be minimum. That is, say 

* if $x=1$ and $y=0.9$ then the error should be minimal. 
* if $x=0$ and $y=3$ then the error should be maximal.

This behaviour can be observed in natural logarithm as below.

<img src="../../images/image2.png" alt="drawing" width="300"/>

The above error function can be adapted for generator. 


In a same way for discriminator, acting as a cop. This is expected to perform reverse of what generator has performed. This can be observed by below graph

<img src="../../images/image3.png" alt="drawing" width="300"/>



So in general for Generator and Discriminator. The forward pass and loss function can be explained by given diagram.


<img src="../../images/image1.png" alt="drawing" width="500"/>



These explanations are very intuitive

A bigger and fancier cost function can be given as below:

$J(\theta)=-\frac{1}{m}[\sum_{i=1}^{m}y^{(i)}h_{\theta}(x^{(i)})+(1-y^{(i)})log(1-h_{\theta}(x^{(i)})]$


### Code Implementations

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [None]:
# Hyperparameters etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50


In [None]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [None]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1