In [326]:
import chainer as ch
import chainer.links as L
import chainer.functions as F

import numpy as np
import scipy as sp
import pandas as pd
from matplotlib import pyplot as plt
%matplotlib inline

import os
import PIL

[impl|test]
- activation power: prelu > relu > lrelu
    - generator
        - [x| ] relu
        - [ | ] lrelu
        - [ | ] prelu
    - discriminator
        - [ | ] relu
        - [x| ] lrelu
        - [ ] prelu
TRY WEAKENING DISCRIMIATOR BY REPLACING CONV+POOL W/ ONLY STRIDED CONV
            # does it make sense reducing channels when striding conv? try it both ways
            # guess it makes sense that more infomation is preserved that way
            # but don't you want to weaken the discriminator?
            # REDUCE BEFORE STRIDED CONV - CLOSER ANAGLOGUE TO POOLING
- discriminator strength: (1) > (2) > (3) > (4)
    - [ | ] (1) channel expansion before strided convolution
        32*32*3->32*32*128->16*16*128->16*16*256->
        8*8*256->8*8*512->4*4*512->4*4*512
    - [x| ] (2) channel expansion during strided convolution
        32*32*3->32*32*64->16*16*128->16*16*128->
        8*8*256->8*8*256->4*4*512->4*4*512
    - [x| ] (3) channel expansion after strided convolution
        32*32*3->32*32*64->16*16*64->16*16*128->
        8*8*128->8*8*256->4*4*512->4*4*512    - [x| ] (4) channel expansion during strided convolution, only strided convolution
        32*32*3->16*16*64->8*8*128->4*4*256->
- generator strength: (1) > (2) > (3) > (4)
    - [ | ] (1) channel compression after strided deconv
    - [ | ] (2) channel compression during strided deconv
    - [ | ] (3) channel compression before strided deconv
    - [x| ] (4) channel compression during strided deconv, only strided deconv
        4*4*512->8*8*256->16*16*128->32*32->64->32*32*3

In [365]:
class Generator(ch.Chain):
    def __init__(self,nh=128,d0=4,c0=512,Wscale=0.02):
        self.nh,self.d0,self.c0 = nh,d0,c0
        super(Generator,self).__init__()
        with self.init_scope():
            # WHY NOT ADD NON STRIDED CONVS BETWEEN DECONVS?
            Winit = ch.initializers.Normal(scale=Wscale)
            self.l0 = L.Linear(in_size=nh,out_size=d0*d0*c0,initialW=Winit)
            self.dc1 = L.Deconvolution2D(in_channels=c0,out_channels=c0//2,
                                         ksize=4,stride=2,pad=1,initialW=Winit)
            self.dc2 = L.Deconvolution2D(in_channels=c0//2,out_channels=c0//4,
                                         ksize=4,stride=2,pad=1,initialW=Winit)
            self.dc3 = L.Deconvolution2D(in_channels=c0//4,out_channels=c0//8,
                                         ksize=4,stride=2,pad=1,initialW=Winit)
            self.dc4 = L.Deconvolution2D(in_channels=c0//8,out_channels=3,
                                         ksize=3,stride=1,pad=1,initialW=Winit)
            self.bn0 = L.BatchNormalization(d0*d0*c0)
            self.bn1 = L.BatchNormalization(c0//2)
            self.bn2 = L.BatchNormalization(c0//4)
            self.bn3 = L.BatchNormalization(c0//8)
        
    def hidden_init(self,nb):
        return np.random.uniform(-1,1,(nb,self.nh,1,1)).astype(np.float32)
    
    def __call__(self,z):
        nb,nh,*_ = z.shape
        assert(nh==self.nh)
        h = F.reshape(F.relu(self.bn0(self.l0(z))),
                      (nb,self.c0,self.d0,self.d0))
        # nb*512*4*4 
        h = F.relu(self.bn1(self.dc1(h)))
        # nb*256*8*8
        h = F.relu(self.bn2(self.dc2(h)))
        # nb*128*16*16
        h = F.relu(self.bn3(self.dc3(h)))
        # nb*64*32*32
        return F.sigmoid(self.dc4(h))
        # nb*3*32*32

In [366]:
def noise(h,sigma=0.2):
    xp = ch.backends.cuda.get_array_module(h.array)
    return h + sigma*xp.random.rand(*h.shape) if ch.config.train else h

In [386]:
class Discrimiator(ch.Chain):
    def __init__(self,nh=128,df=4,cf=512,Wscale=0.02):
        Winit = ch.initializers.Normal(scale=Wscale)
        super(Discrimiator,self).__init__()
        with self.init_scope():
            self.c0a = L.Convolution2D(in_channels=3,out_channels=cf//8,
                                       ksize=3,stride=1,pad=1,initialW=Winit)
            self.c0b = L.Convolution2D(in_channels=cf//8,out_channels=cf//4,
                                       ksize=4,stride=2,pad=1,initialW=Winit)
            self.c1a = L.Convolution2D(in_channels=cf//4,out_channels=cf//4,
                                       ksize=3,stride=1,pad=1,initialW=Winit)
            self.c1b = L.Convolution2D(in_channels=cf//4,out_channels=cf//2,
                                       ksize=4,stride=2,pad=1,initialW=Winit)
            self.c2a = L.Convolution2D(in_channels=cf//2,out_channels=cf//2,
                                       ksize=3,stride=1,pad=1,initialW=Winit)
            self.c2b = L.Convolution2D(in_channels=cf//2,out_channels=cf,
                                       ksize=4,stride=2,pad=1,initialW=Winit)
            self.c3a = L.Convolution2D(in_channels=cf,out_channels=cf,
                                       ksize=3,stride=1,pad=1,initialW=Winit)
            self.l4 = L.Linear(df*df*cf,1,initialW=Winit)
            # no batch norm for first conv b/c already normalized
            self.bn0b = L.BatchNormalization(cf//4,use_gamma=False)
            self.bn1a = L.BatchNormalization(cf//4,use_gamma=False)
            self.bn1b = L.BatchNormalization(cf//2,use_gamma=False)
            self.bn2a = L.BatchNormalization(cf//2,use_gamma=False)
            self.bn2b = L.BatchNormalization(cf,use_gamma=False)
            self.bn3a = L.BatchNormalization(cf,use_gamma=False)

    def __call__(self,x):                             
        h = noise(x)
        # 32*32*3
        h = F.leaky_relu(noise(self.c0a(h)))
        # 32*32*64
        h = F.leaky_relu(noise(self.bn0b(self.c0b(h))))
        # 16*16*128
        h = F.leaky_relu(noise(self.bn1a(self.c1a(h))))
        # 16*16*128
        h = F.leaky_relu(noise(self.bn1b(self.c1b(h))))
        # 8*8*256
        h = F.leaky_relu(noise(self.bn2a(self.c2a(h))))
        # 8*8*256
        h = F.leaky_relu(noise(self.bn2b(self.c2b(h))))
        # 4*4*512
        h = F.leaky_relu(noise(self.bn3a(self.c3a(h))))
        # 4*4*512
        return self.l4(h)

In [438]:
class DCGANupdater(ch.training.updaters.StandardUpdater):
    def __init__(self,models=(None,None),*args,**kwargs):
        self.gen,self.dis = models
        assert self.gen != None and self.dis != None
        super(DCGANupdater,self).__init__(*args,**kwargs)
    
    def genLoss(self,yFake):
        nbF,*_ = yFake.shape
        L = F.sum(F.softplus(-yFake))
        L = L / nbF
        ch.report({'loss':L},dis)
        return L
    
    def disLoss(self,yFake,yReal):
        nbF,*_ = yFake.shape
        nbR,*_ = yReal.shape
        L = F.sum(F.softplus(yFake)) + F.sum(F.softplus(-yReal))
#         return L / (nbF + nbR)
        L = L / nbF # in tutorial repo
        ch.report({'loss':L},gen)
        return L
                                             
    def update_core(self):
        genOpt,disOpt = self.get_optimizer('gen'),self.get_optimizer('dis')
        gen,dis = self.gen,self.dis
        
        batch = self.get_iterator('main').next()
        xReal = ch.Variable(self.converter(batch,self.device))/255
        xp = ch.backends.cuda.get_array_module(xReal.array)
        nb,*_ = xReal.shape
        
        z = ch.Variable(xp.asarray(gen.hidden_init(nb)))
        xFake = gen(z)
#         print(xFake.array.shape)
#         print(xReal.array.shape)
        yFake,yReal = dis(xFake),dis(xReal)
        
        genOpt.update(self.genLoss,yFake)
        disOpt.update(self.disLoss,yFake,yReal)

In [450]:
def sample(gen,dis,rows,cols,seed,out):
    @ch.training.make_extension()
    def sampleImg(trainer):
        np.random.seed(seed)
        nb = rows*cols
        xp = gen.xp
        z = ch.Variable(xp.asarray(gen.hidden_init(nb)))
        with ch.using_config('train',False):
            x = gen(z)
        x = ch.backends.cuda.to_cpu(x.array)
        np.random.seed()
        
        x = np.asarray(np.clip(x*255,0.0,255.0),dtype=np.uint8)
        _,_,H,W = x.shape
        x = x.reshape((rows,cols,3,H,W))
        # turn sampled images into grid
        x = x.transpose(0,3,1,4,2).reshape(rows*H,cols*W,3)
        
        sampleDir = os.path.join(out,'samples')
        if not os.path.exists(sampleDir): os.makedirs(sampleDir)
        samplePath = os.path.join(sampleDir,'sample_iter_{:0>8}.png'.format(trainer.updater.iteration))
        PIL.Image.fromarray(x).save(samplePath)
    return sampleImg

In [451]:
train,_ = ch.datasets.get_cifar10(withlabel=False,scale=255.)
# imgFiles = [os.path.join(datasetRoot,f) 
#             for f in os.listdir(datasetRoot) 
#             if ('png' in f or 'jpg' in f)]
# train = ch.datasets.ImageDataset(imgFiles)

nb = 32
trainIter = ch.iterators.SerialIterator(train,nb)
gen = Generator()
dis = Discrimiator()
print(gen.count_params(),dis.count_params())

def makeOpt(model,alpha=0.0002,beta1=0.5):
    opt = ch.optimizers.Adam(alpha=alpha,beta1=beta1)
    opt.setup(model)
    opt.add_hook(ch.optimizer_hooks.WeightDecay(0.0001),'hook_dec')
    return opt
genOpt,disOpt = makeOpt(gen),makeOpt(dis)

3828739 5862657


In [452]:
dev = -1 # CPU
saveEvery = (100,'iteration')
dispEvery = (10,'iteration')
stopAfter = (1,'epoch')
sampleSeed = 0
snapshotPath = ""
updater = DCGANupdater(models=(gen,dis),
                       iterator=trainIter,
                       optimizer={'gen':genOpt,'dis':disOpt},
                       device=dev)
root = '.'
outdir = os.path.join(root,'results')
if not os.path.exists(outdir): os.mkdir(outdir)
trainer = ch.training.Trainer(updater,stop_trigger=stopAfter,out=outdir)
saveExt = [
    ch.training.extensions.snapshot(filename='snapshot_iter_{.updater.iteration}.npz'),
    ch.training.extensions.snapshot_object(gen,'gen_iter_{.updater.iteration}.npz'),
    ch.training.extensions.snapshot_object(dis,'dis_iter_{.updater.iteration}.npz'),
]
dispExt = [
#     ch.training.extensions.LogReport(),
    ch.training.extensions.PrintReport(['epoch','iteration','gen/loss','dis/loss']),
    ch.training.extensions.PlotReport(['gen/loss','dis/loss'],'iteration',file_name='loss.png'),
#     ch.training.extensions.ProgressBar(),
    sample(gen=gen,dis=dis,rows=10,cols=10,seed=sampleSeed,out=outdir),
]
trainer.extend(ch.training.extensions.ProgressBar(update_interval=10))
trainer.extend(ch.training.extensions.LogReport(trigger=dispEvery))
for ext in saveExt: trainer.extend(extension=ext,trigger=saveEvery)
for ext in dispExt: trainer.extend(extension=ext,trigger=dispEvery)
if snapshotPath:
    ch.serializers.load_npz(snapshotPath,trainer)

In [453]:
trainer.run()

[J     total [..................................................]  0.64%
this epoch [..................................................]  0.64%
        10 iter, 0 epoch / 1 epochs
       inf iters/sec. Estimated time to finish: 0:00:00.
[4Aepoch       iteration   gen/loss    dis/loss  
[J0           10          2.59841     1.26126     
[J     total [..................................................]  1.28%
this epoch [..................................................]  1.28%
        20 iter, 0 epoch / 1 epochs
   0.11129 iters/sec. Estimated time to finish: 3:51:00.171706.
[4A[J0           20          1.62074     1.21432     
[J     total [..................................................]  1.92%
this epoch [..................................................]  1.92%
        30 iter, 0 epoch / 1 epochs
   0.11263 iters/sec. Estimated time to finish: 3:46:46.363012.
[4A[J0           30          0.724471    2.55863     
[J     total [#.........................................