In [3]:

# coding: UTF-8
# これは成功！（100 epochで十分）

import argparse
import os

import numpy as np
#from PIL import Image
import chainer
import chainer.functions as F
import chainer.links as L
from chainer import initializers
from chainer import training
from chainer import Variable
from chainer import cuda
from chainer.dataset import iterator as iterator_module
from chainer.training import extensions
from chainer.dataset import convert
from chainer import Variable
#from evaluate import out_generated_image
from chainer import variable
from data import GmmDataset
from utils import out_generated, calcEMD
import itertools as itr


xp = cuda.cupy
#random_state = np.random.RandomState(123)


# Updater

class WGANUpdater(training.StandardUpdater):
    def __init__(self, iterator, generator, critic,
                 n_c, opt_g, opt_c, lam, lam2, n_hidden, device):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator
        self.generator = generator
        self.critic = critic
        self.n_c = n_c
        self._optimizers = {'generator': opt_g, 'critic': opt_c}
        self.lam = lam
        self.lam2 = lam2
        self.device = device
        self.converter = convert.concat_examples
        self.iteration = 0
        self.n_hidden = n_hidden

    def update_core(self):
        batch = self.get_iterator('main').next()
        batchsize = len(batch)

        # Step1 Generate
        z = Variable(xp.asarray(self.generator.make_hidden(batchsize)))
        x_gen = self.generator(z)
        y_gen = self.critic(x_gen)

        # Step2 real
        x_real = Variable(xp.array(batch))
        y_real = self.critic(x_real)

        # Step3 Compute loss for wgan_gp
        eps = xp.random.uniform(0, 1, (batchsize, 1)).astype("f")
        x_mid = eps * x_real + (1.0 - eps) * x_gen
        x_mid_v = Variable(x_mid.data)
        y_mid = self.critic(x_mid_v)
        dydx = chainer.grad([y_mid], [x_mid_v], enable_double_backprop=True)[0]
        dydx = F.sqrt(F.sum(F.square(dydx), axis=1))
        loss_gp = self.lam * F.mean_squared_error(dydx, xp.ones_like(dydx.data))
        loss_cri = F.sum(-y_real) / batchsize
        loss_cri += F.sum(y_gen) / batchsize
        
        loss_sp = self.lam2 * F.absolute(F.sum(self.critic.inter.W) - 1)
        loss_all = loss_cri + loss_gp + loss_sp
        
        # Step4 Update critic
        self.critic.cleargrads()
        loss_all.backward()
        self._optimizers['critic'].update()

        # Step5 Update generator
        
        if self.iteration % self.n_c == 0:
            loss_gen = F.sum(-y_gen) / batchsize
            self.generator.cleargrads()
            loss_gen.backward()
            self._optimizers['generator'].update()
            chainer.reporter.report({'loss/generator': loss_gen})

        # Step6 Report
        chainer.reporter.report({'loss/critic': loss_cri})

# orriginal class of link
class Intersection2(chainer.Link):

    def __init__(self, outdim, numnet):
        super(Intersection2, self).__init__()
        self.outdim = outdim
        self.numnet = numnet
        with self.init_scope():
            W = chainer.initializers.One()
            self.W = variable.Parameter(W)
            self.W.initialize((self.numnet, 1))

    def __call__(self, x):
        if self.outdim == 1:
            weight = F.relu(self.W.T)
        else:
            weight = F.relu(self.make_weight(self.W))

        return F.matmul(weight, x)

    def make_weight(self, array):
        weight_matrix = xp.zeros((self.outdim, self.outdim * self.numnet), dtype=xp.float32)
        for i in range(self.numnet):
            q = xp.array(array[i, 0].data, dtype=xp.float32)
            weight_matrix[:, i * self.outdim:(i + 1) * self.outdim] = xp.identity(self.outdim, dtype=xp.float32) * q
        return Variable(weight_matrix)

    
class Generator(chainer.Chain):

    def __init__(self, dim=784, num_nets=784, latent=100, wscale=0.02, netsize = 48):
        super(Generator, self).__init__()
        self.dim = dim
        self.num_nets = num_nets
        self.wscale = wscale
        self.n_hidden = latent
        self.netsize = netsize

        with self.init_scope():
            self.inter = Intersection2(self.dim, self.num_nets)

            for net in range(self.num_nets):
                w = chainer.initializers.Normal(self.wscale)
                b = chainer.initializers.Normal(self.wscale)

                setattr(self, "l1_{}".format(net), L.Linear(None, self.netsize, initialW=w, initial_bias=b))
                setattr(self, "l2_{}".format(net), L.Linear(None, self.netsize, initialW=w, initial_bias=b))
                setattr(self, "l3_{}".format(net), L.Linear(None, 2, initialW=w, initial_bias=b))

    def make_hidden(self, batchsize):
        return xp.random.uniform(-1, 1, (batchsize, self.n_hidden)).astype(xp.float32)

    def __call__(self, z, test=False):

        for net in range(self.num_nets):
            h = F.relu(getattr(self, 'l1_{}'.format(net))(z))
            h2 = F.relu(getattr(self, 'l2_{}'.format(net))(h))
            h2 = getattr(self, 'l3_{}'.format(net))(h2)

            if net == 0:
                X = h2
            else:
                X = F.concat((X, h2), axis=1)

        batchsize = X.shape[0]
        X = X.reshape(batchsize, self.num_nets * self.dim)
        x = self.inter(X.T).T
        return x


class Critic(chainer.Chain):
    def __init__(self, num_nets=784, wscale=0.02, netsize=48):
        super(Critic, self).__init__()
        self.num_nets = num_nets
        self.wscale = wscale
        self.netsize = netsize

        with self.init_scope():
            self.inter = Intersection2(1, self.num_nets)

            for net in range(self.num_nets):
                w = chainer.initializers.Normal(self.wscale)
                b = chainer.initializers.Normal(self.wscale)

                setattr(self, "l1_{}".format(net), L.Linear(None, self.netsize, initialW=w, initial_bias=b))
                setattr(self, "l2_{}".format(net), L.Linear(None, self.netsize, initialW=w, initial_bias=b))
                setattr(self, "l3_{}".format(net), L.Linear(None, 1, initialW=w, initial_bias=b))
                
    def __call__(self, x, test=False):
        #x = x.reshape(64, 2)
        for net in range(self.num_nets):
            h = F.leaky_relu(getattr(self, 'l1_{}'.format(net))(x))
            h = F.leaky_relu(getattr(self, 'l2_{}'.format(net))(h))
            h2 = getattr(self, 'l3_{}'.format(net))(h)

            if net == 0:
                Y = h2

            else:
                Y = F.concat((Y, h2), axis = 1)

        y = self.inter(Y.T)

        return y



def main(numnet, netsize, beta1, alpha, weightdecay, combination):
    parser = argparse.ArgumentParser(description='WGAN')
    parser.add_argument('--batchsize', '-b', type=int, default=64,
                        help='Number of images in each mini-batch')
    parser.add_argument('--datasize', '-d', type=int, default=10000,
                        help='Number of samples')
    parser.add_argument('--lam', '-l', type=int, default=0.5,
                        help='Hyperparameter of gp')
    parser.add_argument('--lam2', '-l2', type=int, default=1.0,
                        help='Hyperparameter of gp')
    parser.add_argument('--n_hidden', type=int, default=100,
                        help='latent variable')
    parser.add_argument('--alpha', type=int, default=alpha,
                        help='learning rate')
    parser.add_argument('--beta1', type=int, default=beta1,
                        help='adam parameter')
    parser.add_argument('--weightdecay', type=int, default=weightdecay,
                        help='weight decay rate')
    parser.add_argument('--netsize', type=int, default=netsize,
                        help='number of nodes')
    parser.add_argument('--numnet', type=int, default=numnet,
                        help='number of nets')
    parser.add_argument('--seed', '-s', type=int, default=111,
                        help='seed number')
    parser.add_argument('--epoch', '-e', type=int, default=100,
                        help='Number of sweeps over the dataset to train')
    parser.add_argument('--gpu', '-g', type=int, default=0,
                        help='GPU ID (negative value indicates CPU)')
    parser.add_argument('--out', '-o', default='result',
                        help='Directory to output the result')
    parser.add_argument('--resume', '-r', default='',
                        help='Resume the training from snapshot')
    parser.add_argument('--unit', '-u', type=int, default=1000,
                        help='Number of units')
    args = parser.parse_args(args=[])

    # parameter set
    ## Optimizers
    beta2 = 0.999

    ## Network
    dim = 2
    wscale = 0.02

    # Networks
    generator = Generator(dim, args.numnet, args.n_hidden, wscale, args.netsize)
    critic = Critic(args.numnet, wscale, args.netsize)

    if args.gpu >= 0:
        chainer.cuda.get_device(args.gpu).use()
        generator.to_gpu()
        critic.to_gpu()

    # Optimizer set
    opt_g = chainer.optimizers.Adam(alpha = args.alpha, beta1 = args.beta1, beta2 = beta2, weight_decay_rate=args.weightdecay)
    opt_g.setup(generator)
    opt_g.add_hook(chainer.optimizer.GradientClipping(1))

    opt_c = chainer.optimizers.Adam(alpha = args.alpha, beta1 = args.beta1, beta2 = beta2, weight_decay_rate=args.weightdecay)
    opt_c.setup(critic)
    opt_c.add_hook(chainer.optimizer.GradientClipping(1))

    # Dataset
    dataset = GmmDataset(args.datasize, args.seed, std=0.02, scale=2)
    train_iter = chainer.iterators.SerialIterator(dataset, args.batchsize)

    # Trainer
    updater = WGANUpdater(train_iter, generator, critic, 5, opt_g, opt_c, args.lam, args.lam2, args.n_hidden, device=args.gpu)
    trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

    # Extensions
    trainer.extend(calcEMD(generator, combination), trigger=(args.epoch, 'epoch'))

    if args.resume:
        chainer.serializers.load_npz(args.resume, trainer)

    # Run
    trainer.run()


In [4]:
# combination 80までがnumnet = 1の結果ということやな
numnet = [1, 2]
netsize=[32, 48, 64]
beta1=[0.0, 0.5, 0.9]
alpha=[0.005, 0.01, 0.05]
weight_decay = [0.0, 0.0001, 0.0005]

combinations = itr.product(numnet, netsize, beta1, alpha, weight_decay)
for i, comb in enumerate(combinations):
    main(comb[0], comb[1], comb[2], comb[3], comb[4], i)

  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(resul

  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(resul

  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  np.maximum(distances, 0, out=distances)
  return umr_maximum(a, axis, None, out, keepdims, initial)
  check_result(result_code)
  check_result(result_code)
  

  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)
  check_result(result_code)


KeyboardInterrupt: 