# Wasserstein GAN in Pytorch using LSUN Dataset

In [None]:
%matplotlib inline
import importlib

In [None]:
import torch_utils; importlib.reload(torch_utils)
from torch_utils import *

In [None]:
import os, random

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
bs, sz, nz = 1000, 64, 100 # nz is the size of the latent z vector

In [None]:
# Fix seed
manual_seed = 5164#random.randint(1, 10000)
print(manual_seed)
random.seed(manual_seed)
torch.manual_seed(manual_seed)

In [None]:
cudnn.benchmark = True

In [None]:
def show(img, fs=(6,6)):
    plt.figure(figsize=fs)
    plt.imshow(np.transpose((img / 2 + 0.5).clamp(0, 1).numpy(), (1, 2, 0)), interpolation='nearest')

## Create model

The CNN definitions are a little big for a notebook, so we import them.

In [None]:
import dcgan; importlib.reload(dcgan)
from dcgan import DCGAN_D, DCGAN_G

Pytorch uses `module.apply()` for picking an initializer.

In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        m.weight.data.normal_(0.0, 0.02)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [None]:
# nc is input image channels = 3
# ngf is number of generator filter = 64
# ngpu is number of GPUs to use = 1
# n_extra_layers is number of extra layers on gen and disc = 1
netG = DCGAN_G(sz, nz, 3, 64, 1, 1).cuda()
netG.apply(weights_init)

In [None]:
netD = DCGAN_D(sz, 3, 64, 1, 1).cuda()
netD.apply(weights_init)

Just some shortcuts to create tensors and variables.

### Continue Training (custom codes)

In [None]:
netG_checkpoint = 'netG_epoch_1.pth'
netD_checkpoint = 'netD_epoch_1.pth'

# set path to netG_checkpoint (to continue training)
netG_model = netG_checkpoint
netD_model = netD_checkpoint

if netG_model != '':
    sd = torch.load(netG_model)
    new_sd = {}
    for key,value in sd.items():
        key = key.split('.')
        #print(key[0]+"."+"".join(key[1:-1])+"."+key[-1])
        new_sd[key[0]+"."+"".join(key[1:-1])+"."+key[-1]] = value
    netG.load_state_dict(new_sd)
    print('continue training generator/actor')

if netD_model != '':
    sd = torch.load(netD_model)
    new_sd = {}
    for key,value in sd.items():
        key = key.split('.')
        #print(key[0]+"."+"".join(key[1:-1])+"."+key[-1])
        new_sd[key[0]+"."+"".join(key[1:-1])+"."+key[-1]] = value
    netD.load_state_dict(new_sd)
    #netD.load_state_dict(torch.load(netD_model))
    print('continue training discriminator/critic')

In [None]:
from torch import FloatTensor as FT

In [None]:
def Var(*params):
    return Variable( FT(*params).cuda() )

In [None]:
def create_noise(b): 
    return Variable( FT(b, nz, 1, 1).cuda().normal_(0, 1) )

In [None]:
# Input placeholder
input = Var(bs, 3, sz, nz)

# Fixed noise used just for visualizing images when done
fixed_noise = create_noise(bs)

# The numbers 0 and -1
one = torch.FloatTensor([1]).cuda()
mone = one * -1

## Save

In [None]:
fake = netG(fixed_noise).data.cpu()

In [None]:
npfake = fake.numpy()

In [None]:
import numpy as np

In [None]:
np.reshape(npfake,(1000,-1)).tofile('wgan_lsun_gen.bin')