In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import make_grid
import os

# Set random seed for reproducibility
torch.manual_seed(42)

# Hyperparameters
latent_dim = 100
img_shape = (1, 28, 28)
batch_size = 64
lr = 0.0002
epochs = 100
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create directory to save samples
os.makedirs("gan_samples", exist_ok=True)

# Generator Network
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

# Initialize generator and discriminator
generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)

# Loss function
adversarial_loss = nn.BCELoss()

# Configure data loader
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Sample visualization function
def sample_images(epoch, n_row=4):
    """Saves a grid of generated digits"""
    z = torch.randn(n_row**2, latent_dim, device=device)
    gen_imgs = generator(z)
    gen_imgs = 0.5 * gen_imgs + 0.5  # Rescale from [-1, 1] to [0, 1]

    fig, ax = plt.subplots(figsize=(4, 4))
    ax.set_xticks([])
    ax.set_yticks([])
    ax.imshow(make_grid(gen_imgs.cpu().detach(), nrow=n_row).permute(1, 2, 0))
    fig.savefig(f"gan_samples/epoch_{epoch}.png")
    plt.close()

# Training loop
G_losses = []
D_losses = []

for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):
        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1, device=device)
        fake = torch.zeros(imgs.size(0), 1, device=device)

        # Configure input
        real_imgs = imgs.to(device)

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_loss = adversarial_loss(discriminator(real_imgs), valid)

        # Generate fake images
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)

        # Loss for fake images
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)

        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2
        d_loss.backward()
        optimizer_D.step()

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Generate fake images
        z = torch.randn(imgs.size(0), latent_dim, device=device)
        gen_imgs = generator(z)

        # Generator wants discriminator to think these are real
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)
        g_loss.backward()
        optimizer_G.step()

        # Save losses
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

        # Print progress
        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")

    # Save sample images at specific epochs
    if epoch == 0 or epoch == 50 or epoch == 99:
        sample_images(epoch)

# Plot training losses
plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="Generator")
plt.plot(D_losses, label="Discriminator")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.savefig("gan_samples/loss_plot.png")
plt.close()

print("Training complete!")

100%|██████████| 9.91M/9.91M [00:00<00:00, 56.4MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.67MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.3MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 6.28MB/s]


[Epoch 0/100] [Batch 0/938] [D loss: 0.7061] [G loss: 0.6776]
[Epoch 0/100] [Batch 200/938] [D loss: 0.0432] [G loss: 2.5940]
[Epoch 0/100] [Batch 400/938] [D loss: 0.3094] [G loss: 2.1380]
[Epoch 0/100] [Batch 600/938] [D loss: 0.1768] [G loss: 2.8983]
[Epoch 0/100] [Batch 800/938] [D loss: 0.2819] [G loss: 2.8275]
[Epoch 1/100] [Batch 0/938] [D loss: 0.1696] [G loss: 3.7236]
[Epoch 1/100] [Batch 200/938] [D loss: 0.1874] [G loss: 4.4389]
[Epoch 1/100] [Batch 400/938] [D loss: 0.1335] [G loss: 3.9873]
[Epoch 1/100] [Batch 600/938] [D loss: 0.1258] [G loss: 5.0432]
[Epoch 1/100] [Batch 800/938] [D loss: 0.1508] [G loss: 4.0957]
[Epoch 2/100] [Batch 0/938] [D loss: 0.0832] [G loss: 4.4976]
[Epoch 2/100] [Batch 200/938] [D loss: 0.2101] [G loss: 4.3432]
[Epoch 2/100] [Batch 400/938] [D loss: 0.0556] [G loss: 4.7394]
[Epoch 2/100] [Batch 600/938] [D loss: 0.1160] [G loss: 4.2406]
[Epoch 2/100] [Batch 800/938] [D loss: 0.0454] [G loss: 5.6513]
[Epoch 3/100] [Batch 0/938] [D loss: 0.0944] [