In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision


device = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 5e-5
BATCH_SIZE = 64
IMAGE_SIZE = 64
CHANNELS_IMG = 1
Z_DIM = 100
NUM_EPOCHS = 5
FEATURES_CRITIC = 16
FEATURES_GEN = 16
CRITIC_ITERATIONS = 5
WEIGHT_CLIP = 0.01


class Critic(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Critic, self).__init__()
        self.critic = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # Block 2
            nn.Conv2d(features_d, features_d * 2, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_d * 2),
            nn.LeakyReLU(0.2),
            # Block 3
            nn.Conv2d(features_d * 2, features_d * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_d * 4),
            nn.LeakyReLU(0.2),
            # Output: Scalar score (No Sigmoid for WGAN)
            nn.Conv2d(features_d * 4, 1, kernel_size=4, stride=2, padding=0),
        )

    def forward(self, x):
        return self.critic(x)

class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            nn.ConvTranspose2d(z_dim, features_g * 16, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(features_g * 16),
            nn.ReLU(),
            # Block 2
            nn.ConvTranspose2d(features_g * 16, features_g * 8, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_g * 8),
            nn.ReLU(),
            # Block 3
            nn.ConvTranspose2d(features_g * 8, features_g * 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(features_g * 4),
            nn.ReLU(),
            # Output: N x channels_img x 64 x 64
            nn.ConvTranspose2d(features_g * 4, channels_img, kernel_size=4, stride=2, padding=1),
            nn.Tanh(), # Output: [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)


gen = Generator(Z_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
critic = Critic(CHANNELS_IMG, FEATURES_CRITIC).to(device)

opt_gen = optim.RMSprop(gen.parameters(), lr=LEARNING_RATE)
opt_critic = optim.RMSprop(critic.parameters(), lr=LEARNING_RATE)


my_transforms = transforms.Compose(
    [
        transforms.Resize(IMAGE_SIZE),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

dataset = datasets.MNIST(root="dataset/", train=True, transform=my_transforms, download=True)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)


print(f"Starting training on {device}...")

for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.to(device)

        # --- TRAIN CRITIC ---
        for _ in range(CRITIC_ITERATIONS):
            noise = torch.randn(BATCH_SIZE, Z_DIM, 1, 1).to(device)
            fake = gen(noise)

            critic_real = critic(real).reshape(-1)
            critic_fake = critic(fake).reshape(-1)

            # WGAN Loss: Minimize -(Mean(Real) - Mean(Fake))
            loss_critic = -(torch.mean(critic_real) - torch.mean(critic_fake))

            critic.zero_grad()
            loss_critic.backward(retain_graph=True)
            opt_critic.step()

            # Weight Clipping
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        # --- TRAIN GENERATOR ---
        output = critic(fake).reshape(-1)
        loss_gen = -torch.mean(output)

        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # --- PRINT UPDATES ---
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch+1}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \t"
                f"Loss C: {loss_critic:.4f} Loss G: {loss_gen:.4f}"
            )

print("Training Finished Successfully!")

100%|██████████| 9.91M/9.91M [00:00<00:00, 12.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 344kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.17MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.93MB/s]


Starting training on cuda...
Epoch [1/5] Batch 0/938 	Loss C: -0.0007 Loss G: -0.0093
Epoch [1/5] Batch 100/938 	Loss C: -0.0416 Loss G: 0.0338
Epoch [1/5] Batch 200/938 	Loss C: -0.0558 Loss G: 0.0575
Epoch [1/5] Batch 300/938 	Loss C: -0.0611 Loss G: 0.0630
Epoch [1/5] Batch 400/938 	Loss C: -0.0627 Loss G: 0.0618
Epoch [1/5] Batch 500/938 	Loss C: -0.0634 Loss G: 0.0604
Epoch [1/5] Batch 600/938 	Loss C: -0.0640 Loss G: 0.0606
Epoch [1/5] Batch 700/938 	Loss C: -0.0644 Loss G: 0.0605
Epoch [1/5] Batch 800/938 	Loss C: -0.0643 Loss G: 0.0605
Epoch [1/5] Batch 900/938 	Loss C: -0.0649 Loss G: 0.0598
Epoch [2/5] Batch 0/938 	Loss C: -0.0649 Loss G: 0.0598
Epoch [2/5] Batch 100/938 	Loss C: -0.0651 Loss G: 0.0594
Epoch [2/5] Batch 200/938 	Loss C: -0.0648 Loss G: 0.0590
Epoch [2/5] Batch 300/938 	Loss C: -0.0648 Loss G: 0.0589
Epoch [2/5] Batch 400/938 	Loss C: -0.0648 Loss G: 0.0587
Epoch [2/5] Batch 500/938 	Loss C: -0.0647 Loss G: 0.0582
Epoch [2/5] Batch 600/938 	Loss C: -0.0650 Los