<a href="https://colab.research.google.com/github/avrymi-asraf/Garden-of-GAN/blob/main/Basic-And-Principles.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Basic and Principle

## Installations and Import

In [1]:
!pip install -q plotly torchvision

In [2]:
import torch
from torch import nn, optim
import torch.utils.data.dataloader as dataloader
from torchvision import datasets, transforms
import plotly.express as px
import pandas as pd

## Models

In [3]:
class Generator(nn.Module):
    def __init__(self, letant_dim: int, im_dim):
        super().__init__()
        self.im_dim = im_dim
        self.len_im = im_dim[0] * im_dim[1]
        self.letant_dim = letant_dim
        self.model = nn.Sequential(
            nn.Linear(letant_dim, 256),
            nn.LeakyReLU(),
            nn.Linear(256, self.len_im),
            nn.Tanh(),
        )

    def forward(self, X):
        return self.model(X.view((-1, self.letant_dim))).view((-1, *self.im_dim))


class Discrimnator(nn.Module):
    def __init__(self, im_dim) -> None:
        super().__init__()
        self.im_dim = im_dim
        self.len_im = im_dim[0] * im_dim[1]
        self.model = nn.Sequential(
            nn.Linear(self.len_im, 128), nn.LeakyReLU(), nn.Linear(128, 1), nn.Sigmoid()
        )

    def forward(self, X):
        return self.model(X.view(-1, self.len_im)).view(-1)

In [None]:
transfomer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5, 0.5)])
mnist_data = datasets.MNIST("/dataset", download=True, transform=transfomer)

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
letant_dim = 100
im_dim = (28, 28)
generator = Generator(letant_dim, im_dim).to(device)
discrimnator = Discrimnator(im_dim).to(device)

In [7]:
lr = 3e-4
optim_g = optim.Adam(generator.parameters(), lr=lr)
optim_d = optim.Adam(discrimnator.parameters(), lr=lr)
loss_f = nn.BCELoss()

In [8]:
from math import ceil

num_epochs = 32
batch_size = 64
out_data = pd.DataFrame(
    {"epoch": pd.NA, "batch": pd.NA, "loss_g": pd.NA, "loss_d": pd.NA},
    index=range(num_epochs * ceil(len(mnist_data) / batch_size)),
)

In [None]:
from IPython.display import clear_output

ind_out_data = 0
for epoch_ind in range(num_epochs):
    loader = dataloader.DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
    for i_batch, (X, _) in enumerate(loader):
        optim_g.zero_grad()
        optim_d.zero_grad()
        noise = torch.rand(X.shape[0], letant_dim).to(device)
        fake = generator(noise)
        loss_d_fake = loss_f(discrimnator(fake), torch.zeros(fake.shape[0]).to(device))
        loss_d_real = loss_f(
            discrimnator(X.to(device)), torch.ones(X.shape[0]).to(device)
        )
        loss_d = (loss_d_fake + loss_d_real) / 2
        loss_d.backward(retain_graph=True)
        optim_d.step()

        loss_g = loss_f(discrimnator(fake), torch.ones(fake.shape[0]).to(device))
        loss_g.backward()
        optim_g.step()
        out_data.loc[ind_out_data] = [epoch_ind, i_batch, loss_g.item(), loss_d.item()]
        ind_out_data += 1
    with torch.no_grad():
        noise = torch.rand(10, letant_dim)
        im = generator(noise.to(device))
        clear_output(wait=True)
        px.imshow(im.cpu().detach(), facet_col=0, facet_col_wrap=5).show()
        px.line(out_data, x="batch", y="loss_g", color="epoch").show()

In [None]:
from math import ceil
from IPython.display import clear_output

num_epochs = 20
batch_size = 32
out_data = pd.DataFrame(
    {"epochs": pd.NA, "batch": pd.NA, "loss_g": pd.NA, "loss_d": pd.NA},
    index=range(num_epochs * ceil(len(mnist_data) / batch_size)),
)
index_data = 0
for epoch in range(num_epochs):
    loader = dataloader.DataLoader(mnist_data, batch_size=batch_size, shuffle=True)
    for batch_i, (X, _) in enumerate(loader):
        curr_batch_size = X.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(curr_batch_size, letant_dim).to(device)
        fake = generator(noise)
        disc_real = discrimnator(X.to(device)).view(-1)
        lossD_real = loss_f(disc_real, torch.ones_like(disc_real).to(device))
        disc_fake = discrimnator(fake).view(-1)
        lossD_fake = loss_f(disc_fake, torch.zeros_like(disc_fake).to(device))
        lossD = (lossD_real + lossD_fake) / 2
        optim_d.zero_grad()
        lossD.backward(retain_graph=True)
        optim_d.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = discrimnator(fake).view(-1)
        lossG = loss_f(output, torch.ones_like(output).to(device))
        optim_g.zero_grad()
        lossG.backward()
        optim_g.step()

        out_data.loc[index_data] = [epoch, batch_i, lossG.item(), lossD.item()]
        index_data += 1

    with torch.no_grad():
        fake = generator(torch.rand(10, 100).to(device))
        clear_output(wait=True)
        px.imshow(fake.cpu().detach(), facet_col=0, facet_col_wrap=5).show()
        px.line(out_data, x="batch", y=["loss_g"], color="epochs").show()