# DRAW

This code uses `seya.layers.draw`. It has been tested on the binary MNIST data succesfully.
Here we also show how to train the model on Cifar10, although we didn't have patience to 
wait for the model to converge. If you used this code on any dataset, other than MNIST, with success
please let us (post an issue or send us a message on twitter/edersantana).

This was originally tested under:  
`pip install git+https://github.com/fchollet/keras.git@ef43a271eeb25a59383be66a7079f77d226e0c3b`

In [26]:
%matplotlib inline
from __future__ import absolute_import
from __future__ import print_function
import logging
import numpy as np
import matplotlib.pyplot as plt

from theano import tensor, function

#from keras.datasets import mnist # this is not binary MNIST though
from keras.datasets import cifar10

# Binary MNIST use github.com/mila-udem/fuel
#from fuel.datasets.binarized_mnist import BinarizedMNIST

from keras.models import Graph, Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.utils import np_utils, generic_utils
from keras.optimizers import Adam
from keras.initializations import normal


from seya.layers.draw import DRAW
from seya.layers.draw2 import DRAW2
from seya.layers.base import Lambda
from agnez import img_grid

from IPython import display

In [3]:
# In case you close your notebook while training,
# this makes sure you at least still get some logging
logger = logging.getLogger()
hdlr = logging.FileHandler('./draw.log')
formatter = logging.Formatter('%(asctime)s %(levelname)s %(message)s')
hdlr.setFormatter(formatter)
logger.addHandler(hdlr) 
logger.setLevel(logging.INFO)

In [17]:
batch_size = 100
nb_epoch = 100
n_steps = 64

h_dim = 256*2
z_dim = 200
N_enc = 5
N_dec = 5
input_shape = (3, 32, 32)

# Values for MNIST
# h_dim = 256
# z_dim = 100
# N_enc = 2
# N_dec = 5
# input_shape = (1, 28, 28)

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
X_train = X_train / 255. # range [0, 1]
# Load Binary MNIST
#data_train = BinarizedMNIST(which_sets=['train'], sources=['features'])
#data_test = BinarizedMNIST(which_sets=['test'], sources=['features'])
#X_train = data_train.get_data(request=slice(0, 50000))[0].astype('float32')
#X_test = data_test.get_data(request=slice(0, 100))[0].astype('float32')

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')

X_train shape: (50000, 3, 32, 32)
50000 train samples


In [18]:
def myinit(shape):
    return normal(shape, scale=.01)

draw = DRAW(h_dim=h_dim, z_dim=z_dim, input_shape=input_shape, N_enc=N_enc, N_dec=N_dec,
            return_sequences=True, inner_rnn='gru', init=myinit, inner_init=myinit)

# Train model

In [20]:
# Use this cost function for MNIST
def bce(y_true, y_pred):
    epsilon = 1e-7
    y_pred = tensor.clip(y_pred, epsilon, 1.0 - epsilon)
    val = tensor.nnet.binary_crossentropy(y_pred, y_true).sum(axis=-1) #.mean()
    #val = BinaryCrossEntropy().apply(y_true, y_pred)
    return val

def renorm_mse(y_true, y_pred):
    rmse = (0.5 * (y_true-y_pred)**2).sum(axis=-1)
    return rmse


In [29]:
def myreshape(x):
    y = x[0][:, -1, :, :, :]
    z = y.reshape((y.shape[0], np.prod(input_shape)))
    return tensor.nnet.sigmoid(z)  # this is for MNIST
    # return z

model = Graph()

model.add_input(name='input', ndim=4)
model.add_input(name='noise', ndim=3)
model.add_node(draw, name='draw', inputs=['input', 'noise'], merge_mode='join')
model.add_node(Lambda(lambda x: myreshape(x)),
               name='out', input='draw', create_output=True)

# model.get_output()  # make sure the regularizer is generated
# model.regularizers += draw.regularizers  # add reparametrization trick regularizer to the list
print(model.regularizers)

[<seya.regularizers.SimpleCost object at 0x7f73e26fa510>]


In [30]:
adam = Adam(lr=3e-4, clipnorm=10)
model.compile(loss={'out': bce}, optimizer=adam)

In [31]:
X = model.get_input()
Y = model.nodes['draw'].get_output()
do_draw = function(X.values(), tensor.nnet.sigmoid(Y[0]), allow_input_downcast=True)

get_kl = function(X.values(), Y[1], allow_input_downcast=True)

In [None]:
import os
subdir = 'pics'
def savegifs(X_train):
    # adapted from github.com/jbornschein/draw
    samples = do_draw(X_train[:100], np.zeros((100, n_steps, z_dim)))
    for i in xrange(n_steps-1):
        fig = plt.figure()
        img = plt.imshow(img_grid(samples[:,i,:,:,:], (10, 10)))
        plt.axis('off')
        plt.savefig("{0}/time-{1:03d}.png".format(subdir, i))

        #with open("centers.pkl", "wb") as f:
        #    pikle.dump(f, (center_y, center_x, delta))
    os.system("convert -delay 5 {0}/time-*.png -delay 300 {0}/sample.png {0}/sequence.gif".format(subdir))

In [None]:
for e in range(nb_epoch):
    print('-'*40)
    print('Epoch', e)
    print('-'*40)
    print("Training...")
    # batch train with realtime data augmentation
    progbar = generic_utils.Progbar(X_train.shape[0])
    for i in range(X_train.shape[0]/batch_size):
        s = i * batch_size
        l = (i+1) * batch_size
        X_batch = X_train[s:l]
        eps = np.random.normal(0, 1, (X_batch.shape[0], n_steps, z_dim))
        loss = model.train_on_batch({'input': X_batch,
                                     'noise': eps.astype('float32'),
                                     'out': X_batch.reshape(batch_size, -1)})
        progbar.add(X_batch.shape[0], values=[("train loss", loss)])
    
    kl_train = get_kl(X_batch, eps)

    # No testing, we only want to see the DRAWs :D
    #print("")
    #print("Testing...")
    # test time!
    #progbar = generic_utils.Progbar(X_test.shape[0])
    #for i in range(X_test.shape[0]/batch_size):
    #    s = i * batch_size
    #    l = (i+1) * batch_size
    #    X_batch = X_test[s:l]
    #    eps = np.random.normal(0, 1, (X_batch.shape[0], n_steps+1, z_dim))
    #    score = model.test_on_batch({'input': X_batch,
    #                                 'noise': eps,
    #                                 'out': X_batch.reshape(batch_size, -1)})
    #    progbar.add(X_batch.shape[0], values=[("test loss", score)])
        
    
    logging.info('epoch: {0} | train-kl: {1}'.format(e, kl_train))
    
    #rec = do_draw(X_batch[:100], eps[:100])
    #rec = rec.reshape(rec.shape[0], rec.shape[1], -1)
    display.clear_output(wait=True)
    #vg = video_grid(rec.transpose(1, 0, 2), ani_path='draw.gif', rescale=True)
    #display.display(vg)
    savegifs(X_train[:100])

----------------------------------------
Epoch 0
----------------------------------------
Training...
  100/50000 [..............................] - ETA: 31399s - train loss: 2364.4294

<img src='main_img.gif'>

The figure above was generated with a model only a few epochs old.