# Generative Adversarial Networks (GAN) from Scratch

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/adiel2012/deep-learning-abc/blob/main/gan.ipynb)

A GAN consists of two networks competing against each other:
1. **Generator (G):** Tries to create fake data that looks real.
2. **Discriminator (D):** Tries to distinguish between real data and fake data.

Minimax Game:
$$ \min_G \max_D V(D, G) = \mathbb{E}_{x \sim p_{data}(x)}[\log D(x)] + \mathbb{E}_{z \sim p_{z}(z)}[\log(1 - D(G(z)))] $$

In [None]:
!pip install torch torchvision matplotlib

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Define Networks (DCGAN style)

We'll use a simple fully connected GAN for MNIST (or ConvGAN/DCGAN for better results).

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, img_dim=28*28):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 1024),
            nn.LeakyReLU(0.01),
            nn.Linear(1024, img_dim),
            nn.Tanh()  # Output should be [-1, 1]
        )

    def forward(self, z):
        return self.net(z)

class Discriminator(nn.Module):
    def __init__(self, img_dim=28*28):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.01),
            nn.Linear(512, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid()  # Output prob real [0, 1]
        )

    def forward(self, x):
        return self.net(x)

## 2. Training Loop

In [None]:
# Hyperparameters
lr = 3e-4
z_dim = 64
img_dim = 28 * 28
batch_size = 32
epochs = 2  # Keep short for demo

# Data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])
dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Models
gen = Generator(z_dim, img_dim).to(device)
disc = Discriminator(img_dim).to(device)

opt_gen = optim.Adam(gen.parameters(), lr=lr)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
criterion = nn.BCELoss()

print("Training GAN...")
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, img_dim).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        
        # Disc on real
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        
        # Disc on fake
        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z)))
        # Generator wants Discriminator to output 1 (Real) for fake images
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(f"Epoch [{epoch}/{epochs}] Loss D: {lossD:.4f}, loss G: {lossG:.4f}")

## 3. Visualize Generated Images

In [None]:
with torch.no_grad():
    noise = torch.randn(16, z_dim).to(device)
    fake_images = gen(noise).reshape(-1, 28, 28).cpu()

print("Autogenerated Digits:")
fig, axes = plt.subplots(2, 8, figsize=(10, 3))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(fake_images[i], cmap='gray')
    ax.axis('off')
plt.show()