In [None]:
# PyTorch version of basic MNIST GAN (Colab-compatible)
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import os

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
learning_rate = 0.0002
total_epoch = 100
batch_size = 100
n_hidden = 256
n_input = 28 * 28
n_noise = 128

# MNIST dataset
dataset = datasets.MNIST(root="./mnist/data/", train=True, download=True,
                         transform=transforms.ToTensor())
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Generator class
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_noise, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, n_input),
            nn.Sigmoid()
        )

    def forward(self, z):
        return self.model(z)

# Discriminator class
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_input, n_hidden),
            nn.ReLU(),
            nn.Linear(n_hidden, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Models
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Loss and Optimizer
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_D = optim.Adam(discriminator.parameters(), lr=learning_rate)

# Sample folder
os.makedirs("samples_ex", exist_ok=True)

# Training loop
for epoch in range(total_epoch):
    for idx, (real_imgs, _) in enumerate(data_loader):
        real_imgs = real_imgs.view(-1, n_input).to(device)
        batch_size = real_imgs.size(0)

        # Labels
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        # Train Discriminator
        z = torch.randn(batch_size, n_noise).to(device)
        fake_imgs = generator(z)

        real_output = discriminator(real_imgs)
        fake_output = discriminator(fake_imgs.detach())

        d_loss = criterion(real_output, real_labels) + criterion(fake_output, fake_labels)

        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        z = torch.randn(batch_size, n_noise).to(device)
        fake_imgs = generator(z)
        output = discriminator(fake_imgs)
        g_loss = criterion(output, real_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1:03d}/{total_epoch}]  D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    # Save generated samples every 10 epochs
    if epoch == 0 or (epoch + 1) % 10 == 0:
        sample_z = torch.randn(10, n_noise).to(device)
        samples = generator(sample_z).view(-1, 28, 28).cpu().data

        fig, ax = plt.subplots(1, 10, figsize=(10, 1))
        for i in range(10):
            ax[i].imshow(samples[i], cmap='gray')
            ax[i].axis('off')
        plt.savefig(f"samples_ex/{str(epoch).zfill(3)}.png", bbox_inches='tight')
        plt.close(fig)



Epoch [001/100]  D Loss: 0.1141, G Loss: 3.1040
Epoch [002/100]  D Loss: 0.1328, G Loss: 3.3868
Epoch [003/100]  D Loss: 0.0940, G Loss: 4.0175
Epoch [004/100]  D Loss: 0.1007, G Loss: 5.1259
Epoch [005/100]  D Loss: 0.0797, G Loss: 3.4724
Epoch [006/100]  D Loss: 0.1729, G Loss: 2.9989
Epoch [007/100]  D Loss: 0.1354, G Loss: 3.2387
Epoch [008/100]  D Loss: 0.1438, G Loss: 3.7516
Epoch [009/100]  D Loss: 0.1320, G Loss: 3.8958
Epoch [010/100]  D Loss: 0.3066, G Loss: 2.6786
Epoch [011/100]  D Loss: 0.1824, G Loss: 3.5589
Epoch [012/100]  D Loss: 0.2730, G Loss: 2.9390
Epoch [013/100]  D Loss: 0.2475, G Loss: 3.0807
Epoch [014/100]  D Loss: 0.2913, G Loss: 3.4162
Epoch [015/100]  D Loss: 0.2575, G Loss: 3.1665
Epoch [016/100]  D Loss: 0.2859, G Loss: 3.7177
Epoch [017/100]  D Loss: 0.1628, G Loss: 4.1763
Epoch [018/100]  D Loss: 0.2766, G Loss: 2.9310
Epoch [019/100]  D Loss: 0.2658, G Loss: 3.8986
Epoch [020/100]  D Loss: 0.2254, G Loss: 3.7210
Epoch [021/100]  D Loss: 0.1731, G Loss: