In [1]:
dataset_choice = "mnist"
epochs = 50
batch_size = 128
noise_dim = 100
learning_rate = 0.0002
save_interval = 5

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [2]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

img_size = 28
channels = 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

if dataset_choice == "mnist":
    dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
else:
    dataset = datasets.FashionMNIST("./data", train=True, download=True, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print("Dataset loaded:", dataset_choice)

100%|██████████| 9.91M/9.91M [00:01<00:00, 5.01MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 131kB/s]
100%|██████████| 1.65M/1.65M [00:01<00:00, 1.24MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 10.5MB/s]

Dataset loaded: mnist





In [3]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), channels, img_size, img_size)


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))


generator = Generator().to(device)
discriminator = Discriminator().to(device)
print("Models initialized")

Models initialized


In [4]:
import torch.optim as optim

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [5]:
import os
import torchvision.utils as vutils
import torch.nn.functional as F

os.makedirs("generated_samples", exist_ok=True)

best_g_loss = float("inf")
best_epoch = -1

for epoch in range(1, epochs + 1):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        bs = real_imgs.size(0)

        # ----- LABEL SMOOTHING -----
        real_labels = torch.full((bs, 1), 0.9, device=device)
        fake_labels = torch.zeros(bs, 1, device=device)

        # ----- TRAIN DISCRIMINATOR -----
        optimizer_D.zero_grad()

        real_loss = criterion(discriminator(real_imgs), real_labels)

        z = torch.randn(bs, noise_dim, device=device)
        fake_imgs = generator(z)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)

        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # ----- TRAIN GENERATOR -----
        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

    # ----- LOG -----
    print(f"Epoch {epoch}/{epochs} | D_loss: {d_loss.item():.2f} | G_loss: {g_loss.item():.2f}")

    # ----- SAVE BEST MODEL -----
    if g_loss.item() < best_g_loss:
        best_g_loss = g_loss.item()
        best_epoch = epoch
        torch.save(generator.state_dict(), "best_generator.pth")
        torch.save(discriminator.state_dict(), "best_discriminator.pth")
        print(f" Best model saved at epoch {epoch}")

    # ----- SAVE BIG, CLEAR IMAGES -----
    if epoch % save_interval == 0:
        with torch.no_grad():
            z = torch.randn(25, noise_dim, device=device)
            samples = generator(z)
            samples = (samples + 1) / 2
            samples = F.interpolate(samples, scale_factor=4, mode="nearest")

            vutils.save_image(
                samples,
                f"generated_samples/epoch_{epoch:02d}.png",
                nrow=5,
                padding=2
            )

Epoch 1/50 | D_loss: 0.63 | G_loss: 1.57
 Best model saved at epoch 1
Epoch 2/50 | D_loss: 0.56 | G_loss: 2.23
Epoch 3/50 | D_loss: 0.46 | G_loss: 1.25
 Best model saved at epoch 3
Epoch 4/50 | D_loss: 0.47 | G_loss: 1.81
Epoch 5/50 | D_loss: 0.59 | G_loss: 0.95
 Best model saved at epoch 5
Epoch 6/50 | D_loss: 0.44 | G_loss: 2.06
Epoch 7/50 | D_loss: 0.44 | G_loss: 1.82
Epoch 8/50 | D_loss: 0.38 | G_loss: 1.97
Epoch 9/50 | D_loss: 0.38 | G_loss: 1.92
Epoch 10/50 | D_loss: 0.45 | G_loss: 1.57
Epoch 11/50 | D_loss: 0.48 | G_loss: 1.08
Epoch 12/50 | D_loss: 0.47 | G_loss: 1.67
Epoch 13/50 | D_loss: 0.53 | G_loss: 1.26
Epoch 14/50 | D_loss: 0.53 | G_loss: 1.64
Epoch 15/50 | D_loss: 0.55 | G_loss: 1.06
Epoch 16/50 | D_loss: 0.59 | G_loss: 1.36
Epoch 17/50 | D_loss: 0.58 | G_loss: 1.07
Epoch 18/50 | D_loss: 0.68 | G_loss: 1.69
Epoch 19/50 | D_loss: 0.61 | G_loss: 0.96
Epoch 20/50 | D_loss: 0.68 | G_loss: 0.72
 Best model saved at epoch 20
Epoch 21/50 | D_loss: 0.60 | G_loss: 1.35
Epoch 22/5

In [6]:
print("Training complete")
print("Best epoch:", best_epoch)
print("Best Generator Loss:", best_g_loss)

Training complete
Best epoch: 48
Best Generator Loss: 0.6044384241104126
