- Modified from https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/wgan/wgan.py

In [1]:
import math
import os

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
n_epochs = 200
batch_size = 64
latent_dim = 100
img_size = 28
channels = 1
sample_interval = 400

In [3]:
os.makedirs("../data/mnist", exist_ok=True)

dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(img_size), 
             transforms.ToTensor(), 
             transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)

In [4]:
sample_img, label = next(iter(dataloader))

In [10]:
class GeneratorLinear(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        def block(in_feature, out_feature, normalize=True):
            layers = [nn.Linear(in_feature, out_feature)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feature))
            layers.append(nn.LeakyReLU(0.1))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, img_size * img_size * channels),
            nn.Tanh()
            )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.shape[0], channels, img_size, img_size)
        return img

In [11]:
class DiscriminatorLinear(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(img_size * img_size * channels, 512),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.1, inplace=True),
            nn.Linear(256, 1),
            # nn.Sigmoid(),  -- no sigmoid in WGAN
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

In [12]:
# loss and optimizers
adversarial_loss = torch.nn.BCELoss()

generator = GeneratorLinear()
discriminator = DiscriminatorLinear()

optimizer_G = torch.optim.RMSprop(generator.parameters(), lr=5e-4)
optimizer_D = torch.optim.RMSprop(discriminator.parameters(), lr=5e-4)

In [None]:
# training
n_epochs = 20

for epoch in range(n_epochs):
    for i, (imgs, _) in enumerate(dataloader):
        valid = torch.empty((batch_size, 1), dtype=torch.float).fill_(1)
        fake = torch.empty((batch_size, 1), dtype=torch.float).fill_(0)
        
        # generate images
        z = torch.randn((batch_size, latent_dim))
        gen_imgs = generator(z).data
        
        # train discriminator
        optimizer_D.zero_grad()
        loss_D = -torch.mean(discriminator(imgs)) + torch.mean(discriminator(gen_imgs))
        loss_D.backward()
        optimizer_D.step()
        
        # clip weights
        for p in discriminator.parameters():
            p.data.clamp_(-0.01, 0.01)
        
        if i % 5 == 0:
            # train generator every 5 batch
            optimizer_G.zero_grad()

            loss_G = -torch.mean(discriminator(gen_imgs))
            loss_G.backward()
            optimizer_G.step()
        
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, n_epochs, i, len(dataloader), loss_D.item(), loss_G.item())
        )
    
    print('Generated image at epoch {}'.format(epoch))
    sample_image = gen_imgs.detach()[0].squeeze()
    plt.imshow(sample_image)
    plt.show()