# Visual Dynamics

This code is the implementation of `Visual Dynamics: Probabilistic Future Frame Synthesis via Cross Convolutional Networks`

In [2]:
import numpy as np
import tensorflow as tf
import os
import cv2

In [3]:
def load_data(img_dir, ending):
    return np.array([cv2.imread(os.path.join(img_dir, img)) for img in os.listdir(img_dir) if img.endswith(ending)])

In [4]:
def define_graph(img1, img2, batch_size):

    #img1_128 = tf.nn.max_pool(img1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    #img2_128 = tf.nn.max_pool(img2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    img1_64 = tf.nn.max_pool(img1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    img2_64 = tf.nn.max_pool(img2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    img1_32 = tf.nn.max_pool(img1, ksize=[1,4,4,1], strides=[1,4,4,1], padding='VALID')

    #Motion encoder
    #x1 = tf.concat([img1_128,img2_128], axis=3)

    #First convolution: 5x5x96
    weights = tf.Variable(tf.random_normal([5,5,3,96]))
    bias = tf.Variable(tf.zeros([96,]))
    logits = tf.nn.conv2d(img1, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    logits = tf.nn.relu(logits)

    #Second convolution: 5x5x96
    weights = tf.Variable(tf.random_normal([5,5,96,96]))
    bias = tf.Variable(tf.zeros([96,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    logits = tf.nn.max_pool(logits, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits = tf.nn.relu(logits)

    #Third convolution: 5x5x128
    weights = tf.Variable(tf.random_normal([5,5,96,128]))
    bias = tf.Variable(tf.zeros([128,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    logits = tf.nn.max_pool(logits, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits = tf.nn.relu(logits)

    #Forth convolution: 5x5x128
    weights = tf.Variable(tf.random_normal([5,5,128,128]))
    bias = tf.Variable(tf.zeros([128,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='VALID')
    logits = tf.add(logits, bias)
    logits = tf.nn.max_pool(logits, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits = tf.nn.relu(logits)

    #Fifth convolution: 5x5x256
    weights = tf.Variable(tf.random_normal([5,5,128,256]))
    bias = tf.Variable(tf.zeros([256,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='VALID')
    logits = tf.add(logits, bias)
    logits = tf.nn.max_pool(logits, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits = tf.nn.relu(logits)

    #Sixth convolution: 5x5x256
    weights = tf.Variable(tf.random_normal([5,5,256,256]))
    bias = tf.Variable(tf.zeros([256,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    logits = tf.nn.relu(logits)

    # logits = tf.contrib.layers.flatten(logits) won't work because the sahpe is [None, ...]
    logits_flatten = tf.reshape(logits, shape=[logits.shape.as_list()[0],-1])

    mean, std_log = tf.split(logits_flatten, 2, axis=1)
    epsilon = tf.random_normal(mean.shape.as_list(),0,1, dtype=tf.float32)
    kernel = mean + tf.multiply(tf.exp(std_log),epsilon)

    # TODO: sample a z and continue the kernel_decoder with this z sampled

    kernel = tf.reshape(kernel, shape=[batch_size,5,5,-1]) # 5,5,128

    kernel_w = tf.Variable(tf.random_normal([5,5,128,128]))
    kernel_bias = tf.Variable(tf.random_normal([128]))
    kernel = tf.nn.conv2d(kernel, filter=kernel_w, strides=[1,1,1,1], padding='SAME')
    kernel = tf.add(kernel, kernel_bias)
    kernel = tf.nn.relu(kernel)

    kernel_w = tf.Variable(tf.random_normal([5,5,128,128]))
    kernel_bias = tf.Variable(tf.random_normal([128]))
    kernel = tf.nn.conv2d(kernel, filter=kernel_w, strides=[1,1,1,1], padding='SAME')
    kernel = tf.add(kernel, kernel_bias)
    kernel = tf.nn.relu(kernel)

    set1, set2, set3, set4 = tf.split(kernel, 4, axis=3)

    #Img encoder

    # for img 128x128
    logits_img128_encoder_w = tf.Variable(tf.random_normal([5,5,3,64])) 
    logits_img128_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img128_encoder = tf.nn.conv2d(img1, filter=logits_img128_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img128_encoder = tf.add(logits_img128_encoder, logits_img128_encoder_b)
    logits_img128_encoder = tf.nn.relu(logits_img128_encoder)

    logits_img128_encoder_w = tf.Variable(tf.random_normal([5,5,64,64])) 
    logits_img128_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img128_encoder = tf.nn.conv2d(logits_img128_encoder, filter=logits_img128_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img128_encoder = tf.add(logits_img128_encoder, logits_img128_encoder_b)
    logits_img128_encoder = tf.nn.max_pool(logits_img128_encoder, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits_img128_encoder = tf.nn.relu(logits_img128_encoder)

    logits_img128_encoder_w = tf.Variable(tf.random_normal([5,5,64,64])) 
    logits_img128_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img128_encoder = tf.nn.conv2d(logits_img128_encoder, filter=logits_img128_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img128_encoder = tf.add(logits_img128_encoder, logits_img128_encoder_b)
    logits_img128_encoder = tf.nn.max_pool(logits_img128_encoder, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits_img128_encoder = tf.nn.relu(logits_img128_encoder)

    logits_img128_encoder_w = tf.Variable(tf.random_normal([5,5,64,32])) 
    logits_img128_encoder_b = tf.Variable(tf.zeros([32]))
    logits_img128_encoder = tf.nn.conv2d(logits_img128_encoder, filter=logits_img128_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img128_encoder = tf.add(logits_img128_encoder, logits_img128_encoder_b)
    logits_img128_encoder = tf.nn.relu(logits_img128_encoder)

    # for img 64x64
    logits_img64_encoder_w = tf.Variable(tf.random_normal([5,5,3,64])) 
    logits_img64_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img64_encoder = tf.nn.conv2d(img1_64, filter=logits_img64_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img64_encoder = tf.add(logits_img64_encoder, logits_img64_encoder_b)
    logits_img64_encoder = tf.nn.relu(logits_img64_encoder)

    logits_img64_encoder_w = tf.Variable(tf.random_normal([5,5,64,64])) 
    logits_img64_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img64_encoder = tf.nn.conv2d(logits_img64_encoder, filter=logits_img64_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img64_encoder = tf.add(logits_img64_encoder, logits_img64_encoder_b)
    logits_img64_encoder = tf.nn.max_pool(logits_img64_encoder, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits_img64_encoder = tf.nn.relu(logits_img64_encoder)

    logits_img64_encoder_w = tf.Variable(tf.random_normal([5,5,64,64])) 
    logits_img64_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img64_encoder = tf.nn.conv2d(logits_img64_encoder, filter=logits_img64_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img64_encoder = tf.add(logits_img64_encoder, logits_img64_encoder_b)
    logits_img64_encoder = tf.nn.max_pool(logits_img64_encoder, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits_img64_encoder = tf.nn.relu(logits_img64_encoder)

    logits_img64_encoder_w = tf.Variable(tf.random_normal([5,5,64,32])) 
    logits_img64_encoder_b = tf.Variable(tf.zeros([32]))
    logits_img64_encoder = tf.nn.conv2d(logits_img64_encoder, filter=logits_img64_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img64_encoder = tf.add(logits_img64_encoder, logits_img64_encoder_b)
    logits_img64_encoder = tf.nn.relu(logits_img64_encoder)

    # for 32x32
    logits_img32_encoder_w = tf.Variable(tf.random_normal([5,5,3,64])) 
    logits_img32_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img32_encoder = tf.nn.conv2d(img1_32, filter=logits_img32_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_encoder = tf.add(logits_img32_encoder, logits_img32_encoder_b)
    logits_img32_encoder = tf.nn.relu(logits_img32_encoder)

    logits_img32_encoder_w = tf.Variable(tf.random_normal([5,5,64,64])) 
    logits_img32_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img32_encoder = tf.nn.conv2d(logits_img32_encoder, filter=logits_img32_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_encoder = tf.add(logits_img32_encoder, logits_img32_encoder_b)
    logits_img32_encoder = tf.nn.max_pool(logits_img32_encoder, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits_img32_encoder = tf.nn.relu(logits_img32_encoder)

    logits_img32_encoder_w = tf.Variable(tf.random_normal([5,5,64,64])) 
    logits_img32_encoder_b = tf.Variable(tf.zeros([64]))
    logits_img32_encoder = tf.nn.conv2d(logits_img32_encoder, filter=logits_img32_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_encoder = tf.add(logits_img32_encoder, logits_img32_encoder_b)
    logits_img32_encoder = tf.nn.max_pool(logits_img32_encoder, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits_img32_encoder = tf.nn.relu(logits_img32_encoder)

    logits_img32_encoder_w = tf.Variable(tf.random_normal([5,5,64,32])) 
    logits_img32_encoder_b = tf.Variable(tf.zeros([32]))
    logits_img32_encoder = tf.nn.conv2d(logits_img32_encoder, filter=logits_img32_encoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_encoder = tf.add(logits_img32_encoder, logits_img32_encoder_b)
    logits_img32_encoder = tf.nn.relu(logits_img32_encoder)

    # for each entry in every logits_img[]_encoder apply the cross convolution

    logits_32 = tf.unstack(logits_img32_encoder, axis=0)
    set1_unstack = tf.unstack(set1, axis=0)

    convolved32 = []
    for i in range(len(logits_32)):
        convolved32.append(tf.nn.depthwise_conv2d(input=tf.expand_dims(logits_32[i],0), filter=tf.expand_dims(set1_unstack[i],3), strides= [1,1,1,1], padding='SAME'))

    logits_img32_crossconvolved = tf.concat(convolved32, axis=0)

    # TODO: do the cross convolutions to all three sizes

    # Motion decoder
    logits_img32_crossconvolved = tf.image.resize_images(logits_img32_crossconvolved, size = [64,64])

    logits_img32_decoder_w = tf.Variable(tf.random_normal([9,9,32,128])) 
    logits_img32_decoder_b = tf.Variable(tf.zeros([128]))
    logits_img32_decoder = tf.nn.conv2d(logits_img32_crossconvolved, filter=logits_img32_decoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_decoder = tf.add(logits_img32_decoder, logits_img32_decoder_b)
    logits_img32_decoder = tf.nn.relu(logits_img32_decoder)

    logits_img32_decoder_w = tf.Variable(tf.random_normal([1,1,128,128])) 
    logits_img32_decoder_b = tf.Variable(tf.zeros([128]))
    logits_img32_decoder = tf.nn.conv2d(logits_img32_decoder, filter=logits_img32_decoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_decoder = tf.add(logits_img32_decoder, logits_img32_decoder_b)
    logits_img32_decoder = tf.nn.relu(logits_img32_decoder)

    logits_img32_decoder_w = tf.Variable(tf.random_normal([1,1,128,3])) 
    logits_img32_decoder_b = tf.Variable(tf.zeros([3]))
    logits_img32_decoder = tf.nn.conv2d(logits_img32_decoder, filter=logits_img32_decoder_w, strides=[1,1,1,1], padding='SAME')
    logits_img32_decoder = tf.add(logits_img32_decoder, logits_img32_decoder_b)
    logits_img32_decoder = tf.nn.relu(logits_img32_decoder)

    diff_output = logits_img32_decoder
    
    return diff_output, img2_64, mean, std_log

def train_model(prediction,y,mean,std_log):

    l2_loss = tf.reduce_mean(tf.square(prediction - y))
    reconstr_loss = (-tf.reduce_mean(y * (1e-10 + prediction) + (1-y) * tf.log(1e-10 + 1 - prediction)))
    kl_loss = (0.5 * tf.reduce_mean(tf.square(mean) + tf.square(tf.exp(std_log)) - 2 * std_log - 1))

    loss = l2_loss + reconstr_loss + kl_loss

    optimizer = tf.train.GradientDescentOptimizer(0.01)
    train = optimizer.minimize(loss)
    
    return train, loss
    

In [5]:
def run(X, Y, n_epochs=10, batch_size=100):

    
    
    n_examples = X.shape[0]
    print("there is ", n_examples, "examples")
    
    img1 = tf.placeholder(shape=(batch_size,128,128,3), dtype=tf.float32)
    img2 = tf.placeholder(shape=(batch_size,128,128,3), dtype=tf.float32)

    output, img2_64, mean, std_log = define_graph(img1, img2, batch_size)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for epoch in range(n_epochs):
            print("Epoch", epoch, "/",n_epochs)
            comulative_loss = 0.0
            for batch_i in range((int(n_examples/batch_size)-1)):
                #load batch
                from_index = batch_i*batch_size
                to_index = from_index + batch_size
                print("\tBatch", batch_i, "from",from_index, "to", to_index)
                
                x = X[from_index:to_index]
                y = Y[from_index:to_index]
                
                train, loss = train_model(output,img2_64, mean, std_log)
                
                _, loss_result = sess.run([train, loss], feed_dict={img1:x, img2:y})
                comulative_loss += loss_result
            print("Epoch", epoch, "/",n_epochs)
            print("\tloss:", comulative_loss, "\n")
            

In [6]:
X = load_data("3Shapes2_large/", "im1.png")
Y = load_data("3Shapes2_large/","im2.png")

In [7]:
X = X[:100]
Y = Y[:100]

In [None]:
run(X,Y,n_epochs=2, batch_size = 50)

there is  1000 examples
Epoch 0 / 3
	Batch 0 from 0 to 250
