In [3]:
from chainer import Chain
from chainer import Variable, optimizers
import chainer.functions as F
import chainer.links as L

import matplotlib.pyplot as plt

from sklearn.datasets import fetch_mldata
import numpy as np
import pandas as pd

import pickle


In [4]:
class Generator(Chain):
    def __init__(self, z_dim):
        super(Generator, self).__init__(
            l1=L.Linear(z_dim, 3 * 3 * 512),
            
            dc1=L.Deconvolution2D(512, 256, 2, stride=2, pad=1, ),
            dc2=L.Deconvolution2D(256, 128, 2, stride=2, pad=1, ),
            dc3=L.Deconvolution2D(128, 64, 2, stride=2, pad=1, ),
            dc4=L.Deconvolution2D(64, 1, 3, stride=3, pad=1,),
            
            bn1=L.BatchNormalization(512),
            bn2=L.BatchNormalization(256),
            bn3=L.BatchNormalization(128),
            bn4=L.BatchNormalization(64),
        )
        self.z_dim = z_dim
        
    def __call__(self, z):
        h = self.l1(z)
        h = F.reshape(h, (z.data.shape[0], 512, 3, 3))
        
        h1 = F.relu(self.bn1(h))
        h2 = F.relu(self.bn2(self.dc1(h1)))
        h3 = F.relu(self.bn3(self.dc2(h2)))
        h4 = F.relu(self.bn4(self.dc3(h3)))
        x = self.dc4(h4)
        return x


In [5]:
class Discriminator(Chain):
    def __init__(self, ):
        super(Discriminator, self).__init__(
            c1=L.Convolution2D(1, 64, 3, stride=3, pad=1, ),
            c2=L.Convolution2D(64, 128, 2, stride=2, pad=1, ),
            c3=L.Convolution2D(128, 256, 2, stride=2, pad=1, ),
            c4=L.Convolution2D(256, 512, 2, stride=2, pad=1, ),
            
            l1=L.Linear(3 * 3 * 512, 2),
            
            bn1=L.BatchNormalization(128),
            bn2=L.BatchNormalization(256),
            bn3=L.BatchNormalization(512),
        )
    
    def __call__(self, x):
        h1 = F.relu(self.c1(x))
        h2 = F.relu(self.bn1(self.c2(h1)))
        h3 = F.relu(self.bn2(self.c3(h2)))
        h4 = F.relu(self.bn3(self.c4(h3)))
        y = self.l1(h4)
        return y


In [6]:
class Trainer(object):
    def __init__(self, gen, dis):
        self.gen = gen
        self.dis = dis
        self.z_dim = gen.z_dim
        
    def fit(self, X, epochs=100, batch_size=1000, plotting=True):
        self.X = X
        self.epochs = epochs
        self.batch_size = batch_size
        self.plotting = plotting
        
        n_train = X.shape[0]
        o_gen = optimizers.Adam()
        o_dis = optimizers.Adam()
        
        o_gen.setup(self.gen)
        o_dis.setup(self.dis)
        
        self.loss = []
        for epoch in range(epochs):
            perm = np.random.permutation(n_train)
            sum_loss_of_dis = np.float32(0)
            sum_loss_of_gen = np.float32(0)
            
            for i in range(int(n_train / batch_size)):
                z = np.random.uniform(-1, 1, (batch_size, self.z_dim))
                z = z.astype(dtype=np.float32)
                
                # 生成する
                x = self.gen(z)
                # 贋作か予測する(贋作側)
                y1 = self.dis(x)
                
                loss_gen = F.softmax_cross_entropy(y1, Variable(np.zeros(batch_size, dtype=np.int32)))
                loss_dis = F.softmax_cross_entropy(y1, Variable(np.ones(batch_size, dtype=np.int32)))
                
                idx = perm[i * batch_size:(i + 1) * batch_size]
                x_data = self.X[idx]
                
                # 贋作か予測する（真作側）
                y2 = self.dis(x_data)
                
                loss_dis += F.softmax_cross_entropy(y2, Variable(np.zeros(batch_size, dtype=np.int32)))
                
                self.dis.cleargrads()
                loss_dis.backward()
                o_dis.update()
                
                self.gen.cleargrads()
                loss_gen.backward()
                o_gen.update()
                
                sum_loss_of_dis += loss_dis.data
                sum_loss_of_gen += loss_gen.data
                
            print('epoch-{epoch}\tloss\tdiscreminator-{sum_loss_of_dis:.3f}\tgenerator-{sum_loss_of_gen:.3f}'.format(**locals()))
            
            self.loss.append([sum_loss_of_gen, sum_loss_of_dis])
            
            if plotting:
                plt.figure(figsize=(12, 12))
                n_row = 3
                s = n_row ** 2
                z = Variable(np.random.uniform(-1, 1, 100 * s).reshape(-1, 100).astype(np.float32))
                x = self.gen(z)
                y = self.dis(x)
                y = F.softmax(y).data
                x = x.data.reshape(-1, 28, 28)
                for i, xx in enumerate(x):
                    plt.subplot(n_row, n_row, i + 1)
                    plt.imshow(xx, interpolation="nearest", cmap="gray")
                    plt.axis('off')
                    plt.title('True Prob {0:.3f}'.format(y[i][0]))
                plt.tight_layout()
                plt.savefig('epoch-{epoch}.png'.format(**locals()), dip=100)
                plt.close('all')
        print(self.loss)
        

In [None]:
gen = Generator(100)
dis = Discriminator()

data = fetch_mldata('MNIST original')
X = data['data']
n_train = X.shape[0]
X = np.array(X, dtype=np.float32)
X /= 255.
X = X.reshape(n_train, 1, 28, 28)

print("Start!")
trainer = Trainer(gen, dis)
trainer.fit(X, batch_size=1000, epochs=1000)

df_loss = pd.DataFrame(trainer.loss)
df_loss.to_csv('loss.csv')

gen.to_cpu()
dis.to_cpu()

with open('generator.model', 'wb') as w:
    pickle.dump(gen, w)

with open('discriminator.model', 'wb') as w:
    pickle.dump(dis, w)

Start!
epoch-0	loss	discreminator-6.483	generator-394.438
epoch-1	loss	discreminator-0.275	generator-482.044
epoch-2	loss	discreminator-0.126	generator-537.606
epoch-3	loss	discreminator-0.050	generator-580.688
epoch-4	loss	discreminator-0.038	generator-588.821
epoch-5	loss	discreminator-0.026	generator-609.267
