In [0]:
import argparse

import numpy as np
import chainer
from chainer import cuda, Function, gradient_check, report, training, utils, Variable
from chainer import datasets, iterators, optimizers, serializers, reporter
from chainer import Link, Chain, ChainList
from chainer.dataset import convert
from chainer.dataset import iterator as iterator_module
import chainer.functions as F
import chainer.links as L
from chainer.training import extensions

from scipy.misc import imsave




In [0]:
class Generator(Chain):

    def __init__(self, z_dim):
        super(Generator, self).__init__(
            # in_ch,out_ch,ksize,stride,pad
            l1=L.Deconvolution2D(z_dim, 128, 3, 2, 0),
            bn1=L.BatchNormalization(128),
            l2=L.Deconvolution2D(128, 128, 3, 2, 1),
            bn2=L.BatchNormalization(128),
            l3=L.Deconvolution2D(128, 128, 3, 2, 1),
            bn3=L.BatchNormalization(128),
            l4=L.Deconvolution2D(128, 128, 3, 2, 2),
            bn4=L.BatchNormalization(128),
            l5=L.Deconvolution2D(128, 1, 3, 2, 2, outsize=(28, 28)),
        )
        self.train = chainer.using_config('train', True)

    def __call__(self, z):
        h = self.bn1(F.relu(self.l1(z)))
        h = self.bn2(F.relu(self.l2(h)))
        h = self.bn3(F.relu(self.l3(h)))
        h = self.bn4(F.relu(self.l4(h)))
        x = F.sigmoid(self.l5(h))
        return x




In [0]:
class Discriminator(Chain):

    def __init__(self):
        super(Discriminator, self).__init__(
            # in_ch,out_ch,ksize,stride,pad
            l1=L.Convolution2D(None, 32, 3, 2, 1),
            bn1=L.BatchNormalization(32),
            l2=L.Convolution2D(None, 32, 3, 2, 2),
            bn2=L.BatchNormalization(32),
            l3=L.Convolution2D(None, 32, 3, 2, 1),
            bn3=L.BatchNormalization(32),
            l4=L.Convolution2D(None, 32, 3, 2, 1),
            bn4=L.BatchNormalization(32),
            l5=L.Convolution2D(None, 1, 3, 2, 1),
        )
        self.train = True

    def __call__(self, x):
        h = self.bn1(F.leaky_relu(self.l1(x)))
        h = self.bn2(F.leaky_relu(self.l2(h)))
        h = self.bn3(F.leaky_relu(self.l3(h)))
        h = self.bn4(F.leaky_relu(self.l4(h)))
        y = self.l5(h)
        return y




In [0]:
class GAN_Updater(training.StandardUpdater):

    def __init__(self, iterator, generator, discriminator, optimizers,
                 converter=convert.concat_examples, device=None, z_dim=2,):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator
        self.gen = generator
        self.dis = discriminator
        self._optimizers = optimizers
        self.converter = converter
        self.device = device

        self.iteration = 0

        self.z_dim = z_dim

    def update_core(self):
        batch = self._iterators['main'].next()
        in_arrays = self.converter(batch, self.device)
        x_data = in_arrays

        batchsize = x_data.shape[0]
        z = Variable(cuda.cupy.random.normal(
            size=(batchsize, self.z_dim, 1, 1), dtype=np.float32))
        global x_gen
        x_gen = self.gen(z)

        # concatしないままdisに通すと、bnが悪さをする
        x = F.concat((x_gen, x_data), 0)
        y = self.dis(x)
        y_gen, y_data = F.split_axis(y, 2, 0)

        # sigmoid_cross_entropy(x, 0) == softplus(x)
        # sigmoid_cross_entropy(x, 1) == softplus(-x)
        loss_gen = F.sum(F.softplus(-y_gen))
        loss_data = F.sum(F.softplus(y_data))
        loss = (loss_gen + loss_data) / batchsize

        for optimizer in self._optimizers.values():
            optimizer.target.cleargrads()

        # compute gradients all at once
        loss.backward()

        for optimizer in self._optimizers.values():
            optimizer.update()

        reporter.report(
            {'loss': loss,
             'loss_gen': loss_gen / batchsize,
             'loss_data': loss_data / batchsize})




In [0]:
def save_x(x_gen, filename):
    x_gen_img = cuda.to_cpu(x_gen.data)
    n = x_gen_img.shape[0]
    n = n // 15 * 15
    x_gen_img = x_gen_img[:n]
    x_gen_img = x_gen_img.reshape(
        15, -1, 28, 28).transpose(1, 2, 0, 3).reshape(-1, 15 * 28)
    imsave(filename, x_gen_img)
    
def save_example(trainer):
    snapshots.append(trainer.updater.gen(fixed_z))
    
def save_images():
  count = 0
  for i in snapshots:
    count += 1
    save_x(i, 'timeline/pictures/epoch'+str(count)+'.png')




In [0]:
def main(z_dim=2, gpu=0, nb_epoch=30, batchsize=64):

    """
    print('GPU: {}'.format(args.gpu))
    print('# z_dim: {}'.format(args.z_dim))
    print('# Minibatch-size: {}'.format(args.batchsize))
    print('# epoch: {}'.format(args.epoch))
    print('')
    """
    global snapshots
    snapshots = []
    global fixed_z
    fixed_z = Variable(cuda.cupy.random.normal(
            size=(batchsize, z_dim, 1, 1), dtype=np.float32))

    gen = Generator(2)
    dis = Discriminator()
    gen.to_gpu()
    dis.to_gpu()

    opt = {'gen': optimizers.Adam(alpha=-0.001, beta1=0.5),  # alphaの符号が重要
           'dis': optimizers.Adam(alpha=0.001, beta1=0.5)}
    opt['gen'].setup(gen)
    opt['dis'].setup(dis)

    train, test = datasets.get_mnist(withlabel=False, ndim=3)

    train_iter = iterators.SerialIterator(train, batch_size=200)

    updater = GAN_Updater(train_iter, gen, dis, opt,
                          device=gpu, z_dim=2)
    trainer = training.Trainer(updater, (nb_epoch, 'epoch'), out='result')

    trainer.extend(extensions.LogReport())
    trainer.extend(extensions.snapshot(), trigger=(30, 'epoch'))
    trainer.extend(extensions.PlotReport(['loss_gen', 'loss_data'], file_name='loss.png',))
    trainer.extend(extensions.PrintReport(
        ['epoch', 'loss', 'loss_gen', 'loss_data']))
    trainer.extend(save_example, trigger=(1, 'epoch'))
    
    #trainer.extend(extensions.ProgressBar(update_interval=100))

    if '':
        # Resume from a snapshot
        chainer.serializers.load_npz('', trainer)

    trainer.run()
    save_images()
    

    

In [78]:
main()

epoch       loss        loss_gen    loss_data 
[J1           1.22873     0.701407    0.52732     
[J2           1.03184     0.527082    0.504759    
[J3           0.800337    0.395721    0.404615    
[J4           0.702052    0.337219    0.364833    
[J5           0.859973    0.408609    0.451365    
[J6           0.955718    0.457965    0.497753    
[J7           1.03378     0.495346    0.538436    
[J8           0.950061    0.450014    0.500048    
[J9           1.04263     0.496393    0.546242    
[J10          1.04526     0.493772    0.551493    
[J11          1.07545     0.514326    0.561121    
[J12          1.08063     0.517324    0.563301    
[J13          1.01669     0.481861    0.534827    
[J14          0.914082    0.431521    0.482561    
[J15          0.862713    0.400382    0.462331    
[J16          0.842553    0.39064     0.451914    
[J17          0.792323    0.362424    0.429898    
[J18          0.769924    0.353797    0.416128    
[J19          0.