In [None]:
import os
import math
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"

# Hyperparameters
lr_gen = 1e-4        # Generator learning rate
lr_disc = 4e-5       # Discriminator learning rate
z_dim = 128
img_dim = 28 * 28
batch_size = 128
epochs = 200

# Output dir
output_dir = "/kaggle/working/generated_images"
os.makedirs(output_dir, exist_ok=True)


**Creating the discriminator and generator**

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),  # logits, no sigmoid
        )

    def forward(self, x):
        return self.disc(x)

class Generator(nn.Module):
    """
    Input:  (batch, z_dim)
    Output: (batch, 28*28) in [-1, 1]
    """
    def __init__(self, z_dim=128, img_dim=28*28):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 1024, bias=False),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(1024, 512, bias=False),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(512, 256, bias=False),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Linear(256, img_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)


**Initialising weights**

In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        if getattr(m, "weight", None) is not None:
            nn.init.normal_(m.weight, 0.0, 0.02)
        if getattr(m, "bias", None) is not None:
            nn.init.constant_(m.bias, 0.0)
    elif isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d)):
        if getattr(m, "weight", None) is not None:
            nn.init.normal_(m.weight, 1.0, 0.02)
        if getattr(m, "bias", None) is not None:
            nn.init.constant_(m.bias, 0.0)

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim=z_dim, img_dim=img_dim).to(device)

# Apply weight initialization (define before calling apply)
disc.apply(weights_init)
gen.apply(weights_init)

# Fixed noise for consistent visualization
fixed_noise = torch.randn((batch_size, z_dim), device=device)

print("Models initialized and moved to device")
print(f"Discriminator parameters: {sum(p.numel() for p in disc.parameters()):,}")
print(f"Generator parameters: {sum(p.numel() for p in gen.parameters()):,}")


**Prepping the data**

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root="dataset/", transform=transform, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

print(f"Dataset size: {len(dataset)}")
print(f"Number of batches: {len(loader)}")

sample_batch = next(iter(loader))
print(f"Sample batch shape: {sample_batch[0].shape}")
print(f"Sample data range: [{sample_batch[0].min():.3f}, {sample_batch[0].max():.3f}]")


**Defining losses and optimizer**

In [None]:
# Loss function
criterion = nn.BCEWithLogitsLoss()

opt_disc = optim.Adam(disc.parameters(), lr=lr_disc, betas=(0.5, 0.999))
opt_gen  = optim.Adam(gen.parameters(),  lr=lr_gen,  betas=(0.5, 0.999))

scheduler_disc = optim.lr_scheduler.ReduceLROnPlateau(opt_disc, mode='min', factor=0.5, patience=5)
scheduler_gen  = optim.lr_scheduler.ReduceLROnPlateau(opt_gen,  mode='min', factor=0.5, patience=5)


**Training loop**

In [None]:
print(f"Starting training on {device}")
print("=" * 50)

best_lossG = float('inf')

for epoch in range(epochs):
    loop = tqdm(loader, leave=True, desc=f"Epoch {epoch+1}/{epochs}")
    epoch_lossD = 0.0
    epoch_lossG = 0.0
    batches = 0

    for batch_idx, (real, _) in enumerate(loop):
        real = real.view(-1, 784).to(device)
        bsz = real.size(0)

        # Label smoothing within [0,1] for BCEWithLogitsLoss
        real_labels = torch.empty(bsz, device=device).uniform_(0.7, 1.0)
        fake_labels = torch.empty(bsz, device=device).uniform_(0.0, 0.3)

        # Train Discriminator
        noise = torch.randn(bsz, z_dim, device=device)
        fake = gen(noise)

        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, real_labels)

        disc_fake = disc(fake.detach()).view(-1)
        lossD_fake = criterion(disc_fake, fake_labels)

        lossD = 0.5 * (lossD_real + lossD_fake)
        opt_disc.zero_grad()
        lossD.backward()
        torch.nn.utils.clip_grad_norm_(disc.parameters(), max_norm=1.0)
        opt_disc.step()

        # Train Generator (wants disc(fake) to be 'real' = 1s)
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        opt_gen.zero_grad()
        lossG.backward()
        torch.nn.utils.clip_grad_norm_(gen.parameters(), max_norm=1.0)
        opt_gen.step()

        epoch_lossD += lossD.item()
        epoch_lossG += lossG.item()
        batches += 1

        loop.set_postfix(lossD=lossD.item(), lossG=lossG.item())

    avg_lossD = epoch_lossD / batches
    avg_lossG = epoch_lossG / batches

    scheduler_disc.step(avg_lossD)
    scheduler_gen.step(avg_lossG)

    # Visualize and save images every 10 epochs (or first epoch)
    if (epoch + 1) % 10 == 0 or epoch == 0:
        with torch.no_grad():
            fake_img = gen(fixed_noise).view(-1, 1, 28, 28).cpu()
            fake_img = (fake_img + 1) / 2  # [0,1] for display
            real_img = real[:64].view(-1, 1, 28, 28).cpu()
            real_img = (real_img + 1) / 2  # [0,1]

        grid_fake = torchvision.utils.make_grid(fake_img[:64], nrow=8, normalize=False)
        grid_real = torchvision.utils.make_grid(real_img[:64], nrow=8, normalize=False)

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1); plt.title(f"Fake Images - Epoch {epoch+1}"); plt.axis("off")
        plt.imshow(grid_fake.permute(1, 2, 0).squeeze(), cmap="gray")
        plt.subplot(1, 2, 2); plt.title("Real Images"); plt.axis("off")
        plt.imshow(grid_real.permute(1, 2, 0).squeeze(), cmap="gray")
        plt.tight_layout()
        save_path_img = os.path.join(output_dir, f"epoch_{epoch+1:03d}_real_vs_fake.png")
        plt.savefig(save_path_img)
        plt.show()

    # Save best generator model checkpoint
    if avg_lossG < best_lossG:
        best_lossG = avg_lossG
        model_save_path = os.path.join(output_dir, f"best_gen_epoch_{epoch+1:03d}.pth")
        torch.save(gen.state_dict(), model_save_path)
        print(f"Saved best generator model at epoch {epoch+1} with loss {best_lossG:.4f}")

print("Training completed!")
