## 諸々インポート

In [1]:
import os
import chainer
from chainer import Chain
from chainer import Variable
from chainer import training, report
from chainer import optimizers, iterators
from chainer.training import extensions
from chainer import cuda
import chainer.functions as F
import chainer.links as L

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_mldata
import numpy as np
import pandas as pd

import pickle
from PIL import Image

cuda.get_device_from_id(0).use()
xp = cuda.cupy

In [2]:
import struct
from PIL import Image

## Generatorのクラス

In [3]:
class Generator(Chain):
    def __init__(self, z_dim, wscale=0.02):
        w = chainer.initializers.Normal(wscale)
        super(Generator, self).__init__()
        with self.init_scope():
            self.l1=L.Linear(z_dim, 3 * 3 * 512)
            
            self.dc1=L.Deconvolution2D(512, 256, 2, stride=2, pad=1, initialW=w)
            self.dc2=L.Deconvolution2D(256, 128, 2, stride=2, pad=1, initialW=w)
            self.dc3=L.Deconvolution2D(128, 64, 2, stride=2, pad=1, initialW=w)
            self.dc4=L.Deconvolution2D(64, 3, 7, stride=3, pad=1, initialW=w)
            
            self.bn1=L.BatchNormalization(512)
            self.bn2=L.BatchNormalization(256)
            self.bn3=L.BatchNormalization(128)
            self.bn4=L.BatchNormalization(64)
        self.z_dim = z_dim
        
    def __call__(self, z):
        h = self.l1(z)
        h = F.reshape(h, (z.shape[0], 512, 3, 3))
        
        h1 = F.relu(self.bn1(h))
        h2 = F.relu(self.bn2(self.dc1(h1)))
        h3 = F.relu(self.bn3(self.dc2(h2)))
        h4 = F.relu(self.bn4(self.dc3(h3)))
        x = F.sigmoid(self.dc4(h4))
        return x


## Discriminatorのクラス

In [4]:
def add_noise(h, sigma=0.2):
    xp = cuda.get_array_module(h.data)
    if chainer.config.train:
        return h + sigma * xp.random.randn(*h.shape)
    else:
        return h

class Discriminator(Chain):
    def __init__(self, wscale=0.02):
        w = chainer.initializers.Normal(wscale)
        super(Discriminator, self).__init__()
        with self.init_scope():
            self.c1=L.Convolution2D(3, 64, 3, stride=3, pad=1, initialW=w)
            self.c2=L.Convolution2D(64, 128, 2, stride=2, pad=1, initialW=w)
            self.c3=L.Convolution2D(128, 256, 2, stride=2, pad=1, initialW=w)
            self.c4=L.Convolution2D(256, 512, 2, stride=2, pad=1, initialW=w)
            
            self.l1=L.Linear(3 * 3 * 512, 2)
            
            self.bn1=L.BatchNormalization(128)
            self.bn2=L.BatchNormalization(256)
            self.bn3=L.BatchNormalization(512)
    
    def __call__(self, x):
        h = add_noise(x)
        h1 = F.leaky_relu(self.c1(h))
        h2 = F.leaky_relu(self.bn1(self.c2(h1)))
        h3 = F.leaky_relu(self.bn2(self.c3(h2)))
        h4 = F.leaky_relu(self.bn3(self.c4(h3)))
        y = self.l1(h4)
        return y


## Updaterのクラス

In [5]:
class GANUpdater(training.StandardUpdater):
    def __init__(self, *args, **kwargs):
        self.gen, self.dis = kwargs.pop('models')
        super(GANUpdater, self).__init__(*args, **kwargs)
        
    def dis_loss(self, y_fake, y_real):
        batch_size = len(y_fake)
        real_loss = F.sum(F.softplus(-y_real)) / batch_size
        fake_loss = F.sum(F.softplus(y_fake)) / batch_size
        loss = real_loss + fake_loss
        report({'loss': loss}, self.dis)
        return loss
    
    def gen_loss(self, y_fake):
        batch_size = len(y_fake)
        loss = F.sum(F.softplus(-y_fake)) / batch_size
        report({'loss': loss}, self.gen)
        return loss
    
    def update_core(self):
        gen_optimizer = self.get_optimizer('gen')
        dis_optimizer = self.get_optimizer('dis')
        
        batch = self.get_iterator('main').next()
        x_real = Variable(self.converter(batch, self.device))
        batch_size = len(x_real)
        
        y_real = dis(x_real)
        
        z = xp.random.uniform(-1, 1, (batch_size, self.gen.z_dim))
        z = z.astype(dtype=xp.float32)
        x_fake = gen(z)
        y_fake = dis(x_fake)
        
        dis_optimizer.update(self.dis_loss, y_fake, y_real)
        gen_optimizer.update(self.gen_loss, y_fake)
        

## 表示用のクラス

In [6]:
def image(gen, rows, cols, dst):
    @training.make_extension()
    def make_image(trainer):
        n_images = rows * cols
        z = xp.random.uniform(-1, 1, (rows*cols, gen.z_dim))
        z = z.astype(dtype=xp.float32)
        with chainer.using_config('train', False):
            x = gen(z)
        x = cuda.to_cpu(x.data)

        x = np.asarray(np.clip(x * 255, 0.0, 255.0), dtype=np.uint8)
        _, _, H, W = x.shape
        x = x.reshape((rows, cols, 3, H, W))
        x = x.transpose(0, 3, 1, 4, 2)
        x = x.reshape((rows * H, cols * W, 3))
        
        preview_dir = '{}/preview'.format(dst)
        preview_path = preview_dir +\
            '/image{:0>8}.png'.format(trainer.updater.iteration)
        if not os.path.exists(preview_dir):
            os.makedirs(preview_dir)
        Image.fromarray(x).save(preview_path)
    return make_image

## 実行する

In [None]:
gen = Generator(100)
dis = Discriminator()
gen.to_gpu()
dis.to_gpu()

gen_opt = optimizers.Adam()
gen_opt.setup(gen)
dis_opt = optimizers.Adam()
dis_opt.setup(dis)

imgs = []
dataset_path = os.environ.get(key="DATA_SET_DIR")
dataset_dir = os.path.join(dataset_path, 'cifar10/cifar-10-batches-bin')
n_train = 10000
for i in range(5):
    dataset_path = os.path.join(dataset_dir, 'data_batch_'+str(i+1)+'.bin')
    with open(dataset_path, 'rb') as cifar10:
        for i in range(n_train):
            data = cifar10.read(3073)
            values = list(struct.unpack("3073b", data))
            label = values[0]
            img = np.asarray(values[1:], dtype=np.uint8).reshape(32,32,3, order='F')/255.
            imgs.append(img)

imgs = np.asarray(imgs, dtype=np.float32)
X = imgs.reshape(n_train*5, 3, 32, 32)

train = iterators.SerialIterator(X, 100)

updater = GANUpdater(models=(gen, dis), 
                     iterator=train, 
                     optimizer={'gen': gen_opt, 'dis': dis_opt}, 
                     device=0)

trainer = training.Trainer(updater, (1000, 'epoch'), out='result')

snapshot_interval = (10, 'epoch')
display_interval = (1000, 'iteration')
trainer.extend(extensions.LogReport(trigger=display_interval))
trainer.extend(extensions.PrintReport([
    'epoch', 'iteration', 'gen/loss', 'dis/loss', 'elapsed_time'
]), trigger=display_interval)

trainer.extend(image(gen, 10, 10, 'result'),trigger=snapshot_interval)

trainer.run()

epoch       iteration   gen/loss    dis/loss    elapsed_time
[J2           1000        13.3756     0.267901    20.786        
[J4           2000        12.7655     0.149555    41.255        
[J6           3000        14.9482     0.14894     61.7251       
[J8           4000        15.342      0.11814     82.1239       
[J10          5000        16.2367     0.119011    102.618       
[J12          6000        16.6022     0.174672    123.764       
[J14          7000        16.0269     0.195657    144.691       
[J16          8000        16.5522     0.145074    165.941       
[J18          9000        14.9671     0.1396      186.859       
[J20          10000       18.3422     0.15351     207.442       
[J22          11000       16.2164     0.129818    228.309       
[J24          12000       15.551      0.13664     249.102       
[J26          13000       18.3611     0.188278    270.192       
[J28          14000       17.0132     0.168595    290.929       
[J30          

In [1]:
# 表示用
deff = X[0] * 255
a = Image.fromarray(np.uint8(deff.reshape(32,32,3)))
a = a.transpose(Image.ROTATE_270)
a.show()

NameError: name 'X' is not defined