In [1]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import os

#  Hyperparameters 
latent_dim = 100
img_size = 28
channels = 1
batch_size = 64
lr = 0.00005  # lower lr for WGAN
epochs = 5
clip_value = 0.01  # weight clipping
n_critic = 5       # train critic more
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

os.makedirs("wgan_images", exist_ok=True)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])
dataloader = DataLoader(
    datasets.MNIST("./data", train=True, download=True, transform=transform),
    batch_size=batch_size,
    shuffle=True
)




100%|██████████| 9.91M/9.91M [00:00<00:00, 11.5MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 334kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.18MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.66MB/s]


In [3]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, channels * img_size * img_size),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img.view(z.size(0), channels, img_size, img_size)

In [4]:
class Critic(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(channels * img_size * img_size, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1)  # no sigmoid
        )

    def forward(self, img):
        return self.model(img.view(img.size(0), -1))

In [5]:
generator = Generator().to(device)
critic = Critic().to(device)

optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_C = optim.RMSprop(critic.parameters(), lr=lr)
generator
critic

Critic(
  (model): Sequential(
    (0): Linear(in_features=784, out_features=256, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=128, out_features=1, bias=True)
  )
)

In [6]:
epochs=1
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Real images
        real_imgs = imgs.to(device)

        # Train Critic
        for _ in range(n_critic):
            z = torch.randn(imgs.size(0), latent_dim).to(device)
            fake_imgs = generator(z).detach()

            loss_C = -torch.mean(critic(real_imgs)) + torch.mean(critic(fake_imgs))

            optimizer_C.zero_grad()
            loss_C.backward()
            optimizer_C.step()

            # weight clipping
            for p in critic.parameters():
                p.data.clamp_(-clip_value, clip_value)

        #Train Generator
        z = torch.randn(imgs.size(0), latent_dim).to(device)
        gen_imgs = generator(z)
        loss_G = -torch.mean(critic(gen_imgs))

        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # Print log
        if i % 200 == 0:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {loss_C.item():.4f}] [G loss: {loss_G.item():.4f}]")

    save_image(gen_imgs[:25], f"wgan_images/{epoch}.png", nrow=5, normalize=True)

[Epoch 0/1] [Batch 0/938] [D loss: -0.0121] [G loss: -0.0086]
[Epoch 0/1] [Batch 200/938] [D loss: -0.2752] [G loss: -3.9592]
[Epoch 0/1] [Batch 400/938] [D loss: -0.0728] [G loss: -3.7113]
[Epoch 0/1] [Batch 600/938] [D loss: -0.0363] [G loss: -3.3569]
[Epoch 0/1] [Batch 800/938] [D loss: -0.0242] [G loss: -3.0211]
