# Wasserstein GAN in Pytorch using LSUN Dataset

In [None]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *

In [None]:
import torch_utils; importlib.reload(torch_utils)
from torch_utils import *

## Download and Process Dataset

Download the LSUN scene classification dataset bedroom category, unzip it, and convert it to JPG files (the scripts folder is here in the `dl2_2017` folder):

In [None]:
%mkdir ~/lsun_data

In [None]:
!curl 'http://lsun.cs.princeton.edu/htbin/download.cgi?tag=latest&category=bedroom&set=train' -o ~/lsun_data/bedroom_train_lmdb.zip

In [None]:
!ls -lh ~/lsun_data

In [None]:
!unzip ~/lsun_data/bedroom_train_lmdb.zip

In [None]:
!ls -lh

The good news is that in the last month the GAN training problem has been solved! [This paper](https://arxiv.org/abs/1701.07875) shows a minor change to the loss function and constraining the weights allows a GAN to reliably learn following a consistent loss schedule.

First, we, set up batch size, image size, and size of noise vector:

In [None]:
bs, sz, nz = 64, 64, 100 # nz is the size of the latent z vector

In [None]:
# My own codes
# Where to store samples and models
experiment_path = 'wgan_samples'

In [None]:
os.system('mkdir {0}'.format(experiment_path))

In [None]:
# Fix seed
manual_seed = random.randint(1, 10000)
print(manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

In [None]:
cudnn.benchmark = True

Pytorch has the handy [torch-vision](https://github.com/pytorch/vision) library which makes handling images fast and easy.

In [None]:
PATH = 'data/cifar10'
data = datasets.CIFAR10(root=PATH, download=True,
    transform=transforms.Compose([
        transforms.Scale(sz),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])
)

Even parallel processing is handling automatically by torch-vision.

In [None]:
dataloader = torch.utils.data.DataLoader(data, bs, True, num_workers=8)
n = len(dataloader)
n

Our activation function will be `tanh`, so we need to do some processing to view the generated images.

In [None]:
def show(img, fs=(6,6)):
    plt.figure(figsize=fs)
    plt.imshow(np.transpose((img / 2 + 0.5).clamp(0, 1).numpy(), (1, 2, 0)), interpolation='nearest')

## Create model

The CNN definitions are a little big for a notebook, so we import them.

In [None]:
import dcgan; importlib.reload(dcgan)
from dcgan import DCGAN_D, DCGAN_G

Pytorch uses `module.apply()` for picking an initializer.

In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
# nc is input image channels = 3
# ngf is number of generator filter = 64
# ngpu is number of GPUs to use = 1
# n_extra_layers is number of extra layers on gen and disc = 1
netG = DCGAN_G(sz, nz, 3, 64, 1, 1).cuda()
netG.apply(weights_init)

In [None]:
netD = DCGAN_D(sz, 3, 64, 1, 1).cuda()
netD.apply(weights_init)

Just some shortcuts to create tensors and variables.

### Continue Training (my own codes)

In [None]:
netG_checkpoint = experiment_path + '/netG_epoch_1.pth'
netD_checkpoint = experiment_path + '/netD_epoch_1.pth'

# set path to netG_checkpoint (to continue training)
netG_model = '' # netG_checkpoint
netD_model = '' # netD_checkpoint

if netG_model != '':
    netG.load_state_dict(torch.load(netG_model))
    print('continue training generator/actor')

if netD_model != '':
    netD.load_state_dict(torch.load(netD_model))
    print('continue training discriminator/critic')

In [None]:
from torch import FloatTensor as FT

In [None]:
def Var(*params):
    return Variable( FT(*params).cuda() )

In [None]:
def create_noise(b): 
    return Variable( FT(b, nz, 1, 1).cuda().normal_(0, 1) )

In [None]:
# Input placeholder
input = Var(bs, 3, sz, nz)

# Fixed noise used just for visualizing images when done
fixed_noise = create_noise(bs)

# The numbers 0 and -1
one = torch.FloatTensor([1]).cuda()
mone = one * -1

An optimizer needs to be told what variables to optimize. A module automatically keeps track of its variables.

In [None]:
optimizerD = optim.RMSprop(netD.parameters(), lr = 1e-4)
optimizerG = optim.RMSprop(netG.parameters(), lr = 1e-4)

One forward step and one backward step for D

In [None]:
def step_D(v, init_grad):
    err = netD(v)
    err.backward(init_grad)
    return err

In [None]:
def make_trainable(net, val):
    for p in net.parameters(): # reset requires_grad
        p.requires_grad = val # they are set to False below in netG update

In [None]:
def train(niter, first=True):
    
    gen_iterations = 0
    
    for epoch in range(niter):
        data_iter = iter(dataloader)
        i = 0
        
        while i < n:
            ###########################
            # (1) Update D network
            ###########################
            make_trainable(netD, True)
            
            # train the discriminator d_iters times
            d_iters = (100 if first and (gen_iterations < 25) or gen_iterations % 500 == 0 
                       else 5) # 5 is number of D iters per each G iter

            j = 0
            
            while j < d_iters and i < n:
                j += 1
                i += 1
                
                # clamp parameters to a cube
                for p in netD.parameters():
                    p.data.clamp_(-0.01, 0.01)

                # my own codes
                data = next(data_iter)
                
                # train with real
                real_cpu, _ = data # my own codes
                real_cpu = real_cpu.cuda() # my own codes
                real = Variable( data[0].cuda() )
                netD.zero_grad()
                errD_real = step_D(real, one)

                # train with fake
                fake = netG( create_noise(real.size()[0]) )
                input.data.resize_(real.size()).copy_(fake.data)
                errD_fake = step_D(input, mone)
                errD = errD_real - errD_fake
                optimizerD.step()

            ###########################
            # (2) Update G network
            ###########################
            make_trainable(netD, False)
            netG.zero_grad()
            errG = step_D(netG(create_noise(bs)), one)
            optimizerG.step()
            gen_iterations += 1
            
            print('[%d/%d][%d/%d][%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f'
                % (epoch, niter, i, n, gen_iterations,
                errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

            if gen_iterations % 8 == 0: # every 500 generator iterations
                print('saving real and fake images...')
                real_cpu = real_cpu.mul(0.5).add(0.5)
                vutils.save_image(real_cpu, '{0}/real_samples.png'.format(experiment_path))
                fake = netG(create_noise(bs)) # create_noise replaced these codes: Variable(fixed_noise, volatile=True)
                fake.data = fake.data.mul(0.5).add(0.5)
                vutils.save_image(fake.data, '{0}/fake_samples_{1}.png'.format(experiment_path, gen_iterations))
          
        # do checkpointing
        torch.save(netG.state_dict(), '{0}/netG_epoch_{1}.pth'.format(experiment_path, epoch))
        torch.save(netD.state_dict(), '{0}/netD_epoch_{1}.pth'.format(experiment_path, epoch))
        
#         print('[%d/%d][%d/%d] Loss_D: %f Loss_G: %f Loss_D_real: %f Loss_D_fake %f' % (
#             epoch, niter, gen_iterations, n,
#             errD.data[0], errG.data[0], errD_real.data[0], errD_fake.data[0]))

In [None]:
%time train(200, True)

## View

In [None]:
fake = netG(fixed_noise).data.cpu()

Generated images by Generator.

In [None]:
show(vutils.make_grid(fake))

Real images from dataset.

In [None]:
show(vutils.make_grid(iter(dataloader).next()[0]))