In [None]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import print_function
import logging
import numpy as np
import matplotlib.pyplot as plt

from theano import tensor, function

#from keras.datasets import mnist
from fuel.datasets.binarized_mnist import BinarizedMNIST

from keras.models import Graph, Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils, generic_utils
from keras.optimizers import Adam
from keras.initializations import normal


from seya.layers.draw import DRAW
from seya.layers.draw2 import DRAW2
from seya.layers.base import Lambda

from agnez import video_grid2d
from agnez.app_callbacks import SendImgur, SendFigHTML

from IPython import display

In [None]:
logger = logging.getLogger()
hdlr = logging.FileHandler('./draw.log')
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
hdlr.setFormatter(formatter)
logger.addHandler(hdlr) 
logger.setLevel(logging.INFO)

In [None]:
batch_size = 100
nb_epoch = 100
n_steps = 64
h_dim = 300
z_dim = 100
input_shape = (1, 28, 28)

data_train = BinarizedMNIST(which_sets=['train'], sources=['features'])
data_test = BinarizedMNIST(which_sets=['test'], sources=['features'])
X_train = data_train.get_data(request=slice(0, 50000))[0].astype('float32')
X_test = data_test.get_data(request=slice(0, 1000))[0].astype('float32')

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')

In [None]:
def myinit(shape):
    return normal(shape, scale=.01)

draw = DRAW(h_dim=h_dim, z_dim=z_dim, input_shape=(1, 28, 28), N_enc=2, N_dec=5,
            return_sequences=True, inner_rnn='gru', init=myinit, inner_init=myinit)
# draw.write_factor

# Agnez clients

In [None]:
def gifgen():
    return 'draw.gif'

class PlotGen():
    def __init__(self):
        self.values = []
    
    def add_value(self, new_value):
        self.values.append(new_value)
    
    def __call__(self):
        fig = plt.figure()
        plt.plot(np.asarray(self.values))
        plt.title('cost function')
        return fig

plotgen = PlotGen()
send_plot = SendFigHTML(generate_plot=plotgen, app_url='https://agnez.herokuapp.com/values')
send_gif = SendImgur(generate_img=gifgen, app_url='https://agnez.herokuapp.com/values')

# Train model

In [None]:
def bce(y_true, y_pred):
    epsilon = 1e-7
    y_pred = tensor.clip(y_pred, epsilon, 1.0 - epsilon)
    val = tensor.nnet.binary_crossentropy(y_pred, y_true).sum(axis=-1) #.mean()
    #val = BinaryCrossEntropy().apply(y_true, y_pred)
    return val

In [None]:
def myreshape(x):
    y = x[0][:, -1, :, :, :]
    z = y.reshape((y.shape[0], 28*28))
    return tensor.nnet.sigmoid(z)

model = Graph()

model.add_input(name='input', ndim=4)
model.add_input(name='noise', ndim=3)
model.add_node(draw, name='draw', inputs=['input', 'noise'], merge_mode='join')
model.add_node(Lambda(lambda x: myreshape(x)),
               name='out', input='draw', create_output=True)

# model.get_output()  # make sure the regularizer is generated
# model.regularizers += draw.regularizers  # add reparametrization trick regularizer to the list
print(model.regularizers)

In [None]:
adam = Adam(lr=3e-4, clipnorm=10)
model.compile(loss={'out': bce}, optimizer=adam)

In [None]:
X = model.get_input()
Y = model.nodes['draw'].get_output()
do_draw = function(X.values(), tensor.nnet.sigmoid(Y[0]), allow_input_downcast=True)

In [None]:
for e in range(nb_epoch):
    print('-'*40)
    print('Epoch', e)
    print('-'*40)
    print("Training...")
    # batch train with realtime data augmentation
    progbar = generic_utils.Progbar(X_train.shape[0])
    for i in range(X_train.shape[0]/batch_size):
        s = i * batch_size
        l = (i+1) * batch_size
        X_batch = X_train[s:l]
        eps = np.random.normal(0, 1, (X_batch.shape[0], n_steps, z_dim))
        loss = model.train_on_batch({'input': X_batch,
                                     'noise': eps.astype('float32'),
                                     'out': X_batch.reshape(batch_size, -1)})
        progbar.add(X_batch.shape[0], values=[("train loss", loss)])
    
    #kl_train = get_kl(X_batch, eps)
    plotgen.add_value(loss)

    # print("")
    # print("Testing...")
    # test time!
    # progbar = generic_utils.Progbar(X_test.shape[0])
    # for i in range(X_test.shape[0]/batch_size):
    #    s = i * batch_size
    #    l = (i+1) * batch_size
    #    X_batch = X_test[s:l]
    #    eps = np.random.normal(0, 1, (X_batch.shape[0], n_steps+1, z_dim))
    #    score = model.test_on_batch({'input': X_batch,
    #                                 'noise': eps,
    #                                 'out': X_batch.reshape(batch_size, -1)})
    #    progbar.add(X_batch.shape[0], values=[("test loss", score)])
        
    
    # logging.info('epoch: {0} | train-kl: {1} | test-loss: {2}'.format(e, kl_train, score))
    
    rec = do_draw(X_batch[:100], eps[:100])
    rec = rec.reshape(rec.shape[0], rec.shape[1], -1)  
    vg = video_grid2d(rec.transpose(1, 0, 2), filepath='draw.gif', rescale=True)
    display.clear_output(wait=True)
    
    send_plot.on_epoch_end()
    send_gif.on_epoch_end()

In [19]:
model.save_weights('draw_agnez')

In [None]:
eps = np.random.normal(0, 1, (100, n_steps, z_dim))
rec = do_draw(X_train[:100], eps[:100])
rec = rec.reshape(rec.shape[0], rec.shape[1], -1)
display.clear_output(wait=True)
vg = video_grid2d(rec.transpose(1, 0, 2), filepath='draw.gif', rescale=True)
display.display(vg)

In [None]:
    send_plot.on_epoch_end()
    send_gif.on_epoch_end()