In [None]:
import os
import glob
import chainer
from chainer import Chain
from chainer import Variable
from chainer import training, report
from chainer import optimizers, iterators
from chainer.training import extensions
from chainer import cuda
import chainer.functions as F
import chainer.links as L

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_mldata
import numpy as np
import pandas as pd

import pickle
from PIL import Image

cuda.get_device_from_id(0).use()
xp = cuda.cupy

In [None]:
class Generator(Chain):
    def __init__(self, z_dim):
        super(Generator, self).__init__(
            l1=L.Linear(z_dim, 3 * 3 * 512),
            
            dc1=L.Deconvolution2D(512, 256, 2, stride=2, pad=1, ),
            dc2=L.Deconvolution2D(256, 128, 2, stride=2, pad=1, ),
            dc3=L.Deconvolution2D(128, 64, 2, stride=2, pad=1, ),
            dc4=L.Deconvolution2D(64, 1, 3, stride=3, pad=1,),
            
            bn1=L.BatchNormalization(512),
            bn2=L.BatchNormalization(256),
            bn3=L.BatchNormalization(128),
            bn4=L.BatchNormalization(64),
        )
        self.z_dim = z_dim
        
    def __call__(self, z):
        h = self.l1(z)
        h = F.reshape(h, (z.shape[0], 512, 3, 3))
        
        h1 = F.relu(self.bn1(h))
        h2 = F.relu(self.bn2(self.dc1(h1)))
        h3 = F.relu(self.bn3(self.dc2(h2)))
        h4 = F.relu(self.bn4(self.dc3(h3)))
        x = F.sigmoid(self.dc4(h4))
        return x

In [None]:
class Discriminator(Chain):
    def __init__(self, ):
        super(Discriminator, self).__init__(
            c1=L.Convolution2D(1, 64, 3, stride=3, pad=1, ),
            c2=L.Convolution2D(64, 128, 2, stride=2, pad=1, ),
            c3=L.Convolution2D(128, 256, 2, stride=2, pad=1, ),
            c4=L.Convolution2D(256, 512, 2, stride=2, pad=1, ),
            
            l1=L.Linear(3 * 3 * 512, 2),
            
            bn1=L.BatchNormalization(128),
            bn2=L.BatchNormalization(256),
            bn3=L.BatchNormalization(512),
        )
    
    def __call__(self, x):
        h1 = F.relu(self.c1(x))
        h2 = F.relu(self.bn1(self.c2(h1)))
        h3 = F.relu(self.bn2(self.c3(h2)))
        h4 = F.relu(self.bn3(self.c4(h3)))
        y = self.l1(h4)
        return y

In [None]:
class GANUpdater(training.StandardUpdater):
    def __init__(self, *args, **kwargs):
        self.gen, self.dis = kwargs.pop('models')
        super(GANUpdater, self).__init__(*args, **kwargs)
        
    def dis_loss(self, y_fake, y_real):
        batch_size = len(y_fake)
        real_loss = F.sum(F.softplus(-y_real)) / batch_size
        fake_loss = F.sum(F.softplus(y_fake)) / batch_size
        loss = real_loss + fake_loss
        report({'loss': loss}, self.dis)
        return loss
    
    def gen_loss(self, y_fake):
        batch_size = len(y_fake)
        loss = F.sum(F.softplus(-y_fake)) / batch_size
        report({'loss': loss}, self.gen)
        return loss
    
    def update_core(self):
        gen_optimizer = self.get_optimizer('gen')
        dis_optimizer = self.get_optimizer('dis')
        
        batch = self.get_iterator('main').next()
        x_real = Variable(self.converter(batch, self.device))
        batch_size = len(x_real)
        
        y_real = dis(x_real)
        
        z = xp.random.uniform(-1, 1, (batch_size, self.gen.z_dim))
        z = z.astype(dtype=xp.float32)
        x_fake = gen(z)
        y_fake = dis(x_fake)
        
        dis_optimizer.update(self.dis_loss, y_fake, y_real)
        gen_optimizer.update(self.gen_loss, y_fake)

In [None]:
def image(gen, rows, cols, dst):
    @training.make_extension()
    def make_image(trainer):
        n_images = rows * cols
        z = xp.random.uniform(-1, 1, (rows*cols, gen.z_dim))
        z = z.astype(dtype=xp.float32)
        with chainer.using_config('train', False):
            x = gen(z)
        x = cuda.to_cpu(x.data)

        x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
        _, _, H, W = x.shape
        x = x.reshape((rows, cols, 1, H, W))
        x = x.transpose(0, 3, 1, 4, 2)
        x = x.reshape((rows * H, cols * W))
        
        preview_dir = '{}/preview'.format(dst)
        preview_path = preview_dir +\
            '/image{:0>8}.png'.format(trainer.updater.iteration)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        Image.fromarray(x).save(preview_path)
    return make_image

In [None]:
gen = Generator(100)
dis = Discriminator()
gen.to_gpu()
dis.to_gpu()

gen_opt = optimizers.Adam()
gen_opt.setup(gen)
dis_opt = optimizers.Adam()
dis_opt.setup(dis)

data = fetch_mldata('MNIST original')
X = data['data']
n_train = X.shape[0]
X = np.array(X, dtype=np.float32)
X /= 255.
X = X.reshape(n_train, 1, 28, 28)
train = iterators.SerialIterator(X, 100)

updater = GANUpdater(models=(gen, dis), 
                     iterator=train, 
                     optimizer={'gen': gen_opt, 'dis': dis_opt}, 
                     device=0)

trainer = training.Trainer(updater, (1000, 'epoch'), out='result')

snapshot_interval = (100, 'epoch')
display_interval = (1000, 'iteration')
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
    'epoch', 'iteration', 'gen/loss', 'dis/loss',
]), trigger=display_interval)

trainer.extend(image(gen, 10, 10, 'result'),trigger=snapshot_interval)

trainer.run()