In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

# ========== Hyperparameters ==========
latent_dim = 100
img_size = 28
channels = 1
batch_size = 64
lr = 0.0002
epochs = 5
lambda_gp = 10     # gradient penalty
n_critic = 5       # train critic more
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ========== DataLoader ==========
os.makedirs("wgan_gp_images", exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataloader = DataLoader(
    datasets.MNIST("./data", train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)

# ========== Generator ==========
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, channels * img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), channels, img_size, img_size)

#  Critic 
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(channels * img_size * img_size, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1)
        )

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))

#  Gradient Penalty 
def gradient_penalty(critic, real_imgs, fake_imgs):
    batch_size = real_imgs.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device, requires_grad=True)
    interpolates = (epsilon * real_imgs + (1 - epsilon) * fake_imgs).requires_grad_(True)

    d_interpolates = critic(interpolates)
    fake = torch.ones(batch_size, 1, device=device, requires_grad=False)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    gp = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gp

#  Initialize 
generator = Generator().to(device)
critic = Critic().to(device)

optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.9))
optimizer_C = optim.Adam(critic.parameters(), lr=lr, betas=(0.5, 0.9))

#  Training 
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):

        real_imgs = imgs.to(device)

        #  Train Critic 
        for _ in range(n_critic):
            z = torch.randn(imgs.size(0), latent_dim).to(device)
            fake_imgs = generator(z).detach()

            real_validity = critic(real_imgs)
            fake_validity = critic(fake_imgs)

            gp = gradient_penalty(critic, real_imgs, fake_imgs)

            loss_C = -torch.mean(real_validity) + torch.mean(fake_validity) + lambda_gp * gp

            optimizer_C.zero_grad()
            loss_C.backward()
            optimizer_C.step()

        # ================== Train Generator 
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        gen_imgs = generator(z)
        loss_G = -torch.mean(critic(gen_imgs))

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Print log
        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {loss_C.item():.4f}] [G loss: {loss_G.item():.4f}]")

    save_image(gen_imgs[:25], f"wgan_gp_images/{epoch}.png", nrow=5, normalize=True)


100%|██████████| 9.91M/9.91M [00:00<00:00, 58.7MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.76MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.4MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.29MB/s]


[Epoch 0/5] [Batch 0/938] [D loss: 6.7242] [G loss: 0.0778]
[Epoch 0/5] [Batch 200/938] [D loss: -5.2687] [G loss: -1.9145]
[Epoch 0/5] [Batch 400/938] [D loss: -4.5642] [G loss: -3.5412]
[Epoch 0/5] [Batch 600/938] [D loss: -4.0072] [G loss: -3.5285]
[Epoch 0/5] [Batch 800/938] [D loss: -3.1602] [G loss: -4.9549]
[Epoch 1/5] [Batch 0/938] [D loss: -3.4199] [G loss: -4.1138]
[Epoch 1/5] [Batch 200/938] [D loss: -3.4154] [G loss: -4.9841]
[Epoch 1/5] [Batch 400/938] [D loss: -3.0889] [G loss: -4.3835]
[Epoch 1/5] [Batch 600/938] [D loss: -3.1961] [G loss: -4.9686]
[Epoch 1/5] [Batch 800/938] [D loss: -2.6893] [G loss: -5.9364]
[Epoch 2/5] [Batch 0/938] [D loss: -2.7351] [G loss: -4.3204]
[Epoch 2/5] [Batch 200/938] [D loss: -2.8105] [G loss: -5.3265]
[Epoch 2/5] [Batch 400/938] [D loss: -2.8542] [G loss: -5.7246]
[Epoch 2/5] [Batch 600/938] [D loss: -2.6555] [G loss: -3.9416]
[Epoch 2/5] [Batch 800/938] [D loss: -2.8085] [G loss: -4.2201]
[Epoch 3/5] [Batch 0/938] [D loss: -2.7496] [G l