# Trains a GAN from scratch on MNIST dataset using PyTorch.
# Generator is conditioned on digit label.


In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(latent_dim + num_classes, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Tanh()
        )

    def forward(self, z, labels):
        input = torch.cat((z, self.label_emb(labels)), dim=1)
        img = self.model(input)
        return img.view(-1, 1, 28, 28)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.label_emb = nn.Embedding(num_classes, num_classes)
        self.model = nn.Sequential(
            nn.Linear(28*28 + num_classes, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x, labels):
        x = x.view(x.size(0), -1)
        input = torch.cat((x, self.label_emb(labels)), dim=1)
        return self.model(input)

In [None]:

os.makedirs("models", exist_ok=True)
latent_dim = 100
num_classes = 10
epochs = 10
batch_size = 64

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
data = datasets.MNIST(root="data", train=True, download=True, transform=transform)
loader = DataLoader(data, batch_size=batch_size, shuffle=True)

G = Generator(latent_dim, num_classes).to(device)
D = Discriminator(num_classes).to(device)
loss_fn = nn.BCELoss()
opt_G = torch.optim.Adam(G.parameters(), lr=0.0002)
opt_D = torch.optim.Adam(D.parameters(), lr=0.0002)

for epoch in range(epochs):
    for imgs, labels in loader:
        imgs, labels = imgs.to(device), labels.to(device)
        valid = torch.ones(imgs.size(0), 1).to(device)
        fake = torch.zeros(imgs.size(0), 1).to(device)

        z = torch.randn(imgs.size(0), latent_dim).to(device)
        gen_labels = torch.randint(0, num_classes, (imgs.size(0),)).to(device)
        gen_imgs = G(z, gen_labels)

        # Train Discriminator
        real_loss = loss_fn(D(imgs, labels), valid)
        fake_loss = loss_fn(D(gen_imgs.detach(), gen_labels), fake)
        D_loss = (real_loss + fake_loss) / 2
        opt_D.zero_grad(); D_loss.backward(); opt_D.step()

        # Train Generator
        G_loss = loss_fn(D(gen_imgs, gen_labels), valid)
        opt_G.zero_grad(); G_loss.backward(); opt_G.step()

    print(f"Epoch {epoch+1}/{epochs} - D_loss: {D_loss.item():.4f} - G_loss: {G_loss.item():.4f}")

torch.save(G.state_dict(), "models/generator.pth")

In [None]:
# Generate test images
z = torch.randn(5, latent_dim).to(device)
digit = torch.tensor([5]*5).to(device)  # example: generate digit '5'
samples = G(z, digit)

In [None]:
import matplotlib.pyplot as plt

for i in range(5):
    plt.subplot(1, 5, i+1)
    plt.imshow(samples[i].detach().cpu().squeeze(), cmap='gray')
    plt.axis('off')
plt.show()
