In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pylab import rcParams
rcParams['figure.figsize'] = 30,5

# Load Dataset

In [None]:
num_epoch = 100
batch_size = 50

in_timesteps  = range(0,19)
out_timesteps = range(1,20)

data = np.load( "../datasets/shapes_images.npy" )
x_tr = data[:10000,in_timesteps]
y_tr = data[:10000,out_timesteps]

x_te = data[10000:,in_timesteps]
y_te = data[10000:,out_timesteps]

tr_set = data[0:10000, :, :, :]
te_set = data[10000:, :, :, :]

print np.shape(x_tr), np.shape(y_tr), np.shape(x_te), np.shape(y_te)

# Define Parameters

In [None]:
lstm_units = 1024
feature_vector = 1024
latent_dim = 256

# placeholders to hold each frame
x_ = tf.placeholder("float", shape= (None, len(in_timesteps),  64, 64, 3))
y_ = tf.placeholder("float", shape= (None, len(out_timesteps), 64, 64, 3))

# encoder
encoder_conv1_w = tf.get_variable("encoder_conv1_w", shape=[7, 7, 3, 16])
encoder_conv2_w = tf.get_variable("encoder_conv2_w", shape=[5, 5, 16, 32])
encoder_conv3_w = tf.get_variable("encoder_conv3_w", shape=[5, 5, 32, 48])
encoder_conv4_w = tf.get_variable("encoder_conv4_w", shape=[3, 3, 48, 64])

encoder_conv1_b = tf.get_variable("encoder_conv1_b", shape=[16])
encoder_conv2_b = tf.get_variable("encoder_conv2_b", shape=[32])
encoder_conv3_b = tf.get_variable("encoder_conv3_b", shape=[48])
encoder_conv4_b = tf.get_variable("encoder_conv4_b", shape=[64])

def encoder(x):
    out = tf.nn.conv2d(input=x,   filter=encoder_conv1_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv1_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d(input=out, filter=encoder_conv2_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv2_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d(input=out, filter=encoder_conv3_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv3_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d(input=out, filter=encoder_conv4_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv4_b
    out = tf.nn.relu(out)
    out = tf.reshape(out, shape=[-1, 4*4*64])
    return out

# decoder
decoder_conv1_w = tf.get_variable("decoder_conv1_w", shape=[3, 3, 48, 64])
decoder_conv2_w = tf.get_variable("decoder_conv2_w", shape=[5, 5, 32, 48])
decoder_conv3_w = tf.get_variable("decoder_conv3_w", shape=[5, 5, 16, 32])
decoder_conv4_w = tf.get_variable("decoder_conv4_w", shape=[7, 7, 3, 16])

decoder_conv1_b = tf.get_variable("decoder_conv1_b", shape=[48])
decoder_conv2_b = tf.get_variable("decoder_conv2_b", shape=[32])
decoder_conv3_b = tf.get_variable("decoder_conv3_b", shape=[16])
decoder_conv4_b = tf.get_variable("decoder_conv4_b", shape=[3])

def decoder(x):
    out = tf.reshape(x, shape=[-1, 4, 4, 64])
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv1_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 8, 8, 48], padding='SAME') + decoder_conv1_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv2_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 16, 16, 32], padding='SAME') + decoder_conv2_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv3_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 32, 32, 16], padding='SAME') + decoder_conv3_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv4_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 64, 64, 3], padding='SAME') + decoder_conv4_b
    out = tf.nn.sigmoid(out)
    return out

# f_posterior
f_posterior_fc1_w = tf.get_variable("f_posterior_fc1_w", shape=[feature_vector+lstm_units, latent_dim])
f_posterior_fc2_w = tf.get_variable("f_posterior_fc2_w", shape=[latent_dim, latent_dim])

f_posterior_fc1_b = tf.get_variable("f_posterior_fc1_b", shape=[latent_dim])
f_posterior_fc2_b = tf.get_variable("f_posterior_fc2_b", shape=[latent_dim])

f_posterior_mu_w = tf.get_variable("f_posterior_mu_w", shape=[latent_dim, latent_dim])
f_posterior_mu_b = tf.get_variable("f_posterior_mu_b", shape=[latent_dim])

f_posterior_sigma_w = tf.get_variable("f_posterior_sigma_w", shape=[latent_dim, latent_dim])
f_posterior_sigma_b = tf.get_variable("f_posterior_sigma_b", shape=[latent_dim])

def f_posterior(out):
    out = tf.matmul(out, f_posterior_fc1_w) + f_posterior_fc1_b    
    out = tf.nn.relu(out)
    out = tf.matmul(out, f_posterior_fc2_w) + f_posterior_fc2_b
    out = tf.nn.relu(out)
    
    out_mu  = tf.matmul(out, f_posterior_mu_w)  + f_posterior_mu_b
    out_std = tf.nn.softplus(tf.matmul(out, f_posterior_sigma_w) + f_posterior_sigma_b)
    
    return out_mu, out_std

# f_decoder
f_decoder_fc1_w = tf.get_variable("f_decoder_fc1_w", shape=[latent_dim+lstm_units, feature_vector])
f_decoder_fc2_w = tf.get_variable("f_decoder_fc2_w", shape=[feature_vector, feature_vector])

f_decoder_fc1_b = tf.get_variable("f_decoder_fc1_b", shape=[feature_vector])
f_decoder_fc2_b = tf.get_variable("f_decoder_fc2_b", shape=[feature_vector])

def f_decoder(out):
    out = tf.matmul(out, f_decoder_fc1_w) + f_decoder_fc1_b
    out = tf.nn.relu(out)
    out = tf.matmul(out, f_decoder_fc2_w) + f_decoder_fc2_b
    out = tf.nn.relu(out)
        
    return out

# f_z
f_z_fc1_w = tf.get_variable("f_z_fc1_w", shape=[latent_dim, latent_dim])
f_z_fc2_w = tf.get_variable("f_z_fc2_w", shape=[latent_dim, latent_dim])

f_z_fc1_b = tf.get_variable("f_z_fc1_b", shape=[latent_dim])
f_z_fc2_b = tf.get_variable("f_z_fc2_b", shape=[latent_dim])

def f_z(out):
    out = tf.matmul(out, f_z_fc1_w) + f_z_fc1_b
    out = tf.nn.relu(out)
    out = tf.matmul(out, f_z_fc2_w) + f_z_fc2_b
    out = tf.nn.relu(out)
    return out

# f_prior
f_prior_fc1_w = tf.get_variable("f_prior_fc1_w", shape=[lstm_units, latent_dim])
f_prior_fc2_w = tf.get_variable("f_prior_fc2_w", shape=[latent_dim, latent_dim])

f_prior_fc1_b = tf.get_variable("f_prior_fc1_b", shape=[latent_dim])
f_prior_fc2_b = tf.get_variable("f_prior_fc2_b", shape=[latent_dim])

f_prior_mu_w = tf.get_variable("f_prior_mu_w", shape=[latent_dim, latent_dim])
f_prior_std_w = tf.get_variable("f_prior_std_w", shape=[latent_dim, latent_dim])

f_prior_mu_b = tf.get_variable("f_prior_mu_b", shape=[latent_dim])
f_prior_std_b = tf.get_variable("f_prior_std_b", shape=[latent_dim])

def f_prior(out):
    out = tf.matmul(out, f_prior_fc1_w) + f_prior_fc1_b
    out = tf.nn.relu(out)
    out = tf.matmul(out, f_prior_fc2_w) + f_prior_fc2_b
    out = tf.nn.relu(out)
    
    out_mu  = tf.matmul(out, f_prior_mu_w)  + f_prior_mu_b
    out_std = tf.nn.softplus(tf.matmul(out, f_prior_std_w) + f_prior_std_b)
    
    return out_mu, out_std

def tf_kl_gaussgauss(mu_1, sigma_1, mu_2, sigma_2):
    return tf.reduce_sum(tf.log(sigma_2) - tf.log(sigma_1) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2*((sigma_2)**2)) - 0.5, axis=1)

def cross_entropy(y_prediction, y):
    prediction_loss = y * tf.log(1e-10 + y_prediction) + (1 - y) * tf.log(1e-10 + 1 - y_prediction)
    return -tf.reduce_sum(prediction_loss, axis=[1, 2, 3])   

def batch_data(source, target, batch_size):

   # Shuffle data
    shuffle_indices = np.random.permutation(np.arange(len(target)))
    source = source[shuffle_indices]
    target = target[shuffle_indices]

    for batch_i in range(0, len(source)//batch_size):
        start_i = batch_i * batch_size
        source_batch = source[start_i:start_i + batch_size]
        target_batch = target[start_i:start_i + batch_size]

        yield np.array(source_batch), np.array(target_batch)

# lstm
lstm  = tf.nn.rnn_cell.LSTMCell(num_units = lstm_units, state_is_tuple=True)
lstm_state = lstm.zero_state(batch_size, tf.float32)

# Define Architecture

In [None]:
loss_list = [None]*19
kl_divergence_list = [None]*19
reconstruction_loss_list = [None]*19
    
for i in range(0,len(in_timesteps)):
    
    # encode image
    encoder_out = encoder(tf.divide(x=x_[:,i,:,:,:],y=255.0))

    # compute prior
    f_prior_out_mu, f_prior_out_sigma = f_prior(lstm_state[1])

    # compute posterior
    f_posterior_out_mu, f_posterior_out_sigma = f_posterior(tf.concat(values=(lstm_state[1], encoder(tf.divide(x=y_[:,i,:,:,:],y=255.0))), axis=1))
    
    # sample from posterior 
    z = f_posterior_out_mu + f_posterior_out_sigma * tf.random_normal(shape=[256], mean=0.0, stddev=1.0)        
    f_z_out = f_z(z)

    # decode [lstm, latent information]
    f_decoder_out = f_decoder(tf.concat(values=(lstm_state[1], f_z_out), axis=1)) 
    y_hat = decoder(f_decoder_out)
    
    # lstm state transition
    lstm_out, lstm_state = lstm(inputs = tf.concat(values=(encoder_out, f_z_out), axis=1), state = lstm_state)
        
    # track divergence of current timestep
    kl_divergence_list[i] = tf_kl_gaussgauss(f_posterior_out_mu, f_posterior_out_sigma, f_prior_out_mu, f_prior_out_sigma)    
    tf.summary.scalar("kl_divergence_loss_" + str(i), tf.reduce_mean(kl_divergence_list[i]))
                
    # track reconstruction loss of current timestep
    reconstruction_loss_list[i] = cross_entropy(y_hat,tf.divide(x=y_[:,i,:,:,:],y=255.0))
    tf.summary.scalar("reconstruction_loss_" + str(i), tf.reduce_mean(reconstruction_loss_list[i]))
         
    # track total loss of current timestep
    loss_list[i] = tf.reduce_mean(kl_divergence_list[i] + reconstruction_loss_list[i])
    tf.summary.scalar("total_loss_" + str(i), loss_list[i])
    
# optimize loss and track its mean across the 19 timesteps
loss = tf.stack(loss_list)
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(loss)
tf.summary.scalar("total_loss_mean", tf.reduce_mean(loss))
    
# track mean of kl divergence across the 19 timesteps
kl_divergence = tf.stack(kl_divergence_list)
tf.summary.scalar("kl_divergence_loss_mean", tf.reduce_mean(kl_divergence))

# track mean of reconstruction loss across the 19 timesteps
reconstruction_loss = tf.stack(reconstruction_loss_list)
tf.summary.scalar("reconstruction_loss_mean", tf.reduce_mean(reconstruction_loss))
    
# train and validation counters
train_counter = tf.Variable(0, name="train_counter", trainable=False)
increment_train_counter = tf.assign(train_counter, train_counter+1)
validation_counter = tf.Variable(0, name="validation_counter", trainable=False)
increment_validation_counter = tf.assign(validation_counter, train_counter+1)

# Begin Training

In [None]:
# Summaries
merged_summary_op = tf.summary.merge_all()
train_summary_writer = tf.summary.FileWriter('./train/', graph=tf.get_default_graph())
validation_summary_writer = tf.summary.FileWriter('./validation/',graph=tf.get_default_graph())

# initialize all variables
init = tf.global_variables_initializer()

# to save variables
saver = tf.train.Saver(max_to_keep=10)

# Start a new TF session
sess = tf.Session()

# Run the initializer
sess.run(init)

# Train
for i in range(1, 101):
    
    print 'Epoch ', i
    
    # Train
    for x_tr_batch, y_tr_batch in batch_data(x_tr, y_tr, batch_size=batch_size):
        _, step, summary, kl = sess.run([optimizer, increment_train_counter, merged_summary_op, kl_divergence2], feed_dict={x_: x_tr_batch, y_: y_tr_batch})
        train_summary_writer.add_summary(summary, step)     
              
    # Validate
    for x_te_batch, y_te_batch in batch_data(x_te, y_te, batch_size=batch_size):
        step, summary, kl = sess.run([increment_validation_counter, merged_summary_op, kl_divergence2], feed_dict={x_: x_te_batch, y_: y_te_batch}) 
        validation_summary_writer.add_summary(summary, step)
    
    # Save model
    if(i%10 == 0):
        save_path = saver.save(sess, 'epoch', i)