## 諸々インポート

In [1]:
import os
import chainer
from chainer import Chain
from chainer import Variable
from chainer import training, report
from chainer import optimizers, iterators
from chainer.training import extensions
from chainer import cuda
import chainer.functions as F
import chainer.links as L

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_mldata
import numpy as np
import pandas as pd

import pickle
from PIL import Image

cuda.get_device_from_id(0).use()
xp = cuda.cupy

In [2]:
import struct
from PIL import Image

## Generatorのクラス

In [41]:
class Generator(Chain):
    def __init__(self, z_dim, wscale=0.02):
        super(Generator, self).__init__()
        with self.init_scope():
            w = chainer.initializers.Normal(wscale)
            self.l1=L.Linear(z_dim, 3 * 3 * 512)
            
            self.dc1=L.Deconvolution2D(512, 256, 2, stride=2, pad=1, initialW=w)
            self.dc2=L.Deconvolution2D(256, 128, 2, stride=2, pad=1, initialW=w)
            self.dc3=L.Deconvolution2D(128, 64, 2, stride=2, pad=1, initialW=w)
            self.dc4=L.Deconvolution2D(64, 3, 7, stride=3, pad=1, initialW=w)
            
            self.bn1=L.BatchNormalization(512)
            self.bn2=L.BatchNormalization(256)
            self.bn3=L.BatchNormalization(128)
            self.bn4=L.BatchNormalization(64)
        self.z_dim = z_dim
        
    def __call__(self, z):
        h = self.l1(z)
        h = F.reshape(h, (z.shape[0], 512, 3, 3))
        
        h1 = F.relu(self.bn1(h))
        h2 = F.relu(self.bn2(self.dc1(h1)))
        h3 = F.relu(self.bn3(self.dc2(h2)))
        h4 = F.relu(self.bn4(self.dc3(h3)))
        x = F.sigmoid(self.dc4(h4))
        return x


In [59]:
class Generator(chainer.Chain):
    def __init__(self, z_dim, bottom_width=4, ch=512, wscale=0.02):
        super(Generator, self).__init__()
        self.z_dim = z_dim
        self.ch = ch
        self.bottom_width = bottom_width

        with self.init_scope():
            w = chainer.initializers.Normal(wscale)
            self.l0 = L.Linear(z_dim, bottom_width * bottom_width * ch,
                               initialW=w)
            self.dc1 = L.Deconvolution2D(ch, ch // 2, 4, 2, 1, initialW=w)
            self.dc2 = L.Deconvolution2D(ch // 2, ch // 4, 4, 2, 1, initialW=w)
            self.dc3 = L.Deconvolution2D(ch // 4, ch // 8, 4, 2, 1, initialW=w)
            self.dc4 = L.Deconvolution2D(ch // 8, 3, 3, 1, 1, initialW=w)
            self.bn0 = L.BatchNormalization(bottom_width * bottom_width * ch)
            self.bn1 = L.BatchNormalization(ch // 2)
            self.bn2 = L.BatchNormalization(ch // 4)
            self.bn3 = L.BatchNormalization(ch // 8)

    def __call__(self, z):
        h = F.reshape(F.relu(self.bn0(self.l0(z))),
                      (len(z), self.ch, self.bottom_width, self.bottom_width))
        h = F.relu(self.bn1(self.dc1(h)))
        h = F.relu(self.bn2(self.dc2(h)))
        h = F.relu(self.bn3(self.dc3(h)))
        x = F.sigmoid(self.dc4(h))
        return x

## Discriminatorのクラス

In [42]:
class Discriminator(Chain):
    def __init__(self, wscale=0.02):
        w = chainer.initializers.Normal(wscale)
        super(Discriminator, self).__init__()
        with self.init_scope():
            self.c1=L.Convolution2D(3, 64, 3, stride=3, pad=1, initialW=w)
            self.c2=L.Convolution2D(64, 128, 2, stride=2, pad=1, initialW=w)
            self.c3=L.Convolution2D(128, 256, 2, stride=2, pad=1, initialW=w)
            self.c4=L.Convolution2D(256, 512, 2, stride=2, pad=1, initialW=w)
            
            self.l1=L.Linear(3 * 3 * 512, 2)
            
            self.bn1=L.BatchNormalization(128)
            self.bn2=L.BatchNormalization(256)
            self.bn3=L.BatchNormalization(512)
    
    def __call__(self, x):
        h = add_noise(x)
        h1 = F.leaky_relu(self.c1(h))
        h2 = F.leaky_relu(self.bn1(self.c2(h1)))
        h3 = F.leaky_relu(self.bn2(self.c3(h2)))
        h4 = F.leaky_relu(self.bn3(self.c4(h3)))
        y = self.l1(h4)
        return y


In [48]:
def add_noise(h, sigma=0.2):
    xp = cuda.get_array_module(h.data)
    if chainer.config.train:
        return h + sigma * xp.random.randn(*h.shape)
    else:
        return h

class Discriminator(chainer.Chain):

    def __init__(self, bottom_width=4, ch=512, wscale=0.02):
        w = chainer.initializers.Normal(wscale)
        super(Discriminator, self).__init__()
        with self.init_scope():
            self.c0_0 = L.Convolution2D(3, ch // 8, 3, 1, 1, initialW=w)
            self.c0_1 = L.Convolution2D(ch // 8, ch // 4, 4, 2, 1, initialW=w)
            self.c1_0 = L.Convolution2D(ch // 4, ch // 4, 3, 1, 1, initialW=w)
            self.c1_1 = L.Convolution2D(ch // 4, ch // 2, 4, 2, 1, initialW=w)
            self.c2_0 = L.Convolution2D(ch // 2, ch // 2, 3, 1, 1, initialW=w)
            self.c2_1 = L.Convolution2D(ch // 2, ch // 1, 4, 2, 1, initialW=w)
            self.c3_0 = L.Convolution2D(ch // 1, ch // 1, 3, 1, 1, initialW=w)
            self.l4 = L.Linear(bottom_width * bottom_width * ch, 1, initialW=w)
            self.bn0_1 = L.BatchNormalization(ch // 4, use_gamma=False)
            self.bn1_0 = L.BatchNormalization(ch // 4, use_gamma=False)
            self.bn1_1 = L.BatchNormalization(ch // 2, use_gamma=False)
            self.bn2_0 = L.BatchNormalization(ch // 2, use_gamma=False)
            self.bn2_1 = L.BatchNormalization(ch // 1, use_gamma=False)
            self.bn3_0 = L.BatchNormalization(ch // 1, use_gamma=False)

    def __call__(self, x):
        h = add_noise(x)
        h = F.leaky_relu(add_noise(self.c0_0(h)))
        h = F.leaky_relu(add_noise(self.bn0_1(self.c0_1(h))))
        h = F.leaky_relu(add_noise(self.bn1_0(self.c1_0(h))))
        h = F.leaky_relu(add_noise(self.bn1_1(self.c1_1(h))))
        h = F.leaky_relu(add_noise(self.bn2_0(self.c2_0(h))))
        h = F.leaky_relu(add_noise(self.bn2_1(self.c2_1(h))))
        h = F.leaky_relu(add_noise(self.bn3_0(self.c3_0(h))))
        return self.l4(h)

## Updaterのクラス

In [54]:
class GANUpdater(training.StandardUpdater):
    def __init__(self, *args, **kwargs):
        self.gen, self.dis = kwargs.pop('models')
        super(GANUpdater, self).__init__(*args, **kwargs)
        
    def dis_loss(self, y_fake, y_real):
        batch_size = len(y_fake)
        real_loss = F.sum(F.softplus(-y_real)) / batch_size
        fake_loss = F.sum(F.softplus(y_fake)) / batch_size
        loss = real_loss + fake_loss
        report({'loss': loss}, self.dis)
        return loss
    
    def gen_loss(self, y_fake):
        batch_size = len(y_fake)
        loss = F.sum(F.softplus(-y_fake)) / batch_size
        report({'loss': loss}, self.gen)
        return loss
    
    def update_core(self):
        gen_optimizer = self.get_optimizer('gen')
        dis_optimizer = self.get_optimizer('dis')
        
        batch = self.get_iterator('main').next()
        x_real = Variable(self.converter(batch, self.device))
        batch_size = len(x_real)
        
        y_real = dis(x_real)
        
        z = xp.random.uniform(-1, 1, (batch_size, self.gen.z_dim))
        z = z.astype(dtype=xp.float32)
        x_fake = gen(z)
        y_fake = dis(x_fake)
        
        dis_optimizer.update(self.dis_loss, y_fake, y_real)
        gen_optimizer.update(self.gen_loss, y_fake)
        

## 表示用のクラス

In [55]:
def image(gen, rows, cols, dst):
    @training.make_extension()
    def make_image(trainer):
        n_images = rows * cols
        z = xp.random.uniform(-1, 1, (rows*cols, gen.z_dim))
        z = z.astype(dtype=xp.float32)
        with chainer.using_config('train', False):
            x = gen(z)
        x = cuda.to_cpu(x.data)

        x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
        _, ch, H, W = X.shape
        x = x.reshape((rows, cols, H, W, ch))
        x = x.transpose(0, 3, 1, 2, 4)
        x = x.reshape((rows * H, cols * W, 3))
        
        preview_dir = '{}/preview'.format(dst)
        preview_path = preview_dir +\
            '/image{:0>8}.png'.format(trainer.updater.iteration)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        Image.fromarray(x).save(preview_path)
    return make_image

# Datasetの準備

In [57]:
imgs = []
dataset_path = os.environ.get(key="DATA_SET_DIR")
dataset_dir = os.path.join(dataset_path, 'cifar10/cifar-10-batches-bin')
n_train = 10000
for i in range(5):
    dataset_path = os.path.join(dataset_dir, 'data_batch_'+str(i+1)+'.bin')
    with open(dataset_path, 'rb') as cifar10:
        for i in range(n_train):
            data = cifar10.read(3073)
            values = list(struct.unpack("3073b", data))
            label = values[0]
            img = np.asarray(values[1:], dtype=np.uint8).reshape(32,32,3, order='F')/255.
            imgs.append(img)

imgs = np.asarray(imgs, dtype=np.float32)
X = imgs.reshape(n_train*5, 3, 32, 32)
train = iterators.SerialIterator(X, 100)

## 実行する

In [60]:
gen = Generator(100)
dis = Discriminator()
gen.to_gpu()
dis.to_gpu()

gen_opt = optimizers.Adam()
gen_opt.setup(gen)
dis_opt = optimizers.Adam()
dis_opt.setup(dis)



updater = GANUpdater(models=(gen, dis), 
                     iterator=train, 
                     optimizer={'gen': gen_opt, 'dis': dis_opt}, 
                     device=0)

trainer = training.Trainer(updater, (1000, 'epoch'), out='result')

snapshot_interval = (10, 'epoch')
display_interval = (1000, 'iteration')
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
    'epoch', 'iteration', 'gen/loss', 'dis/loss', 'elapsed_time'
]), trigger=display_interval)

trainer.extend(image(gen, 10, 10, 'result'),trigger=snapshot_interval)

trainer.run()

epoch       iteration   gen/loss    dis/loss    elapsed_time
[J2           1000        7.76178     0.220708    190.632       
[J4           2000        9.37549     0.165556    380.336       
[J6           3000        9.03522     0.137917    569.928       
[J8           4000        9.48489     0.150576    759.463       
[J10          5000        10.2463     0.142548    949.123       
[J12          6000        9.78042     0.192017    1138.47       
[J14          7000        9.36257     0.183196    1328.76       
[J16          8000        10.8717     0.235048    1518.33       
[J18          9000        10.0213     0.195054    1708          
[J20          10000       10.1593     0.19002     1897.56       
[J22          11000       9.34098     0.19644     2087.3        
[J24          12000       8.49097     0.203287    2277.02       
[J26          13000       9.03608     0.195798    2466.69       
[J28          14000       8.63164     0.251899    2656.38       
[J30          

[J250         125000      5.58755     0.319555    23723.2       
[J252         126000      5.78716     0.313835    23912.8       
[J254         127000      5.78945     0.256407    24102.3       
[J256         128000      6.0171      0.285311    24291.8       
[J258         129000      5.80303     0.2994      24481.5       
[J260         130000      5.97929     0.261046    24671         
[J262         131000      6.07783     0.273318    24860.6       
[J264         132000      6.03834     0.257192    25050.3       
[J266         133000      6.19351     0.28797     25240         
[J268         134000      5.89282     0.309133    25429.8       
[J270         135000      6.2732      0.279121    25619.6       
[J272         136000      6.19814     0.266637    25809.4       
[J274         137000      6.17065     0.27228     25999.1       
[J276         138000      6.11775     0.268054    26188.8       
[J278         139000      6.36635     0.332565    26378.6       
[J280    

[J500         250000      7.7933      0.125364    47447.6       
[J502         251000      8.08099     0.201545    47637.4       
[J504         252000      7.98115     0.199076    47827.3       
[J506         253000      8.16319     0.177535    48017.1       
[J508         254000      7.89945     0.194336    48206.9       
[J510         255000      8.34736     0.247431    48396.7       
[J512         256000      7.82816     0.164706    48586.6       
[J514         257000      8.13501     0.203096    48776.3       
[J516         258000      8.11666     0.161201    48966.2       
[J518         259000      8.10223     0.13487     49156.1       
[J520         260000      8.43017     0.219943    49345.9       
[J522         261000      8.10636     0.194305    49535.8       
[J524         262000      8.1549      0.158615    49725.5       
[J526         263000      8.49231     0.165175    49915.2       
[J528         264000      8.16367     0.166413    50104.9       
[J530    

[J750         375000      10.7774     0.110891    71157.5       
[J752         376000      10.9934     0.142777    71347.1       
[J754         377000      10.6459     0.0760422   71536.5       
[J756         378000      10.9455     0.155664    71725.9       
[J758         379000      10.8633     0.107857    71915.4       
[J760         380000      11.5847     0.154656    72104.9       
[J762         381000      11.1327     0.0904665   72294.4       
[J764         382000      11.0914     0.0715527   72483.9       
[J766         383000      11.22       0.102205    72673.4       
[J768         384000      11.2828     0.130514    72862.9       
[J770         385000      10.8999     0.156671    73052.4       
[J772         386000      11.0089     0.0800804   73242         
[J774         387000      11.3487     0.106303    73431.4       
[J776         388000      11.4991     0.0899509   73620.8       
[J778         389000      11.1032     0.124115    73810.1       
[J780    

[J1000        500000      13.2793     0.10861     94811.9       


# 表示の確認

In [62]:
def raw_image(x, rows, cols):
    x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
    n_images = rows * cols
    print(x.shape)
    _, ch, H, W = X.shape
    x = x.reshape((rows, cols, H, W, ch))
    x = x.transpose(0, 3, 1, 2, 4)
    x = x.reshape((rows * H, cols * W, 3))
    Image.fromarray(x).show()
raw_image(X[:100], 10, 10)

(100, 3, 32, 32)
