In [1]:
# --- Setup ---
!pip install torch torchvision



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os

In [3]:
# --- Hyperparams ---
batch_size = 128
epochs = 30
z_dim = 100
lr = 0.0002
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# --- Data Loader ---
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

In [5]:
# --- Generator ---
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        self.model = nn.Sequential(
            nn.Linear(z_dim + 10, 256),
            nn.ReLU(True),
            nn.Linear(256, 512),
            nn.ReLU(True),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, z, labels):
        c = self.label_emb(labels)
        x = torch.cat([z, c], dim=1)
        return self.model(x).view(-1, 1, 28, 28)

In [6]:
# --- Discriminator ---
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.label_emb = nn.Embedding(10, 10)
        self.model = nn.Sequential(
            nn.Linear(784 + 10, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, img, labels):
        c = self.label_emb(labels)
        x = torch.cat([img.view(img.size(0), -1), c], dim=1)
        return self.model(x)

In [7]:
# --- Initialize ---
G = Generator().to(device)
D = Discriminator().to(device)
loss_fn = nn.BCELoss()
opt_G = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
opt_D = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

In [8]:
# --- Train ---
for epoch in range(epochs):
    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        b = imgs.size(0)

        real = torch.ones(b, 1).to(device)
        fake = torch.zeros(b, 1).to(device)

        # --- Train Discriminator ---
        z = torch.randn(b, z_dim).to(device)
        gen_labels = torch.randint(0, 10, (b,)).to(device)
        gen_imgs = G(z, gen_labels)

        real_loss = loss_fn(D(imgs, labels), real)
        fake_loss = loss_fn(D(gen_imgs.detach(), gen_labels), fake)
        d_loss = real_loss + fake_loss

        opt_D.zero_grad()
        d_loss.backward()
        opt_D.step()

        # --- Train Generator ---
        z = torch.randn(b, z_dim).to(device)
        gen_labels = torch.randint(0, 10, (b,)).to(device)
        gen_imgs = G(z, gen_labels)
        g_loss = loss_fn(D(gen_imgs, gen_labels), real)

        opt_G.zero_grad()
        g_loss.backward()
        opt_G.step()

    print(f"[{epoch+1}/{epochs}] D Loss: {d_loss.item():.4f} | G Loss: {g_loss.item():.4f}")

[1/30] D Loss: 0.9765 | G Loss: 1.7571
[2/30] D Loss: 0.2814 | G Loss: 6.2829
[3/30] D Loss: 0.4120 | G Loss: 3.5094
[4/30] D Loss: 0.2844 | G Loss: 3.2379
[5/30] D Loss: 0.2006 | G Loss: 4.2493
[6/30] D Loss: 0.1793 | G Loss: 4.3382
[7/30] D Loss: 0.3601 | G Loss: 3.6300
[8/30] D Loss: 0.6493 | G Loss: 1.3287
[9/30] D Loss: 0.3703 | G Loss: 2.7347
[10/30] D Loss: 0.4665 | G Loss: 2.3637
[11/30] D Loss: 0.3839 | G Loss: 2.8111
[12/30] D Loss: 0.0134 | G Loss: 5.0802
[13/30] D Loss: 0.0056 | G Loss: 6.2151
[14/30] D Loss: 0.0021 | G Loss: 7.4487
[15/30] D Loss: 0.0010 | G Loss: 7.6134
[16/30] D Loss: 0.0012 | G Loss: 7.6192
[17/30] D Loss: 0.0004 | G Loss: 9.1256
[18/30] D Loss: 0.0003 | G Loss: 9.8049
[19/30] D Loss: 0.0001 | G Loss: 10.2498
[20/30] D Loss: 0.0000 | G Loss: 10.7263
[21/30] D Loss: 0.0015 | G Loss: 6.9172
[22/30] D Loss: 0.0002 | G Loss: 9.1520
[23/30] D Loss: 0.0000 | G Loss: 10.2884
[24/30] D Loss: 0.0003 | G Loss: 8.4296
[25/30] D Loss: 0.0004 | G Loss: 8.8055
[26/30

In [9]:
# --- Save model ---
os.makedirs("models", exist_ok=True)
torch.save(G.state_dict(), "models/generator.pth")
print("Model saved!")

Model saved!
