# P5: GAN Image Generation

**Objective:** Implement a simple GAN (DCGAN-style) to generate small images (e.g., MNIST or CIFAR-10).

In [None]:
# Minimal placeholder: students should implement generator and discriminator, adversarial training loop.
print('GAN skeleton: define generator, discriminator, loss functions and train in loop.')

In [None]:
# Practical 5: DCGAN on MNIST (PyTorch)
import torch, torch.nn as nn, torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

tfm = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,),(0.5,))])
ds = datasets.MNIST(root='./data', train=True, download=True, transform=tfm)
dl = DataLoader(ds, batch_size=128, shuffle=True)

nz=64
class G(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(nz, 256), nn.ReLU(True),
            nn.Linear(256, 512), nn.ReLU(True),
            nn.Linear(512, 28*28), nn.Tanh() )
    def forward(self,z):
        return self.net(z).view(-1,1,28,28)
class D(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28*28, 256), nn.LeakyReLU(0.2, True),
            nn.Linear(256, 128), nn.LeakyReLU(0.2, True),
            nn.Linear(128,1), nn.Sigmoid() )
    def forward(self,x):
        return self.net(x)

Gz=G(); Dz=D(); optG=optim.Adam(Gz.parameters(), 2e-4, betas=(0.5,0.999)); optD=optim.Adam(Dz.parameters(), 2e-4, betas=(0.5,0.999))
bce = nn.BCELoss()
fixed = torch.randn(16, nz)

for epoch in range(2):
    for xb,_ in dl:
        # train D
        z = torch.randn(xb.size(0), nz)
        fake = Gz(z).detach()
        lossD = bce(Dz(xb), torch.ones(xb.size(0),1)) + bce(Dz(fake), torch.zeros(xb.size(0),1))
        optD.zero_grad(); lossD.backward(); optD.step()
        # train G
        z = torch.randn(xb.size(0), nz)
        gen = Gz(z)
        lossG = bce(Dz(gen), torch.ones(xb.size(0),1))
        optG.zero_grad(); lossG.backward(); optG.step()
    print('epoch', epoch, 'lossG', float(lossG), 'lossD', float(lossD))

with torch.no_grad():
    grid = utils.make_grid(Gz(fixed), nrow=4, normalize=True, value_range=(-1,1))
plt.figure(figsize=(4,4)); plt.imshow(grid.permute(1,2,0)); plt.axis('off'); plt.show()