In [1]:
import mxnet as mx
import numpy as np
import matplotlib.pyplot as plt
import sys
import cv2
import glob
import os
sys.path.append('../')
from util.visualizer import *
%matplotlib inline

In [2]:
def make_symG(data, ndf=64, use_dropout=False, n_blocks=9, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
    
    BatchNorm = mx.sym.BatchNorm
    
    ## Convolution BatchNorm RELU layer
    c1 = mx.sym.Convolution(data, name='c1', kernel=(7, 7), pad=(3, 3), stride=(1, 1), num_filter=ndf, no_bias=no_bias)
    cbn1 = BatchNorm(c1, name='cbn1', fix_gamma=fix_gamma, eps=eps)
    cact1 = mx.sym.Activation(cbn1, name='cact1', act_type='relu')
    
    c2 = mx.sym.Convolution(cact1, name='c2', kernel=(3, 3), pad=(1, 1), stride=(2, 2), num_filter=ndf*2, no_bias=no_bias)
    cbn2 = BatchNorm(c2, name='cbn2', fix_gamma=fix_gamma, eps=eps)
    cact2 = mx.sym.Activation(cbn2, name='cact2', act_type='relu')
    
    c3 = mx.sym.Convolution(cact2, name='c3', kernel=(3, 3), pad=(1, 1), stride=(2, 2), num_filter=ndf*4, no_bias=no_bias)
    cbn3 = BatchNorm(c3, name='cbn3', fix_gamma=fix_gamma, eps=eps)
    cact3 = mx.sym.Activation(cbn3, name='cact3', act_type='relu')
    
    ## Resnet Block
    reslayer_out = cact3
    for i in range(n_blocks):
        reslayer = mx.sym.Convolution(reslayer_out, name='resc1_%d'%i, kernel=(3, 3), pad=(1, 1), num_filter=ndf*4, no_bias=no_bias)
        reslayer = BatchNorm(reslayer, name='resbn1_%d'%i, fix_gamma=fix_gamma, eps=eps)
        reslayer = mx.sym.Activation(reslayer, name='resact%d'%i, act_type='relu')
        reslayer = mx.sym.Convolution(reslayer, name='resc2_%d'%i, kernel=(3, 3), pad=(1, 1), num_filter=ndf*4, no_bias=no_bias)
        reslayer = BatchNorm(reslayer, name='resbn2_%d'%i, fix_gamma=fix_gamma, eps=eps)
        reslayer_out = reslayer_out + reslayer
    
    ## Deconvolution Layer
    d1 = mx.sym.Deconvolution(reslayer_out, name='d1', kernel=(3, 3), pad=(1, 1), stride=(2, 2), adj=(1, 1), num_filter=ndf*2, no_bias=no_bias)
    dbn1 = BatchNorm(d1, name='dbn1', fix_gamma=fix_gamma, eps=eps)
    dact1 = mx.sym.Activation(dbn1, name='dact1', act_type='relu')
    
    d2 = mx.sym.Deconvolution(dact1, name='d2', kernel=(3, 3), pad=(1, 1), stride=(2, 2), adj=(1, 1), num_filter=ndf, no_bias=no_bias)
    dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)
    dact2 = mx.sym.Activation(dbn2, name='dact2', act_type='relu')
    
    c4 = mx.sym.Convolution(dact2, name='c4', kernel=(7, 7), pad=(3, 3), stride=(1, 1), num_filter=3, no_bias=no_bias)
    gout = mx.sym.Activation(c4, name='gout', act_type='tanh')
    
    return gout

In [3]:
def make_symD(data, label, ndf=64, no_bias=True, fix_gamma=True, eps=1e-5 + 1e-12):
    BatchNorm = mx.sym.BatchNorm
    d1 = mx.sym.Convolution(data, name='d1', kernel=(4, 4), stride=(
        2, 2), pad=(2, 2), num_filter=ndf, no_bias=no_bias)
    dact1 = mx.sym.LeakyReLU(d1, name='dact1', act_type='leaky', slope=0.2)

    d2 = mx.sym.Convolution(dact1, name='d2', kernel=(4, 4), stride=(
        2, 2), pad=(2, 2), num_filter=ndf * 2, no_bias=no_bias)
    dbn2 = BatchNorm(d2, name='dbn2', fix_gamma=fix_gamma, eps=eps)
    dact2 = mx.sym.LeakyReLU(dbn2, name='dact2', act_type='leaky', slope=0.2)

    d3 = mx.sym.Convolution(dact2, name='d3', kernel=(4, 4), stride=(
        2, 2), pad=(2, 2), num_filter=ndf * 4, no_bias=no_bias)
    dbn3 = BatchNorm(d3, name='dbn3', fix_gamma=fix_gamma, eps=eps)
    dact3 = mx.sym.LeakyReLU(dbn3, name='dact3', act_type='leaky', slope=0.2)

    d4 = mx.sym.Convolution(dact3, name='d4', kernel=(4, 4), stride=(
        1, 1), pad=(2, 2), num_filter=ndf * 8, no_bias=no_bias)
    dbn4 = BatchNorm(d4, name='dbn4', fix_gamma=fix_gamma, eps=eps)
    dact4 = mx.sym.LeakyReLU(dbn4, name='dact4', act_type='leaky', slope=0.2)

    d5 = mx.sym.Convolution(dact4, name='d5', kernel=(4, 4), stride=(
        1, 1), pad=(2, 2), num_filter=1, no_bias=no_bias)
    
    # fc = mx.sym.FullyConnected(data = d5, num_hidden=1, name='output')
    
    mseloss_ = mx.sym.mean(mx.sym.square(d5 - label), name='mean')
    mseloss = mx.sym.MakeLoss(data=mseloss_, name='mean_square_loss')
    
    return mseloss

In [4]:
def make_cycleGAN():
    # Generator
    dataA = mx.sym.Variable('dataA')
    dataB = mx.sym.Variable('dataB')
    symG_A = make_symG(dataA, 32)
    symG_B = make_symG(dataB, 32)
    
    # Discriminator
    dataC = mx.sym.Variable('dataC')
    dataD = mx.sym.Variable('dataD')
    labelC = mx.sym.Variable('labelC')
    labelD = mx.sym.Variable('labelD')
    symD_A = make_symD(dataC, labelC)
    symD_B = make_symD(dataD, labelD)
    
    return symG_A, symG_B, symD_A, symD_B

----

In [5]:
ndf = 32
ngf = 32
nc = 3
batch_size = 1
Z = 100
lr = 0.0002
beta1 = 0.5
ctx = mx.gpu(0)
check_point = False
label = mx.nd.zeros((batch_size, 1, 35, 35), ctx=ctx)

In [6]:
symG_A, symG_B, symD_A, symD_B = make_cycleGAN()
# Generator A
modG_A = mx.mod.Module(symbol=symG_A, data_names=(
    'dataA',), label_names=None, context=ctx)
modG_A.bind(data_shapes=[('dataA', (batch_size, 3, 256, 256))],
           inputs_need_grad=True)
modG_A.init_params(initializer=mx.init.Normal(0.02))
modG_A.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'wd': 0.,
        'beta1': beta1,
    })

# Generator B
modG_B = mx.mod.Module(symbol=symG_B, data_names=(
    'dataB',), label_names=None, context=ctx)
modG_B.bind(data_shapes=[('dataB', (batch_size, 3, 256, 256))],
           inputs_need_grad=True)
modG_B.init_params(initializer=mx.init.Normal(0.02))
modG_B.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'wd': 0.,
        'beta1': beta1,
    })

In [7]:
# Discriminator A
modD_A = mx.mod.Module(symbol=symD_A, data_names=(
    'dataC',), label_names=('labelC',), context=ctx)
modD_A.bind(data_shapes=[('dataC', (batch_size, 3, 256, 256))], 
            label_shapes=[('labelC', (batch_size, 1, 35, 35))],
            inputs_need_grad=True)
modD_A.init_params(initializer=mx.init.Normal(0.02))
modD_A.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'wd': 0.,
        'beta1': beta1,
    })

# Discriminator B
modD_B = mx.mod.Module(symbol=symD_B, data_names=(
    'dataD',), label_names=('labelD',), context=ctx)
modD_B.bind(data_shapes=[('dataD', (batch_size, 3, 256, 256))], 
            label_shapes=[('labelD', (batch_size, 1, 35, 35))],
            inputs_need_grad=True)
modD_B.init_params(initializer=mx.init.Normal(0.02))
modD_B.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'wd': 0.,
        'beta1': beta1,
    })

----

In [8]:
## Image Iterator
class ImagenetIter(mx.io.DataIter):

    def __init__(self, path, batch_size, data_shape):
        self.internal = mx.image.ImageIter(
            imglist=[[1, img] for img in path],
            data_shape=data_shape,
            batch_size=batch_size,
            path_root='./',
            )
        self.provide_data = [('data', (batch_size,) + data_shape)]
        self.provide_label = []

    def reset(self):
        self.internal.reset()

    def iter_next(self):
        return self.internal.iter_next()

    def getdata(self):
        data = self.internal.next().data[0]
        data = data * (2.0 / 255.0)
        data -= 1
        return [data]

In [9]:
## load train data to iterator
horses = glob.glob('../../datasets/vangogh2photo/trainA/*.jpg')
zebras = glob.glob('../../datasets/vangogh2photo/trainB/*.jpg')
horses_iter = ImagenetIter(horses, batch_size, (3, 256, 256))
zebras_iter = ImagenetIter(zebras, batch_size, (3, 256, 256))

loading image list...
loading image list...


In [10]:
## load test data to iterator
horses = glob.glob('../../datasets/vangogh2photo/testA/*.jpg')
zebras = glob.glob('../../datasets/vangogh2photo/testB/*.jpg')
horses_test_iter = ImagenetIter(horses, batch_size, (3, 256, 256))
zebras_test_iter = ImagenetIter(zebras, batch_size, (3, 256, 256))

loading image list...
loading image list...


In [11]:
inputA = horses_iter.getdata()
inputB = zebras_iter.getdata()

In [12]:
horses_iter.reset()
zebras_iter.reset()

In [14]:
def get_l1grad(cycle, real, lamb=150, sigma=100):
    l1_g = (cycle.asnumpy() - real.asnumpy())
    l1_loss = np.mean(np.abs(l1_g))
    l1_g =  l1_g * sigma * sigma     
    l1_g[l1_g > 1] = 1
    l1_g[l1_g < -1] = -1
    grad = mx.ndarray.array(l1_g * lamb *(1.0/(256.0*256.0*3)), ctx=ctx)
    return l1_loss, grad

In [15]:
def update_generator(inputA, inputB):
    # calculate loss for inputA
    modG_A.forward(mx.io.DataBatch(data=inputA, label=None), is_train=True)
    fakeB = modG_A.get_outputs()
    modG_B.forward(mx.io.DataBatch(data=fakeB, label=None), is_train=True)
    cycleA = modG_B.get_outputs()

    # backward for cycle L1 loss for inputA and cycleA
    l1lossA, grad = get_l1grad(cycleA[0], inputA[0])
    modG_B.backward([grad])

    # backward for GAN loss
    label[:] = 1
    modD_B.forward(mx.io.DataBatch(data=fakeB, label=[label]), is_train=True)
    modD_B.backward()

    modG_A.backward([modG_B.get_input_grads()[0] + modD_B.get_input_grads()[0]])

    # save gradients for future update
    gradG_A = [[grad.copyto(grad.context) for grad in grads] for grads in modG_A._exec_group.grad_arrays]
    gradG_B = [[grad.copyto(grad.context) for grad in grads] for grads in modG_B._exec_group.grad_arrays]

    ## Calculate loss for inputB
    modG_B.forward(mx.io.DataBatch(data=inputB, label=None), is_train=True)
    fakeA = modG_B.get_outputs()
    modG_A.forward(mx.io.DataBatch(data=fakeA, label=None), is_train=True)
    cycleB = modG_A.get_outputs()

    # backward for cycle L1 loss for inputB and cycleB
    l1lossB, grad = get_l1grad(cycleB[0], inputB[0])
    modG_A.backward([grad])

    # backward for GAN loss
    label[:] = 1
    modD_A.forward(mx.io.DataBatch(data=fakeA, label=[label]), is_train=True)
    modD_A.backward()

    modG_B.backward([modG_A.get_input_grads()[0] + modD_A.get_input_grads()[0]])

    # update Generator A and Generator B
    for gradsr, gradsf in zip(modG_A._exec_group.grad_arrays, gradG_A):
        for gradr, gradf in zip(gradsr, gradsf):
            gradr += gradf
    modG_A.update()

    for gradsr, gradsf in zip(modG_B._exec_group.grad_arrays, gradG_B):
        for gradr, gradf in zip(gradsr, gradsf):
            gradr += gradf
    modG_B.update()
    
    return l1lossA, l1lossB

In [16]:
def update_discriminator(modD, real, fake):
    # train with real data
    label[:] = 1
    modD.forward(mx.io.DataBatch(data=real, label=[label]), is_train=True)
    modD.backward()
    # save gradient for future use
    gradD = [[grad.copyto(grad.context) for grad in grads] for grads in modD._exec_group.grad_arrays]
    
    # loss of discriminator
    loss = modD.get_outputs()[0].asnumpy()[0]
    
    # train with fake data
    label[:] = 0
    modD.forward(mx.io.DataBatch(data=fake, label=[label]), is_train=True)
    modD.backward()
    loss += modD.get_outputs()[0].asnumpy()[0]
    for gradsr, gradsf in zip(modD._exec_group.grad_arrays, gradD):
        for gradr, gradf in zip(gradsr, gradsf):
            gradr += gradf
    modD.update()
    return loss

In [17]:
testA = horses_test_iter.getdata()
testB = zebras_test_iter.getdata()

In [18]:
dirname = './cycle_output_new'
if not os.path.exists(dirname):
    os.makedirs(dirname)

In [None]:
for i in range(20):
    horses_iter.reset()
    zebras_iter.reset()
    if (i % 1 == 0):
        inputA = horses_iter.getdata()
        inputB = zebras_iter.getdata()
        # visualize A-B-A
        modG_A.forward(mx.io.DataBatch(data=inputA, label=None), is_train=True)
        fakeB = modG_A.get_outputs()
        modG_B.forward(mx.io.DataBatch(data=fakeB, label=None), is_train=True)
        cycleA = modG_B.get_outputs()
        A_B_A = np.concatenate((inputA[0].asnumpy(), fakeB[0].asnumpy(), cycleA[0].asnumpy()))
        visual(os.path.join(dirname, 'A_B_A'+str(i)+'.jpg'), A_B_A)
        
        # visualize B-A-B
        modG_B.forward(mx.io.DataBatch(data=inputB, label=None), is_train=True)
        fakeA = modG_B.get_outputs()
        modG_A.forward(mx.io.DataBatch(data=fakeA, label=None), is_train=True)
        cycleB = modG_A.get_outputs()
        B_A_B = np.concatenate((inputB[0].asnumpy(), fakeA[0].asnumpy(), cycleB[0].asnumpy()))
        visual(os.path.join(dirname, 'B_A_B'+str(i)+'.jpg'), B_A_B)
    
    # reset image iterator
    horses_iter.reset()
    zebras_iter.reset()
    for j in range(399):
        inputA = horses_iter.getdata()
        inputB = zebras_iter.getdata()
        l1lossA, l1lossB = update_generator(inputA, inputB)
        modG_A.forward(mx.io.DataBatch(data=inputA, label=None), is_train=True)
        fakeB = modG_A.get_outputs()
        modG_B.forward(mx.io.DataBatch(data=inputB, label=None), is_train=True)
        fakeA = modG_B.get_outputs()
        lossD_A = update_discriminator(modD_A, inputA, fakeA)
        lossD_B = update_discriminator(modD_B, inputB, fakeB)
        if (j % 200 == 0):
            print 'epoch: ' + str(i) + '_' + str(j) + ' lossD_A: ' + str(lossD_A) + ' lossD_B: ' + str(lossD_B) + ' l1loss_a:' + str(l1lossA) + ' l1loss_b:' + str(l1lossB)
            visual(os.path.join(dirname, 'fakeA'+str(i)+'_'+str(j/100)+'.jpg'), fakeA[0].asnumpy())
            visual(os.path.join(dirname, 'fakeB'+str(i)+'_'+str(j/100)+'.jpg'), fakeB[0].asnumpy())
            modG_A.forward(mx.io.DataBatch(data=testA, label=None), is_train=True)
            modG_B.forward(mx.io.DataBatch(data=testB, label=None), is_train=True)
            visual(os.path.join(dirname, 'testB'+str(i)+'_'+str(j/100)+'.jpg'), modG_A.get_outputs()[0].asnumpy())
            visual(os.path.join(dirname, 'testA'+str(i)+'_'+str(j/100)+'.jpg'), modG_B.get_outputs()[0].asnumpy())

  buff = np.zeros((n*X.shape[1], n*X.shape[2], X.shape[3]), dtype=np.uint8)
  exec_.forward(is_train=is_train)


epoch: 0_0 lossD_A: 3.96132 lossD_B: 2.77213 l1loss_a:0.665736 l1loss_b:0.516695


----

# Below are used for test function or so

In [30]:
c = np.concatenate((cycleA[0].asnumpy(), inputA[0].asnumpy(), fakeB[0].asnumpy()))

In [31]:
visual(os.path.join(dirname, 'aaa.jpg'), c)

In [39]:
inputA = horses_iter.getdata()
inputB = zebras_iter.getdata()
l1lossA, l1lossB = update_generator(inputA, inputB)
modG_A.forward(mx.io.DataBatch(data=inputA, label=None), is_train=True)
fakeB = modG_A.get_outputs()
modG_B.forward(mx.io.DataBatch(data=inputB, label=None), is_train=True)
fakeA = modG_B.get_outputs()
lossA = update_discriminator(modD_A, inputA, fakeA)
lossB = update_discriminator(modD_B, inputB, fakeB)

In [25]:
modG_A.forward(mx.io.DataBatch(data=inputA, label=None), is_train=True)
fakeB = modG_A.get_outputs()
modG_B.forward(mx.io.DataBatch(data=fakeB, label=None), is_train=True)
cycleA = modG_B.get_outputs()

In [31]:
modG_B.forward(mx.io.DataBatch(data=inputB, label=None), is_train=True)
fakeA = modG_B.get_outputs()
modG_A.forward(mx.io.DataBatch(data=fakeA, label=None), is_train=True)
cycleB = modG_A.get_outputs()

In [32]:
visual(os.path.join(dirname, 'B_2.jpg'), cycleB[0].asnumpy())

In [114]:
label[:] = 0
modD_B.forward(mx.io.DataBatch(data=fakeB, label=[label]), is_train=True)

In [49]:
output = symD_B.get_internals()['output_output']

In [71]:
# Discriminator B
modD_O = mx.mod.Module(symbol=output, data_names=(
    'dataD',), label_names=None, context=ctx)
modD_O.bind(data_shapes=[('dataD', (batch_size, 3, 256, 256))], 
            inputs_need_grad=True)
modD_O.init_params(initializer=mx.init.Normal(0.02))
modD_O.init_optimizer(
    optimizer='adam',
    optimizer_params={
        'learning_rate': lr,
        'wd': 0.,
        'beta1': beta1,
    })

In [77]:
modD_O.forward(mx.io.DataBatch(data=fakeB, label=[label]), is_train=True)
c = modD_O.get_outputs()[0]

In [105]:
output.bind(ctx=ctx, args=new)

ValueError: Length of aux_states do not match number of arguments

In [103]:
new = dict({'dataD' : fakeB[0]})
new.update(modD_B.get_params()[0])

In [104]:
new

{'d1_weight': <NDArray 64x3x4x4 @cpu(0)>,
 'd2_weight': <NDArray 128x64x4x4 @cpu(0)>,
 'd3_weight': <NDArray 256x128x4x4 @cpu(0)>,
 'd4_weight': <NDArray 512x256x4x4 @cpu(0)>,
 'd5_weight': <NDArray 1x512x4x4 @cpu(0)>,
 'dataD': <NDArray 9x3x256x256 @gpu(0)>,
 'dbn2_beta': <NDArray 128 @cpu(0)>,
 'dbn2_gamma': <NDArray 128 @cpu(0)>,
 'dbn3_beta': <NDArray 256 @cpu(0)>,
 'dbn3_gamma': <NDArray 256 @cpu(0)>,
 'dbn4_beta': <NDArray 512 @cpu(0)>,
 'dbn4_gamma': <NDArray 512 @cpu(0)>,
 'output_bias': <NDArray 1 @cpu(0)>,
 'output_weight': <NDArray 1x1225 @cpu(0)>}

-----

-----