# これを300 epochほどやろう

In [3]:
# coding: UTF-8
# これは成功！（100 epochで十分）

import argparse
import os

import numpy as np
#from PIL import Image
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import initializers
from chainer import training
from chainer import Variable
from chainer import cuda
from chainer.dataset import iterator as iterator_module
from chainer.training import extensions
from chainer.dataset import convert
from chainer import Variable
#from evaluate import out_generated_image
from chainer import variable
from data import GmmDataset
from utils import out_generated, calcEMD

# np = cuda.cupy
#random_state = np.random.RandomState(123)


# Updater

class WGANUpdater(training.StandardUpdater):
    def __init__(self, iterator, generator, critic,
                 n_c, opt_g, opt_c, lam, lam2, n_hidden, device):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator
        self.generator = generator
        self.critic = critic
        self.n_c = n_c
        self._optimizers = {'generator': opt_g, 'critic': opt_c}
        self.lam = lam
        self.lam2 = lam2
        self.device = device
        self.converter = convert.concat_examples
        self.iteration = 0
        self.n_hidden = n_hidden

    def update_core(self):
        batch = self.get_iterator('main').next()
        batchsize = len(batch)

        # Step1 Generate
        z = Variable(np.asarray(self.generator.make_hidden(batchsize)))
        #print(z)
        x_gen = self.generator(z)
        #print("iteration is ", self.iteration)
        #print("x_gen is ", x_gen)
        y_gen = self.critic(x_gen)

        # Step2 real
        x_real = Variable(np.array(batch))
        y_real = self.critic(x_real)

        # Step3 Compute loss for wgan_gp
        eps = np.random.uniform(0, 1, (batchsize, 1)).astype("f")
        x_mid = eps * x_real + (1.0 - eps) * x_gen
        x_mid_v = Variable(x_mid.data)
        y_mid = self.critic(x_mid_v)
        dydx = chainer.grad([y_mid], [x_mid_v], enable_double_backprop=True)[0]
        dydx = F.sqrt(F.sum(F.square(dydx), axis=1))
        loss_gp = self.lam * F.mean_squared_error(dydx, np.ones_like(dydx.data))
        loss_cri = F.sum(-y_real) / batchsize
        #print(loss_cri)
        loss_cri += F.sum(y_gen) / batchsize
        
        loss_sp = self.lam2 * F.absolute(F.sum(self.critic.inter.W) - 1)
        loss_all = loss_cri + loss_gp + loss_sp
        
        # Step4 Update critic
        self.critic.cleargrads()
        loss_all.backward()
        self._optimizers['critic'].update()

        # Step5 Update generator
        
        if self.iteration % self.n_c == 0:
            loss_gen = F.sum(-y_gen) / batchsize
            #print(loss_gen)
            self.generator.cleargrads()
            loss_gen.backward()
            self._optimizers['generator'].update()
            chainer.reporter.report({'loss/generator': loss_gen})

        """
        if self.iteration < 2500 and self.iteration % 100 == 0:
            loss_gen = F.sum(-y_gen) / batchsize
            loss_sp = self.lam2 * F.absolute(F.sum(self.generator.inter.W) - 1)
            loss_gen += loss_sp
            self.generator.cleargrads()
            loss_gen.backward()
            self._optimizers['generator'].update()
            chainer.reporter.report({'loss/generator': loss_gen})

        if self.iteration > 2500 and self.iteration % self.n_c == 0:
            loss_gen = F.sum(-y_gen) / batchsize
            loss_sp = self.lam2 * F.absolute(F.sum(self.generator.inter.W) - 1)
            loss_gen += loss_sp
            self.generator.cleargrads()
            loss_gen.backward()
            self._optimizers['generator'].update()
            chainer.reporter.report({'loss/generator': loss_gen})
        """

        # Step6 Report
        chainer.reporter.report({'loss/critic': loss_cri})


def main():
    parser = argparse.ArgumentParser(description='WGAN')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--datasize', '-d', type=int, default=10000,
                        help='Number of samples')
    parser.add_argument('--lam', '-l', type=int, default=0.5,
                        help='Hyperparameter of gp')
    parser.add_argument('--lam2', '-l2', type=int, default=1.0,
                        help='Hyperparameter of gp')
    parser.add_argument('--n_hidden', type=int, default=100,
                        help='latent variable')
    parser.add_argument('--seed', '-s', type=int, default=111,
                        help='seed number')
    parser.add_argument('--epoch', '-e', type=int, default=300,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=-1,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    parser.add_argument('--setting', '-set', type=int, default=10,
                        help='Number of units')
    args = parser.parse_args(args=[])

    # parameter set
    ## Optimizers
    alpha = 0.002
    beta1 = 0.9
    beta2 = 0.999

    ## Network
    dim = 2
    num_nets = 1
    wscale = 0.02

    # Networks
    generator = Generator(dim, num_nets, args.n_hidden, wscale)
    critic = Critic(num_nets, wscale)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        generator.to_gpu()
        critic.to_gpu()

    # Optimizer set
    opt_g = chainer.optimizers.Adam(alpha = alpha, beta1 = beta1, beta2 = beta2, weight_decay_rate=0.0001)
    #opt_g = chainer.optimizers.RMSprop(lr=0.00005)
    opt_g.setup(generator)
    #opt_g.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
    opt_g.add_hook(chainer.optimizer.GradientClipping(1))

    opt_c = chainer.optimizers.Adam(alpha = alpha, beta1 = beta1, beta2 = beta2, weight_decay_rate=0.0001)
    #opt_c = chainer.optimizers.RMSprop(lr=0.00005)
    opt_c.setup(critic)
    #opt_c.add_hook(chainer.optimizer.WeightDecay(0.00001), 'hook_dec')
    opt_c.add_hook(chainer.optimizer.GradientClipping(1))

    # Dataset
    """
    train, test = chainer.datasets.get_mnist(withlabel=False, ndim=3, scale=255.)
    train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
    """

    dataset = GmmDataset(args.datasize, args.seed, std=0.02, scale=2)
    train_iter = chainer.iterators.SerialIterator(dataset, args.batchsize)

    # Trainer
    updater = WGANUpdater(train_iter, generator, critic, 5, opt_g, opt_c, args.lam, args.lam2, args.n_hidden, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Extensions
    # trainer.extend(extensions.dump_graph('wasserstein distance'))
    """
    epoch_interval = (50, 'epoch')
    trainer.extend(extensions.snapshot(filename='2dim2snap_epoch_{.updater.epoch}.npz'), trigger=epoch_interval)
    trainer.extend(extensions.snapshot_object(generator, '2dim2gensnap_epoch_{.updater.epoch}.npz'), trigger=epoch_interval)
    trainer.extend(extensions.snapshot_object(critic, '2dim2dissnap_epoch_{.updater.epoch}.npz'), trigger=epoch_interval)
    """
    
    trainer.extend(extensions.LogReport())

    trainer.extend(
        extensions.PlotReport(['loss/critic'],
                              'epoch', file_name='criticloss_2dim%d_3.png' % args.setting))

    trainer.extend(
        extensions.PlotReport(
            ['loss/generator'], 'epoch', file_name='genloss_2dim%d_3.png' % args.setting))

    trainer.extend(extensions.PrintReport(
        ['epoch', 'loss/critic', 'loss/generator', 'elapsed_time']))

    out = "/Users/keiikegami/Dropbox/research/Imaizumi-Sensei/paper_result/2dim/result%d_3" % args.setting

    trainer.extend(calcEMD(generator), trigger=(1, 'epoch'))
    trainer.extend(extensions.LogReport(log_name = "log_2dim%d_3" % args.setting))
    trainer.extend(out_generated(generator, args.seed + 10, out, radius=2), trigger=(10, 'epoch'))
    #trainer.extend(extensions.ProgressBar())

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run
    trainer.run()


# orriginal class of link
class Intersection2(chainer.Link):

    def __init__(self, outdim, numnet):
        super(Intersection2, self).__init__()
        self.outdim = outdim
        self.numnet = numnet
        with self.init_scope():
            W = chainer.initializers.One()
            self.W = variable.Parameter(W)
            self.W.initialize((self.numnet, 1))

    def __call__(self, x):
        if self.outdim == 1:
            weight = F.relu(self.W.T)
        else:
            weight = F.relu(self.make_weight(self.W))

        return F.matmul(weight, x)

    def make_weight(self, array):
        weight_matrix = np.zeros((self.outdim, self.outdim * self.numnet), dtype=np.float32)
        for i in range(self.numnet):
            q = np.array(array[i, 0].data, dtype=np.float32)
            weight_matrix[:, i * self.outdim:(i + 1) * self.outdim] = np.identity(self.outdim, dtype=np.float32) * q
        return Variable(weight_matrix)

    
class Generator(chainer.Chain):

    def __init__(self, dim=784, num_nets=784, latent=100, wscale=0.02):
        super(Generator, self).__init__()
        self.dim = dim
        self.num_nets = num_nets
        self.wscale = wscale
        self.n_hidden = latent

        with self.init_scope():
            self.inter = Intersection2(self.dim, self.num_nets)

            for net in range(self.num_nets):
                w = chainer.initializers.Normal(self.wscale)
                b = chainer.initializers.Normal(self.wscale)

                setattr(self, "l1_{}".format(net), L.Linear(None, 48, initialW=w, initial_bias=b))
                setattr(self, "l2_{}".format(net), L.Linear(None, 48, initialW=w, initial_bias=b))
                setattr(self, "l3_{}".format(net), L.Linear(None, 2, initialW=w, initial_bias=b))

                # set batchnormalization
                #setattr(self, "bn1_{}".format(net), L.BatchNormalization(size=800))

    def make_hidden(self, batchsize):
        return np.random.uniform(-1, 1, (batchsize, self.n_hidden)).astype(np.float32)

    def __call__(self, z, test=False):

        for net in range(self.num_nets):
            #h = F.relu(getattr(self, 'bn1_{}'.format(net))(getattr(self, 'l1_{}'.format(net))(z)))
            h = F.relu(getattr(self, 'l1_{}'.format(net))(z))
            h2 = F.relu(getattr(self, 'l2_{}'.format(net))(h))
            # ここがreluじゃダメなの自明でしょ
            h2 = getattr(self, 'l3_{}'.format(net))(h2)

            if net == 0:
                X = h2
            else:
                X = F.concat((X, h2), axis=1)

        batchsize = X.shape[0]
        X = X.reshape(batchsize, self.num_nets * self.dim)
        x = self.inter(X.T).T
        #x = Variable(np.reshape(x, (batchsize, 1, 28, 28)))
        #x = Variable(x)
        return x


class Critic(chainer.Chain):
    def __init__(self, num_nets=784, wscale=0.02):
        super(Critic, self).__init__()
        self.num_nets = num_nets
        self.wscale = wscale

        with self.init_scope():
            self.inter = Intersection2(1, self.num_nets)

            for net in range(self.num_nets):
                w = chainer.initializers.Normal(self.wscale)
                b = chainer.initializers.Normal(self.wscale)

                setattr(self, "l1_{}".format(net), L.Linear(None, 48, initialW=w, initial_bias=b))
                setattr(self, "l2_{}".format(net), L.Linear(None, 48, initialW=w, initial_bias=b))
                setattr(self, "l3_{}".format(net), L.Linear(None, 1, initialW=w, initial_bias=b))

                # set batchnormalization
                #setattr(self, "bn1_{}".format(net), L.BatchNormalization(size=800))
                #setattr(self, "bn2_{}".format(net), L.BatchNormalization(size=800))

    def __call__(self, x, test=False):

        x = x.reshape(64, 2)
        for net in range(self.num_nets):
            # ここでhがnanになることで全てがnanになる（xは確かにnanではない）（そしてそれはその前のupdateでWがnanになってるから）
            #h = F.leaky_relu(getattr(self, 'bn1_{}'.format(net))(getattr(self, 'l1_{}'.format(net))(x)))
            h = F.leaky_relu(getattr(self, 'l1_{}'.format(net))(x))
            h = F.leaky_relu(getattr(self, 'l2_{}'.format(net))(h))
            # ここsumやめる
            #h2 = F.sum(getattr(self, 'bn2_{}'.format(net))(getattr(self, 'l2_{}'.format(net))(h)), axis=1)
            h2 = getattr(self, 'l3_{}'.format(net))(h)

            if net == 0:
                #Y = h2.reshape(64, 1)
                Y = h2

            else:
                #Y = F.concat((Y, h2.reshape(64, 1)), axis=1)
                Y = F.concat((Y, h2), axis = 1)

        y = self.inter(Y.T)

        return y


In [4]:
if __name__ == "__main__":
    main()

epoch       loss/critic  loss/generator  elapsed_time
[J1           -2.71618     1.22755         1.8684        


  check_result(result_code)


[J2           -4.55309     3.50314         20.0969       
[J3           -4.29069     7.0341          36.0247       
[J4           -4.2698      8.59974         52.2403       
[J5           -4.19277     7.58923         68.7109       
[J6           -4.22025     7.99906         83.975        
[J7           -4.16781     7.9999          98.3048       
[J8           -4.09938     7.287           112.874       
[J9           -3.87401     7.31019         127.665       
[J10          -3.69061     6.49659         142.918       
[J11          -3.68058     6.13123         159.545       
[J12          -3.60508     5.83281         174.138       
[J13          -3.72637     4.60248         189.537       
[J14          -3.45697     4.8139          205.147       
[J15          -3.25043     3.89281         220.09        
[J16          -3.16224     2.46785         235.178       
[J17          -2.58603     4.02176         250.306       
[J18          -2.27027     4.16576         268.107     

[J141         -0.28331     6.41442         2245.82       
[J142         -0.26303     6.34842         2264.59       
[J143         -0.290231    5.92525         2280.05       
[J144         -0.295708    6.12845         2295.68       
[J145         -0.238486    5.90295         2310.91       
[J146         -0.247227    6.11712         2326.62       
[J147         -0.231495    6.0464          2343.14       
[J148         -0.252192    6.25843         2359.35       
[J149         -0.254691    6.36387         2374.58       
[J150         -0.230079    6.46681         2390.89       
[J151         -0.286694    6.20153         2410.81       
[J152         -0.263402    6.08131         2428.75       
[J153         -0.22626     5.69941         2446.29       
[J154         -0.231535    5.77897         2464.01       
[J155         -0.235381    5.78658         2481.98       
[J156         -0.226489    5.86736         2498.81       
[J157         -0.21942     5.84837         2516.63     

[J280         -0.469536    2.86157         4560.54       
[J281         -0.490476    -3.39757        4577.82       
[J282         -0.473256    4.14746         4593.31       
[J283         -0.435378    -1.85615        4609.28       
[J284         -0.507945    -1.99221        4627.03       
[J285         -0.488461    2.83583         4644.61       
[J286         -0.478113    -0.13022        4661.12       
[J287         -0.454452    -4.6187         4677.08       
[J288         -0.441078    1.30001         4692.79       
[J289         -0.469873    3.10436         4709.41       
[J290         -0.497143    -4.06208        4727.24       
[J291         -0.544508    1.74446         4744.13       
[J292         -0.522735    5.81415         4763.04       
[J293         -0.592171    -1.30681        4780.58       
[J294         -0.303443    1.46271         4796.76       
[J295         -0.523459    2.54663         4813.28       
[J296         -0.605917    7.2442          4831.05     