<a href="https://colab.research.google.com/github/AlexNedyalkov/GANS/blob/main/01_BasicGAN_AP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter 

# 1. Discriminator and Generator

In [None]:
class Discriminator(nn.Module):
  def __init__(self, img_dim, hidden_layers = 128, output_layer = 1):
    super().__init__()
    self.linear1 = nn.Linear(img_dim, hidden_layers)
    self.linear2 = nn.Linear(hidden_layers, hidden_layers)
    self.final_layer = nn.Linear(hidden_layers, output_layer)
    self.relu = nn.LeakyReLU(0.1)
    self.sigmoid = nn.Sigmoid()
  
  def forward(self, X):
    X = self.relu(self.linear1(X))
    X = self.relu(self.linear2(X))
    X = self.sigmoid(self.final_layer(X))

    return X

class Generator(nn.Module):
  def __init__(self, z_dim, hidden_shape = 256, img_dim = 784) -> None:
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(z_dim, hidden_shape),
        nn.LeakyReLU(0.1),
        nn.Linear(hidden_shape, hidden_shape),
        nn.LeakyReLU(0.1),
        nn.Linear(hidden_shape, img_dim),
        nn.Tanh()
    )
  
  def forward(self, X):
    X = self.gen(X)
    
    return X

# 2. Hyperparameters

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
epochs = 50

# 3. Initializations

In [None]:
disc = Discriminator(img_dim  = image_dim).to(device)
gen = Generator(z_dim = 64).to(device)

fixed_noise = torch.randn((batch_size, z_dim)).to(device)
fixed_noise.shape

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
]
)


opt_disc = optim.Adam(disc.parameters(), lr)
opt_gen = optim.Adam(gen.parameters(), lr)

loss_fn = nn.BCELoss()

writer_fake = SummaryWriter(log_dir = f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(log_dir = f'runs/GAN_MNIST/real')

step = 0

# 4. Load Data

In [None]:
dataset = torchvision.datasets.MNIST(
    download=True,
    train = True,
    root = './data',
    transform = transform
)
data_loader = DataLoader(
    dataset = dataset,
    batch_size = batch_size, 
    shuffle = True,
    drop_last = True
)


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



# 5. Training

In [None]:
for epoch in range(epochs):
  for batch_idx, (real, _) in enumerate(data_loader):
    fake = gen(fixed_noise).to(device)
    real = real.view(-1, 784).to(device)
    noise = torch.randn(batch_size, z_dim).to(device)

    disc_fake = disc(fake.detach()).view(-1)
    disc_real = disc(real).view(-1)
    loss_fake = loss_fn(disc_fake, torch.zeros_like(disc_fake))
    loss_real = loss_fn(disc_real, torch.ones_like(disc_real))
    loss_D = (loss_fake + loss_real) / 2
    
    opt_disc.zero_grad()
    loss_D.backward()
    opt_disc.step()
    

    output = disc(fake).view(-1)
    lossG = loss_fn(output, torch.ones_like(output))
    opt_gen.zero_grad()
    lossG.backward()
    opt_gen.step()

    if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {batch_idx}/{len(data_loader)} \
                      Loss D: {loss_D:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1



Epoch [0/50] Batch 0/1875                       Loss D: 0.7098, loss G: 0.7207
Epoch [1/50] Batch 0/1875                       Loss D: 0.6433, loss G: 0.9907
Epoch [2/50] Batch 0/1875                       Loss D: 0.3161, loss G: 1.6248
Epoch [3/50] Batch 0/1875                       Loss D: 0.2975, loss G: 1.7935
Epoch [4/50] Batch 0/1875                       Loss D: 0.5441, loss G: 1.4565
Epoch [5/50] Batch 0/1875                       Loss D: 0.6613, loss G: 1.5466
Epoch [6/50] Batch 0/1875                       Loss D: 0.3716, loss G: 1.5027
Epoch [7/50] Batch 0/1875                       Loss D: 0.5366, loss G: 1.2274
Epoch [8/50] Batch 0/1875                       Loss D: 0.5854, loss G: 1.7338
Epoch [9/50] Batch 0/1875                       Loss D: 0.5875, loss G: 0.9997
Epoch [10/50] Batch 0/1875                       Loss D: 0.4863, loss G: 1.4572
Epoch [11/50] Batch 0/1875                       Loss D: 0.4173, loss G: 1.5321
Epoch [12/50] Batch 0/1875                       L