In [None]:
dataset_choice = "mnist"
epochs = 50
batch_size = 128
noise_dim = 100
learning_rate = 0.0002
save_interval = 5

import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [None]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

img_size = 28
channels = 1

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

if dataset_choice == "mnist":
    dataset = datasets.MNIST("./data", train=True, download=True, transform=transform)
else:
    dataset = datasets.FashionMNIST("./data", train=True, download=True, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print("Dataset loaded:", dataset_choice)

100%|██████████| 9.91M/9.91M [00:00<00:00, 19.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 468kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.49MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 13.7MB/s]

Dataset loaded: mnist





In [None]:
import torch.nn as nn

class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),

            nn.Linear(1024, img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), channels, img_size, img_size)


class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size, 512),
            nn.LeakyReLU(0.2),

            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),

            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))


generator = Generator().to(device)
discriminator = Discriminator().to(device)
print("Models initialized")

Models initialized


In [None]:
import torch.optim as optim

criterion = nn.BCELoss()

optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

In [None]:
import os
import torchvision.utils as vutils
import torch.nn.functional as F

os.makedirs("generated_samples", exist_ok=True)

best_g_loss = float("inf")
best_epoch = -1

for epoch in range(1, epochs + 1):
    for real_imgs, _ in dataloader:
        real_imgs = real_imgs.to(device)
        bs = real_imgs.size(0)

        real_labels = torch.full((bs, 1), 0.9, device=device)
        fake_labels = torch.zeros(bs, 1, device=device)

        optimizer_D.zero_grad()

        real_loss = criterion(discriminator(real_imgs), real_labels)

        z = torch.randn(bs, noise_dim, device=device)
        fake_imgs = generator(z)
        fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)

        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        optimizer_G.zero_grad()
        g_loss = criterion(discriminator(fake_imgs), real_labels)
        g_loss.backward()
        optimizer_G.step()

        print(f"Epoch {epoch}/{epochs} | D_loss: {d_loss.item():.2f} | G_loss: {g_loss.item():.2f}")

        if g_loss.item() < best_g_loss:
            best_g_loss = g_loss.item()
            best_epoch = epoch
            torch.save(generator.state_dict(), "best_generator.pth")
            torch.save(discriminator.state_dict(), "best_discriminator.pth")
            print(f" Best model saved at epoch {epoch}")

        if epoch % save_interval == 0:
            with torch.no_grad():
                z = torch.randn(25, noise_dim, device=device)
                samples = generator(z)
                samples = (samples + 1) / 2
                samples = F.interpolate(samples, scale_factor=4, mode="nearest")

                vutils.save_image(
                    samples,
                    f"generated_samples/epoch_{epoch:02d}.png",
                    nrow=5,
                    padding=2
                )

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch 40/50 | D_loss: 0.54 | G_loss: 1.08
Epoch 40/50 | D_loss: 0.56 | G_loss: 1.34
Epoch 40/50 | D_loss: 0.58 | G_loss: 1.01
Epoch 40/50 | D_loss: 0.55 | G_loss: 1.19
Epoch 40/50 | D_loss: 0.58 | G_loss: 1.34
Epoch 40/50 | D_loss: 0.53 | G_loss: 1.04
Epoch 40/50 | D_loss: 0.59 | G_loss: 1.15
Epoch 40/50 | D_loss: 0.57 | G_loss: 1.06
Epoch 40/50 | D_loss: 0.57 | G_loss: 1.35
Epoch 40/50 | D_loss: 0.56 | G_loss: 0.95
Epoch 40/50 | D_loss: 0.56 | G_loss: 1.45
Epoch 40/50 | D_loss: 0.58 | G_loss: 0.80
Epoch 40/50 | D_loss: 0.59 | G_loss: 1.75
Epoch 40/50 | D_loss: 0.62 | G_loss: 0.71
Epoch 40/50 | D_loss: 0.63 | G_loss: 1.70
Epoch 40/50 | D_loss: 0.59 | G_loss: 0.86
Epoch 40/50 | D_loss: 0.58 | G_loss: 1.23
Epoch 40/50 | D_loss: 0.55 | G_loss: 1.23
Epoch 40/50 | D_loss: 0.58 | G_loss: 1.08
Epoch 40/50 | D_loss: 0.58 | G_loss: 1.15
Epoch 40/50 | D_loss: 0.53 | G_loss: 1.20
Epoch 40/50 | D_loss: 0.59 | G_loss: 0.96
Epoch 40/50

In [None]:
print("Training complete")
print("Best epoch:", best_epoch)
print("Best Generator Loss:", best_g_loss)


Training complete
Best epoch: 8
Best Generator Loss: 0.43837398290634155
