In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

from pylab import rcParams
rcParams['figure.figsize'] = 30,5

num_epoch = 100
batch_size = 50

in_timesteps  = range(0,19)
out_timesteps = range(1,20)

data = np.load( "../../datasets/shapes_images.npy" )
x_tr = data[:9000,in_timesteps]
y_tr = data[:9000,out_timesteps]

x_te = data[9000:,in_timesteps]
y_te = data[9000:,out_timesteps]

tr_set = data[0:9000, :, :, :]
te_set = data[9000:, :, :, :]

print np.shape(x_tr), np.shape(y_tr), np.shape(x_te), np.shape(y_te)

(9000, 19, 64, 64, 3) (9000, 19, 64, 64, 3) (1000, 19, 64, 64, 3) (1000, 19, 64, 64, 3)


In [2]:
lstm_units = 1024
feature_vector = 1024
latent_dim = 1024

# placeholders to hold each frame
x_ = tf.placeholder("float", shape= (None, len(in_timesteps),  64, 64, 3))
y_ = tf.placeholder("float", shape= (None, len(out_timesteps), 64, 64, 3))
kla = tf.placeholder("float")

# encoder
encoder_conv1_w = tf.get_variable("encoder_conv1_w", shape=[7, 7, 3, 16])
encoder_conv2_w = tf.get_variable("encoder_conv2_w", shape=[5, 5, 16, 32])
encoder_conv3_w = tf.get_variable("encoder_conv3_w", shape=[5, 5, 32, 48])
encoder_conv4_w = tf.get_variable("encoder_conv4_w", shape=[3, 3, 48, 64])

encoder_conv1_b = tf.get_variable("encoder_conv1_b", shape=[16])
encoder_conv2_b = tf.get_variable("encoder_conv2_b", shape=[32])
encoder_conv3_b = tf.get_variable("encoder_conv3_b", shape=[48])
encoder_conv4_b = tf.get_variable("encoder_conv4_b", shape=[64])

def encoder(x):
    out = tf.nn.conv2d(input=x,   filter=encoder_conv1_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv1_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d(input=out, filter=encoder_conv2_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv2_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d(input=out, filter=encoder_conv3_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv3_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d(input=out, filter=encoder_conv4_w, strides=[1, 2, 2, 1], padding='SAME') + encoder_conv4_b
    out = tf.nn.relu(out)
    out = tf.reshape(out, shape=[-1, 4*4*64])
    return out

# decoder
decoder_conv1_w = tf.get_variable("decoder_conv1_w", shape=[3, 3, 48, 64])
decoder_conv2_w = tf.get_variable("decoder_conv2_w", shape=[5, 5, 32, 48])
decoder_conv3_w = tf.get_variable("decoder_conv3_w", shape=[5, 5, 16, 32])
decoder_conv4_w = tf.get_variable("decoder_conv4_w", shape=[7, 7, 3, 16])

decoder_conv1_b = tf.get_variable("decoder_conv1_b", shape=[48])
decoder_conv2_b = tf.get_variable("decoder_conv2_b", shape=[32])
decoder_conv3_b = tf.get_variable("decoder_conv3_b", shape=[16])
decoder_conv4_b = tf.get_variable("decoder_conv4_b", shape=[3])

def decoder(x):
    out = tf.reshape(x, shape=[-1, 4, 4, 64])
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv1_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 8, 8, 48], padding='SAME') + decoder_conv1_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv2_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 16, 16, 32], padding='SAME') + decoder_conv2_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv3_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 32, 32, 16], padding='SAME') + decoder_conv3_b
    out = tf.nn.relu(out)
    out = tf.nn.conv2d_transpose(out, filter=decoder_conv4_w, strides=[1, 2, 2, 1], output_shape=[batch_size, 64, 64, 3], padding='SAME') + decoder_conv4_b
    out = tf.nn.sigmoid(out)
    return out

# phi_enc
phi_enc_fc1_w = tf.get_variable("phi_enc_fc1_w", shape=[feature_vector+lstm_units, latent_dim])
phi_enc_fc2_w = tf.get_variable("phi_enc_fc2_w", shape=[latent_dim, latent_dim])

phi_enc_fc1_b = tf.get_variable("phi_enc_fc1_b", shape=[latent_dim])
phi_enc_fc2_b = tf.get_variable("phi_enc_fc2_b", shape=[latent_dim])

phi_enc_mu_w = tf.get_variable("phi_enc_mu_w", shape=[latent_dim, latent_dim])
phi_enc_mu_b = tf.get_variable("phi_enc_mu_b", shape=[latent_dim])

phi_enc_sigma_w = tf.get_variable("phi_enc_sigma_w", shape=[latent_dim, latent_dim])
phi_enc_sigma_b = tf.get_variable("phi_enc_sigma_b", shape=[latent_dim])

def phi_enc(out):
    out = tf.matmul(out, phi_enc_fc1_w) + phi_enc_fc1_b    
    out = tf.nn.relu(out)
    out = tf.matmul(out, phi_enc_fc2_w) + phi_enc_fc2_b
    out = tf.nn.relu(out)
    
    out_mu  = tf.matmul(out, phi_enc_mu_w)  + phi_enc_mu_b
    out_std = tf.nn.softplus(tf.matmul(out, phi_enc_sigma_w) + phi_enc_sigma_b)
    
    return out_mu, out_std

# phi_decoder
phi_dec_fc1_w = tf.get_variable("phi_dec_fc1_w", shape=[latent_dim+lstm_units, feature_vector])
phi_dec_fc2_w = tf.get_variable("phi_dec_fc2_w", shape=[feature_vector, feature_vector])

phi_dec_fc1_b = tf.get_variable("phi_dec_fc1_b", shape=[feature_vector])
phi_dec_fc2_b = tf.get_variable("phi_dec_fc2_b", shape=[feature_vector])

def phi_dec(out):
    out = tf.matmul(out, phi_dec_fc1_w) + phi_dec_fc1_b
    out = tf.nn.relu(out)
    out = tf.matmul(out, phi_dec_fc2_w) + phi_dec_fc2_b
    out = tf.nn.relu(out)
        
    return out

# phi_z
phi_z_fc1_w = tf.get_variable("phi_z_fc1_w", shape=[latent_dim, latent_dim])
phi_z_fc2_w = tf.get_variable("phi_z_fc2_w", shape=[latent_dim, latent_dim])

phi_z_fc1_b = tf.get_variable("phi_z_fc1_b", shape=[latent_dim])
phi_z_fc2_b = tf.get_variable("phi_z_fc2_b", shape=[latent_dim])

def phi_z(out):
    out = tf.matmul(out, phi_z_fc1_w) + phi_z_fc1_b
    out = tf.nn.relu(out)
    out = tf.matmul(out, phi_z_fc2_w) + phi_z_fc2_b
    out = tf.nn.relu(out)
    return out

# phi_prior
phi_prior_fc1_w = tf.get_variable("phi_prior_fc1_w", shape=[lstm_units, latent_dim])
phi_prior_fc2_w = tf.get_variable("phi_prior_fc2_w", shape=[latent_dim, latent_dim])

phi_prior_fc1_b = tf.get_variable("phi_prior_fc1_b", shape=[latent_dim])
phi_prior_fc2_b = tf.get_variable("phi_prior_fc2_b", shape=[latent_dim])

phi_prior_mu_w = tf.get_variable("phi_prior_mu_w", shape=[latent_dim, latent_dim])
phi_prior_std_w = tf.get_variable("phi_prior_std_w", shape=[latent_dim, latent_dim])

phi_prior_mu_b = tf.get_variable("phi_prior_mu_b", shape=[latent_dim])
phi_prior_std_b = tf.get_variable("phi_prior_std_b", shape=[latent_dim])

def phi_prior(out):
    out = tf.matmul(out, phi_prior_fc1_w) + phi_prior_fc1_b
    out = tf.nn.relu(out)
    out = tf.matmul(out, phi_prior_fc2_w) + phi_prior_fc2_b
    out = tf.nn.relu(out)
    
    out_mu  = tf.matmul(out, phi_prior_mu_w)  + phi_prior_mu_b
    out_std = tf.nn.softplus(tf.matmul(out, phi_prior_std_w) + phi_prior_std_b)
    
    return out_mu, out_std

def tf_kl_gaussgauss(mu_1, sigma_1, mu_2, sigma_2):
    return tf.reduce_sum(tf.log(sigma_2) - tf.log(sigma_1) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2*((sigma_2)**2)) - 0.5, axis=1)

def tf_kl_gaussgauss2(mu_1, sigma_1, mu_2, sigma_2):
    return tf.reduce_mean(tf.log(sigma_2) - tf.log(sigma_1) + (sigma_1**2 + (mu_1 - mu_2)**2) / (2*((sigma_2)**2)) - 0.5, axis=0)

def cross_entropy(y_prediction, y):
    prediction_loss = y * tf.log(1e-10 + y_prediction) + (1 - y) * tf.log(1e-10 + 1 - y_prediction)
    return -tf.reduce_sum(prediction_loss, axis=[1, 2, 3])   

def batch_data(source, target, batch_size):

   # Shuffle data
    shuffle_indices = np.random.permutation(np.arange(len(target)))
    source = source[shuffle_indices]
    target = target[shuffle_indices]

    for batch_i in range(0, len(source)//batch_size):
        start_i = batch_i * batch_size
        source_batch = source[start_i:start_i + batch_size]
        target_batch = target[start_i:start_i + batch_size]

        yield np.array(source_batch), np.array(target_batch)

# lstm
lstm  = tf.nn.rnn_cell.LSTMCell(num_units = lstm_units, state_is_tuple=True)
lstm_state = lstm.zero_state(batch_size, tf.float32)

In [3]:
loss_list = [None]*19
reconstruction_loss_list = [None]*19
kl_divergence_list = [None]*19
kl_divergence_list2 = [None]*19
y_hat_list = []
    
for i in range(0,len(in_timesteps)):
    
    phi_x_out = encoder(tf.divide(x=x_[:,i,:,:,:],y=255.0))

    phi_prior_out_mu, phi_prior_out_sigma = phi_prior(lstm_state[1])

    enc_out_mu, enc_out_sigma = phi_enc(tf.concat(values=(lstm_state[1], encoder(tf.divide(x=y_[:,i,:,:,:],y=255.0))), axis=1))
    z = enc_out_mu + enc_out_sigma * tf.random_normal(shape=[latent_dim], mean=0.0, stddev=1.0)
        
    phi_z_out = phi_z(z)

    phi_dec_out = phi_dec(tf.concat(values=(lstm_state[1], phi_z_out), axis=1)) 
    y_hat = decoder(phi_dec_out)
    y_hat_list.append(y_hat)

    lstm_out, lstm_state = lstm(inputs = tf.concat(values=(phi_x_out, phi_z_out), axis=1), state = lstm_state)
        
    kl_divergence_list[i] = tf_kl_gaussgauss(enc_out_mu, enc_out_sigma, phi_prior_out_mu, phi_prior_out_sigma)
    kl_divergence_list2[i] = tf_kl_gaussgauss2(enc_out_mu, enc_out_sigma, phi_prior_out_mu, phi_prior_out_sigma)
    tf.summary.scalar("kl_divergence_loss_" + str(i), tf.reduce_mean(kl_divergence_list[i]))
        
    reconstruction_loss_list[i] = cross_entropy(y_hat,tf.divide(x=y_[:,i,:,:,:],y=255.0))
    tf.summary.scalar("reconstruction_loss_" + str(i), tf.reduce_mean(reconstruction_loss_list[i]))
            
    loss_list[i] = tf.reduce_mean(kla * kl_divergence_list[i] + reconstruction_loss_list[i])
    tf.summary.scalar("total_loss_" + str(i), loss_list[i])

# Loss
loss = tf.stack(loss_list)
kl_divergence = tf.stack(kl_divergence_list)
kl_divergence2 = tf.stack(kl_divergence_list2)
reconstruction_loss = tf.stack(reconstruction_loss_list)
y_hat_out = tf.transpose(tf.stack(y_hat_list), [1, 0, 2, 3, 4])
    
tf.summary.scalar("total_loss_mean", tf.reduce_mean(loss))
tf.summary.scalar("kl_divergence_loss_mean", tf.reduce_mean(kl_divergence))
tf.summary.scalar("reconstruction_loss_mean", tf.reduce_mean(reconstruction_loss))
    
optimizer = tf.train.RMSPropOptimizer(learning_rate=0.001).minimize(loss)
    
train_counter = tf.Variable(0, name="train_counter", trainable=False)
increment_train_counter = tf.assign(train_counter, train_counter+1)

validation_counter = tf.Variable(0, name="validation_counter", trainable=False)
increment_validation_counter = tf.assign(validation_counter, train_counter+1)

In [4]:
# Summaries
merged_summary_op = tf.summary.merge_all()
train_summary_writer = tf.summary.FileWriter('./train/', graph=tf.get_default_graph())
validation_summary_writer = tf.summary.FileWriter('./validation/',graph=tf.get_default_graph())

# initialize all variables
init = tf.global_variables_initializer()

# to save variables
saver = tf.train.Saver(max_to_keep=100)

# Start a new TF session
sess = tf.Session()

# Run the initializer
sess.run(init)

kl_tr_all = []
kl_te_all = []

weight = 0.2

# Train
for i in range(1, 101):
    
    print 'Epoch ', i
    
    # Train
    for x_tr_batch, y_tr_batch in batch_data(x_tr, y_tr, batch_size=batch_size):
        _, step, summary, kl = sess.run([optimizer, increment_train_counter, merged_summary_op, kl_divergence2], feed_dict={x_: x_tr_batch, y_: y_tr_batch, kla: weight})
        train_summary_writer.add_summary(summary, step) 
        if(weight < 1.0):
            weight = weight + 0.0002
    kl_tr_all.append(kl)    
              
    for x_te_batch, y_te_batch in batch_data(x_te, y_te, batch_size=batch_size):
        step, summary, kl = sess.run([increment_validation_counter, merged_summary_op, kl_divergence2], feed_dict={x_: x_te_batch, y_: y_te_batch, kla: weight}) 
        validation_summary_writer.add_summary(summary, step)
    kl_te_all.append(kl)

    if(i%10 == 0):
        save_path = saver.save(sess, 'epoch', i)
        
np.save("kl_tr_all",kl_tr_all)
np.save("kl_te_all",kl_te_all)

Epoch  1
Epoch  2
Epoch  3
Epoch  4
Epoch  5
Epoch  6
Epoch  7
Epoch  8
Epoch  9
Epoch  10
Epoch  11
Epoch  12
Epoch  13
Epoch  14
Epoch  15
Epoch  16
Epoch  17
Epoch  18
Epoch  19
Epoch  20
Epoch  21
Epoch  22
Epoch  23
Epoch  24
Epoch  25
Epoch  26
Epoch  27
Epoch  28
Epoch  29
Epoch  30
Epoch  31
Epoch  32
Epoch  33
Epoch  34
Epoch  35
Epoch  36
Epoch  37
Epoch  38
Epoch  39
Epoch  40
Epoch  41
Epoch  42
Epoch  43
Epoch  44
Epoch  45
Epoch  46
Epoch  47
Epoch  48
Epoch  49
Epoch  50
Epoch  51
Epoch  52
Epoch  53
Epoch  54
Epoch  55
Epoch  56
Epoch  57
Epoch  58
Epoch  59
Epoch  60
Epoch  61
Epoch  62
Epoch  63
Epoch  64
Epoch  65
Epoch  66
Epoch  67
Epoch  68
Epoch  69
Epoch  70
Epoch  71
Epoch  72
Epoch  73
Epoch  74
Epoch  75
Epoch  76
Epoch  77
Epoch  78
Epoch  79
Epoch  80
Epoch  81
Epoch  82
Epoch  83
Epoch  84
Epoch  85
Epoch  86
Epoch  87
Epoch  88
Epoch  89
Epoch  90
Epoch  91
Epoch  92
Epoch  93
Epoch  94
Epoch  95
Epoch  96
Epoch  97
Epoch  98
Epoch  99
Epoch  100
