In [43]:
import numpy as np

# Generate "test" data
z_dim = 32
a_dim = 3
input_dim = z_dim + a_dim

z_t0 = np.ones((1, z_dim))
z_t1 = np.ones((1, z_dim))
z_t2 = np.ones((1, z_dim))

a_t0 = np.ones((1, 3))
a_t1 = np.ones((1, 3))
a_t2 = np.ones((1, 3))

za_t0 = np.hstack([z_t0, a_t0])
za_t1 = np.hstack([z_t1, a_t1])

X = np.vstack([za_t0, za_t1])
Y = np.vstack([z_t1, z_t1])

X = np.expand_dims(X, axis=1)
Y = np.expand_dims(Y, axis=1)
X.shape

(2, 1, 35)

In [61]:
# https://worldmodels.github.io/
# https://arxiv.org/abs/1704.03477
# http://blog.otoro.net/2015/11/24/mixture-density-networks-with-tensorflow/

"""
TODO:
    - reset LSTM states between sequences
    - handle batches
    - teacher forcing

X - z_0 + a_0
Y - z_1
"""
    
import math

import tensorflow as tf
from keras.layers import Input, LSTM, Dense
from keras.models import Model
import keras.backend as K

lstm_units = 256
gaussian_mixtures = 5
mdn_units = gaussian_mixtures * 3 * z_dim
num_epochs = 20

def get_mixture_coef(output):
    """
    output should be size of num_mixtures * 3
    """
    pi, sigma, mu = tf.split(output, 3, axis=1)
    
    max_pi = tf.reduce_max(pi, 1, keepdims=True)
    pi = tf.subtract(pi, max_pi)
    pi = tf.exp(pi)

    normalize_pi = tf.reciprocal(tf.reduce_sum(pi, 1, keepdims=True))
    pi = tf.multiply(normalize_pi, pi) 
    
    sigma = tf.exp(sigma)
    
    return pi, sigma, mu
    
one_div_sqrt_2_pi = 1 / math.sqrt(2*math.pi)
def pdf(z, mu, sigma):
    z = tf.tile(z, [1, 1, gaussian_mixtures])
    result = tf.subtract(z, mu)
    result = tf.multiply(result,tf.reciprocal(sigma))
    result = -tf.square(result)/2
    return tf.multiply(tf.exp(result),tf.reciprocal(sigma))*one_div_sqrt_2_pi

def get_loss(y, output):
    z = tf.slice(y, [0, 0, 0], [-1, -1, 32])
    pi, sigma, mu = get_mixture_coef(output)
    result = pdf(z, mu, sigma)
    result = tf.multiply(result, pi)
    result = tf.reduce_sum(result, 1, keep_dims=True)
    result = -tf.log(result)
    return tf.reduce_mean(result) 
    

sess = tf.Session()
K.set_session(sess)
    
x = tf.placeholder(tf.float32, shape=(None, 1, input_dim), name='x')
y = tf.placeholder(tf.float32, shape=(None, 1, z_dim), name='y')
    
lstm = LSTM(lstm_units)(x)
outputs = Dense(mdn_units, name='rnn_mdn_out')(lstm)

loss = get_loss(y, outputs)
train_op = tf.train.AdamOptimizer().minimize(loss)

sess.run(tf.global_variables_initializer())

loss_hist = np.zeros(num_epochs)
for i in range(num_epochs):
    sess.run(train_op, feed_dict={x: X, y: Y})
    loss_hist[i] = sess.run(loss, feed_dict={x: X, y: Y})
    print('Epoch %d: %f' % (i, loss_hist[i]))

sess.close()

Epoch 0: 5.790408
Epoch 1: 5.769345
Epoch 2: 5.747546
Epoch 3: 5.723546
Epoch 4: 5.696173
Epoch 5: 5.664383
Epoch 6: 5.627168
Epoch 7: 5.583515
Epoch 8: 5.532384
Epoch 9: 5.472695
Epoch 10: 5.403333
Epoch 11: 5.323205
Epoch 12: 5.231420
Epoch 13: 5.127666
Epoch 14: 5.012881
Epoch 15: 4.890414
Epoch 16: 4.767797
Epoch 17: 4.658590
Epoch 18: 4.581739
Epoch 19: 4.551516
Epoch 20: 4.548862
Epoch 21: 4.516861
Epoch 22: 4.426867
Epoch 23: 4.298341
Epoch 24: 4.168011
Epoch 25: 4.065253
Epoch 26: 3.999393
Epoch 27: 3.958487
Epoch 28: 3.917975
Epoch 29: 3.850672
Epoch 30: 3.736504
Epoch 31: 3.576889
Epoch 32: 3.414780
Epoch 33: 3.323528
Epoch 34: 3.260284
Epoch 35: 3.094219
Epoch 36: 2.924816
Epoch 37: 2.848580
Epoch 38: 2.698646
Epoch 39: 2.535625
Epoch 40: 2.441678
Epoch 41: 2.246463
Epoch 42: 2.168273
Epoch 43: 2.033905
Epoch 44: 1.839057
Epoch 45: 1.815007
Epoch 46: 1.757649
Epoch 47: 2.070607
Epoch 48: 1.951027
Epoch 49: 1.841599
Epoch 50: 1.643091
Epoch 51: 1.641840
Epoch 52: 1.512681
Epo

Epoch 454: nan
Epoch 455: nan
Epoch 456: nan
Epoch 457: nan
Epoch 458: nan
Epoch 459: nan
Epoch 460: nan
Epoch 461: nan
Epoch 462: nan
Epoch 463: nan
Epoch 464: nan
Epoch 465: nan
Epoch 466: nan
Epoch 467: nan
Epoch 468: nan
Epoch 469: nan
Epoch 470: nan
Epoch 471: nan
Epoch 472: nan
Epoch 473: nan
Epoch 474: nan
Epoch 475: nan
Epoch 476: nan
Epoch 477: nan
Epoch 478: nan
Epoch 479: nan
Epoch 480: nan
Epoch 481: nan
Epoch 482: nan
Epoch 483: nan
Epoch 484: nan
Epoch 485: nan
Epoch 486: nan
Epoch 487: nan
Epoch 488: nan
Epoch 489: nan
Epoch 490: nan
Epoch 491: nan
Epoch 492: nan
Epoch 493: nan
Epoch 494: nan
Epoch 495: nan
Epoch 496: nan
Epoch 497: nan
Epoch 498: nan
Epoch 499: nan
Epoch 500: nan
Epoch 501: nan
Epoch 502: nan
Epoch 503: nan
Epoch 504: nan
Epoch 505: nan
Epoch 506: nan
Epoch 507: nan
Epoch 508: nan
Epoch 509: nan
Epoch 510: nan
Epoch 511: nan
Epoch 512: nan
Epoch 513: nan
Epoch 514: nan
Epoch 515: nan
Epoch 516: nan
Epoch 517: nan
Epoch 518: nan
Epoch 519: nan
Epoch 520: