# Implementing GAN in PyTorch: Full Example

## Setup Environment
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Data Preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)

## Generator
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 28*28),
            nn.Tanh()
        )
    def forward(self, x):
        return self.net(x).view(-1, 1, 28, 28)

## Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.net(x.view(-1, 28*28))

In [None]:
G = Generator().to(device)
D = Discriminator().to(device)
criterion = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=0.0002)
opt_D = optim.Adam(D.parameters(), lr=0.0002)

## Training Loop
losses_G = []
losses_D = []

In [None]:
for epoch in range(30):
    for i, (real, _) in enumerate(train_loader):
        real = real.to(device)
        bs = real.size(0)
        label_real = torch.ones(bs, 1).to(device)
        label_fake = torch.zeros(bs, 1).to(device)

        # Train Discriminator
        opt_D.zero_grad()
        pred_real = D(real)
        loss_real = criterion(pred_real, label_real)

In [None]:
        noise = torch.randn(bs, 100).to(device)
        fake = G(noise)
        pred_fake = D(fake.detach())
        loss_fake = criterion(pred_fake, label_fake)

In [None]:
        loss_D = loss_real + loss_fake
        loss_D.backward()
        opt_D.step()

        # Train Generator
        opt_G.zero_grad()
        pred = D(fake)
        loss_G = criterion(pred, label_real)
        loss_G.backward()
        opt_G.step()

In [None]:
    losses_D.append(loss_D.item())
    losses_G.append(loss_G.item())
    print(f'Epoch {epoch+1}, Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}')

## Evaluation
def show_samples():
    with torch.no_grad():
        noise = torch.randn(16, 100).to(device)
        samples = G(noise).cpu()
        grid = utils.make_grid(samples, nrow=4, normalize=True)
        plt.figure(figsize=(5,5))
        plt.imshow(np.transpose(grid, (1,2,0)))
        plt.axis('off')
        plt.show()

In [None]:
show_samples()

## Loss Plot
plt.plot(losses_D, label='Discriminator')
plt.plot(losses_G, label='Generator')
plt.legend()
plt.title('Loss Curves')
plt.show()