In [77]:
import mxnet as mx
import numpy as np
from matplotlib import pyplot as plt

In [106]:
dataset = mx.gluon.data.vision.datasets.ImageFolderDataset('./mxnet/data/train', flag = 1, transform = None)

In [107]:
np_data = np.zeros((7721, 3, 64, 64))

In [108]:
for i, (data, label) in enumerate(dataset):
    np_data[i][0] = data.asnumpy()[:, :, 0]
    np_data[i][1] = data.asnumpy()[:, :, 1]
    np_data[i][2] = data.asnumpy()[:, :, 2]

In [109]:
np.random.seed(1)
p = np.random.permutation(np_data.shape[0])
np_data = np_data[p]

In [110]:
np_data.shape

(7721, 3, 128, 128)

In [111]:
np_data = np_data.astype(np.float32, copy = False) / (255.0 / 2) - 1.0

In [112]:
batch_size = 9
image_iter = mx.io.NDArrayIter(np_data, batch_size = batch_size)

In [113]:
class RandIter(mx.io.DataIter):
    def __init__(self, batch_size, ndim):
        self.batch_size = batch_size
        self.ndim = ndim
        self.provide_data = [('rand', (batch_size, ndim, 1, 1))]
        self.provide_label = []

    def iter_next(self):
        return True

    def getdata(self):
        return [mx.random.normal(0, 1.0, shape = (self.batch_size, self.ndim, 1, 1))]

In [114]:
rand_iter = RandIter(batch_size, 100)

In [115]:
no_bias = True
fix_gamma = True
eps = 1e-5 + 1e-12

rand = mx.sym.Variable('rand')

g1 = mx.sym.Deconvolution(rand, kernel = (4, 4), num_filter = 1024, no_bias = no_bias)
gbn1 = mx.sym.BatchNorm(g1, fix_gamma = fix_gamma, eps = eps)
gact1 = mx.sym.Activation(gbn1, act_type = 'relu')

g2 = mx.sym.Deconvolution(gact1, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 512, no_bias = no_bias)
gbn2 = mx.sym.BatchNorm(g2, fix_gamma = fix_gamma, eps = eps)
gact2 = mx.sym.Activation(gbn2, act_type = 'relu')

g3 = mx.sym.Deconvolution(gact2, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 256, no_bias = no_bias)
gbn3 = mx.sym.BatchNorm(g3, fix_gamma = fix_gamma, eps = eps)
gact3 = mx.sym.Activation(gbn3, act_type = 'relu')

g4 = mx.sym.Deconvolution(gact3, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 128, no_bias = no_bias)
gbn4 = mx.sym.BatchNorm(g4, fix_gamma = fix_gamma, eps = eps)
gact4 = mx.sym.Activation(gbn4, act_type = 'relu')

g5 = mx.sym.Deconvolution(gact4, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 3, no_bias = no_bias)
generatorSymbol = mx.sym.Activation(g5, name = 'gact5', act_type = 'tanh')

In [120]:
data = mx.sym.Variable('data')

d1 = mx.sym.Convolution(data, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 128, no_bias = no_bias)
dact1 = mx.sym.LeakyReLU(d1, act_type = 'leaky', slope = 0.2)

d2 = mx.sym.Convolution(dact1, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 256, no_bias = no_bias)
dbn2 = mx.sym.BatchNorm(d2, fix_gamma = fix_gamma, eps = eps)
dact2 = mx.sym.LeakyReLU(dbn2, act_type = 'leaky', slope = 0.2)

d3 = mx.sym.Convolution(dact2, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 512, no_bias = no_bias)
dbn3 = mx.sym.BatchNorm(d3, fix_gamma = fix_gamma, eps = eps)
dact3 = mx.sym.LeakyReLU(dbn3, act_type = 'leaky', slope = 0.2)

#d4l = mx.sym.Convolution(dact3, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 768, no_bias = no_bias)
#dbn4l = mx.sym.BatchNorm(d4l, fix_gamma = fix_gamma, eps = eps)
#dact4l = mx.sym.LeakyReLU(dbn4l, act_type = 'leaky', slope = 0.2)

d4 = mx.sym.Convolution(dact3, kernel = (4, 4), stride = (2, 2), pad = (1, 1), num_filter = 1024, no_bias = no_bias)
dbn4 = mx.sym.BatchNorm(d4, fix_gamma = fix_gamma, eps = eps)
dact4 = mx.sym.LeakyReLU(dbn4, act_type = 'leaky', slope = 0.2)

d5 = mx.sym.Convolution(dact4, name = 'd5', kernel = (4, 4), num_filter = 1, no_bias = no_bias)
d5 = mx.sym.Flatten(d5)

label = mx.sym.Variable('label')
discriminatorSymbol = mx.sym.LogisticRegressionOutput(data = d5, label = label, name='dloss')

In [121]:
sigma = 0.02
lr = 0.01
beta1 = 0.5
ctx = mx.gpu()

In [122]:
#=============Generator Module=============
generator = mx.mod.Module(symbol = generatorSymbol, data_names = ('rand',), label_names = None, context = ctx)

generator.bind(data_shapes = rand_iter.provide_data)

generator.init_params(initializer = mx.init.Normal(sigma))

generator.init_optimizer(optimizer = 'adam', optimizer_params = {'learning_rate': lr, 'beta1': beta1,})

mods = [generator]

In [123]:
#=============Discriminator Module=============
discriminator = mx.mod.Module(symbol = discriminatorSymbol, data_names = ('data',), label_names = ('label',), context = ctx)

discriminator.bind(data_shapes = image_iter.provide_data, 
                   label_shapes = [('label', (batch_size,))],
                   inputs_need_grad=True)

discriminator.init_params(initializer = mx.init.Normal(sigma))

discriminator.init_optimizer(optimizer = 'adam', optimizer_params = {'learning_rate': lr, 'beta1': beta1,})

mods.append(discriminator)

In [124]:
def fill_buf(buf, num_images, img, shape):
    width = buf.shape[0]/shape[1]
    height = buf.shape[1]/shape[0]
    img_width = int(num_images%width)*shape[0]
    img_hight = int(num_images/height)*shape[1]
    buf[img_hight:img_hight+shape[1], img_width:img_width+shape[0], :] = img

def visualize(fake, real):
    fake = fake.transpose((0, 2, 3, 1))
    fake = np.clip((fake+1.0)*(255.0/2.0), 0, 255).astype(np.uint8)
    
    real = real.transpose((0, 2, 3, 1))
    real = np.clip((real+1.0)*(255.0/2.0), 0, 255).astype(np.uint8)


    n = np.ceil(np.sqrt(fake.shape[0]))
    fbuff = np.zeros((int(n*fake.shape[1]), int(n*fake.shape[2]), int(fake.shape[3])), dtype=np.uint8)
    for i, img in enumerate(fake):
        fill_buf(fbuff, i, img, fake.shape[1:3])

    rbuff = np.zeros((int(n*real.shape[1]), int(n*real.shape[2]), int(real.shape[3])), dtype=np.uint8)
    for i, img in enumerate(real):
        fill_buf(rbuff, i, img, real.shape[1:3])


    fig = plt.figure()
    ax1 = fig.add_subplot(2,2,1)
    ax1.imshow(fbuff)
    
    ax2 = fig.add_subplot(2,2,2)
    ax2.imshow(rbuff)
    plt.show()

In [None]:
# =============fit===============
print('fit')
for epoch in range(1):
    image_iter.reset()
    for i, batch in enumerate(image_iter):

        rbatch = rand_iter.next()
        generator.forward(rbatch, is_train = True)
        outG = generator.get_outputs()

        label = mx.nd.zeros((batch_size,), ctx = ctx)
        discriminator.forward(mx.io.DataBatch(outG, [label]), is_train = True) #died on 128x128
        discriminator.backward()
        gradD = [[grad.copyto(grad.context) for grad in grads] for grads in discriminator._exec_group.grad_arrays]

        label[:] = 1
        batch.label = [label]
        discriminator.forward(batch, is_train = True)
        discriminator.backward()
        for gradsr, gradsf in zip(discriminator._exec_group.grad_arrays, gradD):
            for gradr, gradf in zip(gradsr, gradsf):
                gradr += gradf
        discriminator.update()

        discriminator.forward(mx.io.DataBatch(outG, [label]), is_train=True)
        discriminator.backward()
        diffD = discriminator.get_input_grads()
        generator.backward(diffD)
        generator.update()

        i += 1
        if i % 50 == 0:
            print('epoch:', epoch, 'iter:', i)
            print
            print("   From generator:        From NN:")

            visualize(outG[0].asnumpy(), batch.data[0].asnumpy())

fit
