## 諸々インポート

In [98]:
import os
import chainer
from chainer import Chain
from chainer import Variable
from chainer import training, report
from chainer import optimizers, iterators
from chainer.training import extensions
from chainer import cuda
import chainer.functions as F
import chainer.links as L

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_mldata
import numpy as np
import pandas as pd

import pickle
from PIL import Image

cuda.get_device_from_id(0).use()
xp = cuda.cupy

## Generatorのクラス

In [99]:
class Generator(Chain):
    def __init__(self, z_dim):
        super(Generator, self).__init__(
            l1=L.Linear(z_dim, 3 * 3 * 512),
            
            dc1=L.Deconvolution2D(512, 256, 2, stride=2, pad=1, ),
            dc2=L.Deconvolution2D(256, 128, 2, stride=2, pad=1, ),
            dc3=L.Deconvolution2D(128, 64, 2, stride=2, pad=1, ),
            dc4=L.Deconvolution2D(64, 1, 3, stride=3, pad=1,),
            
            bn1=L.BatchNormalization(512),
            bn2=L.BatchNormalization(256),
            bn3=L.BatchNormalization(128),
            bn4=L.BatchNormalization(64),
        )
        self.z_dim = z_dim
        
    def __call__(self, z):
        h = self.l1(z)
        h = F.reshape(h, (z.shape[0], 512, 3, 3))
        
        h1 = F.relu(self.bn1(h))
        h2 = F.relu(self.bn2(self.dc1(h1)))
        h3 = F.relu(self.bn3(self.dc2(h2)))
        h4 = F.relu(self.bn4(self.dc3(h3)))
        x = F.sigmoid(self.dc4(h4))
        return x


## Discriminatorのクラス

In [100]:
class Discriminator(Chain):
    def __init__(self, ):
        super(Discriminator, self).__init__(
            c1=L.Convolution2D(1, 64, 3, stride=3, pad=1, ),
            c2=L.Convolution2D(64, 128, 2, stride=2, pad=1, ),
            c3=L.Convolution2D(128, 256, 2, stride=2, pad=1, ),
            c4=L.Convolution2D(256, 512, 2, stride=2, pad=1, ),
            
            l1=L.Linear(3 * 3 * 512, 2),
            
            bn1=L.BatchNormalization(128),
            bn2=L.BatchNormalization(256),
            bn3=L.BatchNormalization(512),
        )
    
    def __call__(self, x):
        h1 = F.relu(self.c1(x))
        h2 = F.relu(self.bn1(self.c2(h1)))
        h3 = F.relu(self.bn2(self.c3(h2)))
        h4 = F.relu(self.bn3(self.c4(h3)))
        y = self.l1(h4)
        return y


## Updaterのクラス

In [101]:
class GANUpdater(training.StandardUpdater):
    def __init__(self, *args, **kwargs):
        self.gen, self.dis = kwargs.pop('models')
        super(GANUpdater, self).__init__(*args, **kwargs)
        
    def dis_loss(self, y_fake, y_real):
        batch_size = len(y_fake)
        real_loss = F.sum(F.softplus(-y_real)) / batch_size
        fake_loss = F.sum(F.softplus(y_fake)) / batch_size
        loss = real_loss + fake_loss
        report({'loss': loss}, self.dis)
        return loss
    
    def gen_loss(self, y_fake):
        batch_size = len(y_fake)
        loss = F.sum(F.softplus(-y_fake)) / batch_size
        report({'loss': loss}, self.gen)
        return loss
    
    def update_core(self):
        gen_optimizer = self.get_optimizer('gen')
        dis_optimizer = self.get_optimizer('dis')
        
        batch = self.get_iterator('main').next()
        x_real = Variable(self.converter(batch, self.device))
        batch_size = len(x_real)
        
        y_real = dis(x_real)
        
        z = xp.random.uniform(-1, 1, (batch_size, self.gen.z_dim))
        z = z.astype(dtype=xp.float32)
        x_fake = gen(z)
        y_fake = dis(x_fake)
        
        dis_optimizer.update(self.dis_loss, y_fake, y_real)
        gen_optimizer.update(self.gen_loss, y_fake)
        

## 表示用のクラス

In [110]:
def image(gen, rows, cols, dst):
    @training.make_extension()
    def make_image(trainer):
        n_images = rows * cols
        z = xp.random.uniform(-1, 1, (rows*cols, gen.z_dim))
        z = z.astype(dtype=xp.float32)
        with chainer.using_config('train', False):
            x = gen(z)
        x = cuda.to_cpu(x.data)

        x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
        _, _, H, W = x.shape
        x = x.reshape((rows, cols, 1, H, W))
        x = x.transpose(0, 3, 1, 4, 2)
        x = x.reshape((rows * H, cols * W))
        
        preview_dir = '{}/preview'.format(dst)
        preview_path = preview_dir +\
            '/image{:0>8}.png'.format(trainer.updater.iteration)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        Image.fromarray(x).save(preview_path)
    return make_image

## 実行する

In [112]:

gen = Generator(100)
dis = Discriminator()
gen.to_gpu()
dis.to_gpu()

gen_opt = optimizers.Adam()
gen_opt.setup(gen)
dis_opt = optimizers.Adam()
dis_opt.setup(dis)

data = fetch_mldata('MNIST original')
X = data['data']
n_train = X.shape[0]
X = np.array(X, dtype=np.float32)
X /= 255.
X = X.reshape(n_train, 1, 28, 28)
train = iterators.SerialIterator(X, 100)

updater = GANUpdater(models=(gen, dis), 
                     iterator=train, 
                     optimizer={'gen': gen_opt, 'dis': dis_opt}, 
                     device=0)

trainer = training.Trainer(updater, (1000, 'epoch'), out='result')

snapshot_interval = (100, 'epoch')
display_interval = (1000, 'iteration')
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
    'epoch', 'iteration', 'gen/loss', 'dis/loss',
]), trigger=display_interval)

trainer.extend(image(gen, 10, 10, 'result'),trigger=snapshot_interval)

trainer.run()

epoch       iteration   gen/loss    dis/loss  
[J1           1000        16.5661     0.175634    
[J2           2000        12.879      0.0863296   
[J4           3000        13.1816     0.0252714   
[J5           4000        12.7493     0.0686163   
[J7           5000        12.6176     0.052469    
[J8           6000        12.6163     0.0324581   
[J10          7000        13.9916     0.0462195   
[J11          8000        13.6381     0.0782531   
[J12          9000        13.1613     0.0437378   
[J14          10000       16.1907     0.0670317   
[J15          11000       14.2653     0.0998733   
[J17          12000       13.4764     0.0499048   
[J18          13000       14.9891     0.0482959   
[J20          14000       14.3505     0.0554001   
[J21          15000       14.3668     0.0716975   
[J22          16000       15.145      0.0485878   
[J24          17000       17.4748     0.072485    
[J25          18000       17.8919     0.0688703   
[J27          19

[J225         158000      21.8602     0.00120125  
[J227         159000      27.5523     0.0180425   
[J228         160000      21.0085     0.000724959  
[J230         161000      28.6809     0.0571622   
[J231         162000      23.56       0.00664168  
[J232         163000      25.2247     0.00170635  
[J234         164000      24.8755     0.0418254   
[J235         165000      22.0476     0.00114547  
[J237         166000      28.8246     0.00763419  
[J238         167000      26.4022     0.0110411   
[J240         168000      37.6091     0.0374194   
[J241         169000      29.5827     0.0224025   
[J242         170000      23.4183     0.000899857  
[J244         171000      27.0853     0.000232854  
[J245         172000      32.4326     0.00112034  
[J247         173000      33.6471     0.00614139  
[J248         174000      29.5073     0.0103257   
[J250         175000      28.411      0.00747236  
[J251         176000      35.4677     0.0227466   
[J252   

[J450         315000      25.5821     0.00458547  
[J451         316000      28.7082     0.000621169  
[J452         317000      39.3803     0.000680988  
[J454         318000      27.5003     7.82932e-05  
[J455         319000      41.1147     0.0194668   
[J457         320000      27.8447     0.0244395   
[J458         321000      28.4635     0.0054393   
[J460         322000      35.2076     0.0509256   
[J461         323000      26.9791     0.0258404   
[J462         324000      26.7848     0.0299675   
[J464         325000      28.1988     0.0530028   
[J465         326000      24.2984     0.0587738   
[J467         327000      33.4349     0.00525445  
[J468         328000      31.048      0.0377211   
[J470         329000      22.341      0.00347228  
[J471         330000      29.6172     0.0336097   
[J472         331000      31.8711     0.00981541  
[J474         332000      24.0112     0.00408385  
[J475         333000      26.6179     0.00536662  
[J477   

[J675         473000      51.8656     0.00210234  
[J677         474000      41.6084     0.00146734  
[J678         475000      48.9212     0.0265705   
[J680         476000      30.6244     0.0446708   
[J681         477000      30.6517     0.0134021   
[J682         478000      43.4764     0.00542749  
[J684         479000      32.9173     0.0120441   
[J685         480000      31.5342     0.000885432  
[J687         481000      29.3712     0.00513391  
[J688         482000      36.0817     0.000578843  
[J690         483000      53.5915     0.0481421   
[J691         484000      29.1859     0.000534121  
[J692         485000      50.6289     0.0445426   
[J694         486000      28.5274     0.00949959  
[J695         487000      26.3331     0.00299722  
[J697         488000      29.6632     0.000181699  
[J698         489000      28.3569     0.0338922   
[J700         490000      37.0182     0.00238565  
[J701         491000      33.7247     0.0420057   
[J702  

[J900         630000      25.6329     0.00145187  
[J901         631000      34.2517     0.000627338  
[J902         632000      35.6207     0.000871813  
[J904         633000      29.5928     0.00142254  
[J905         634000      35.9482     0.00536371  
[J907         635000      34.5071     0.00441378  
[J908         636000      60.4216     0.000708412  
[J910         637000      43.1956     0.0415055   
[J911         638000      28.3414     0.00866326  
[J912         639000      35.1822     0.0394872   
[J914         640000      32.1962     0.0059312   
[J915         641000      36.3234     0.00872329  
[J917         642000      33.8589     0.0148965   
[J918         643000      37.6228     0.0137768   
[J920         644000      31.6651     0.0273433   
[J921         645000      28.2245     0.0400762   
[J922         646000      40.1925     0.00933341  
[J924         647000      45.7651     0.00277394  
[J925         648000      46.2363     0.127629    
[J927   