# GANs 

In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

In [2]:
cuda = True if torch.cuda.is_available() else False
print("cuda: ", cuda)

cuda:  True


In [3]:
LATENT_DIM = 100
SHAPE =28

In [4]:
class Generator(nn.Module):
    def __init__(self, latent_dim=100, img_shape=(1, 28, 28)):
        super(Generator, self).__init__()
        self.latent_dim = latent_dim
        self.img_shape = img_shape

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        # Build the neural network
        self.model = nn.Sequential(
            *block(self.latent_dim, 128, normalize=False),  
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(self.img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        # Reshape output to image shape
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self, img_shape=(1, 28, 28)):
        super(Discriminator, self).__init__()
        self.img_shape = img_shape

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(self.img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

adversarial_loss = nn.BCELoss()

In [6]:

data = datasets.MNIST(
    root="data",
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.Resize(SHAPE),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ]),
)

dataloader = DataLoader(data, batch_size=64, shuffle=True)

# --------------------------
# Initialize Model
# --------------------------
generator = Generator(latent_dim=LATENT_DIM, img_shape=(1, SHAPE, SHAPE))
discriminator = Discriminator(img_shape=(1, SHAPE, SHAPE))

adversarial_loss = nn.BCELoss()
if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# --------------------------
# Optimizers
# --------------------------
optimizer_G = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

# --------------------------
# Training Loop
# --------------------------
n_epochs = 5  # for example, just run 5 epochs to test
for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Ground truths
        valid = torch.ones(imgs.size(0), 1, device="cuda" if cuda else "cpu", requires_grad=False)
        fake = torch.zeros(imgs.size(0), 1, device="cuda" if cuda else "cpu", requires_grad=False)

        real_imgs = imgs.type(Tensor)

        # -----------------
        #  Train Generator
        # -----------------
        optimizer_G.zero_grad()
        z = torch.randn(imgs.size(0), LATENT_DIM).type(Tensor)
        gen_imgs = generator(z)
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------
        optimizer_D.zero_grad()
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        print(
            f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
            f"[D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}]"
        )

        batches_done = epoch * len(dataloader) + i
        if batches_done % 400 == 0:
            save_image(gen_imgs.data[:25], f"data/MNIST/{batches_done}.png", nrow=5, normalize=True)


[Epoch 0/5] [Batch 0/938] [D loss: 0.705510] [G loss: 0.695471]
[Epoch 0/5] [Batch 1/938] [D loss: 0.607280] [G loss: 0.692139]
[Epoch 0/5] [Batch 2/938] [D loss: 0.533808] [G loss: 0.689332]
[Epoch 0/5] [Batch 3/938] [D loss: 0.477295] [G loss: 0.686353]
[Epoch 0/5] [Batch 4/938] [D loss: 0.439688] [G loss: 0.682618]
[Epoch 0/5] [Batch 5/938] [D loss: 0.413984] [G loss: 0.678105]
[Epoch 0/5] [Batch 6/938] [D loss: 0.395792] [G loss: 0.672899]
[Epoch 0/5] [Batch 7/938] [D loss: 0.385443] [G loss: 0.665978]
[Epoch 0/5] [Batch 8/938] [D loss: 0.385233] [G loss: 0.658104]
[Epoch 0/5] [Batch 9/938] [D loss: 0.387294] [G loss: 0.646904]
[Epoch 0/5] [Batch 10/938] [D loss: 0.387542] [G loss: 0.638136]
[Epoch 0/5] [Batch 11/938] [D loss: 0.392790] [G loss: 0.627510]
[Epoch 0/5] [Batch 12/938] [D loss: 0.398846] [G loss: 0.614694]
[Epoch 0/5] [Batch 13/938] [D loss: 0.403742] [G loss: 0.606288]
[Epoch 0/5] [Batch 14/938] [D loss: 0.410139] [G loss: 0.595644]
[Epoch 0/5] [Batch 15/938] [D loss: