# Step by step reconstruction of the model

In [1]:
import numpy as np
import tensorflow as tf
import os
import cv2

In [2]:
def load_data(img_dir, ending):
    return np.array([cv2.imread(os.path.join(img_dir, img)) for img in os.listdir(img_dir) if img.endswith(ending)])


In [3]:
X = load_data("3Shapes2_large/", "im1.png")
Y = load_data("3Shapes2_large/","im2.png")

In [4]:
X.shape

(8015, 128, 128, 3)

In [5]:
X = X[:10]
Y = Y[:10]

In [26]:
def create_graph(img1, img2):
    img1_64 = tf.nn.max_pool(img1, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    img2_64 = tf.nn.max_pool(img2, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    
    #Motion encoder

    #First convolution: 5x5x96
    weights = tf.Variable(tf.random_normal([5,5,3,96]))
    bias = tf.Variable(tf.zeros([96,]))
    logits = tf.nn.conv2d(img1_64, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    logits = tf.nn.relu(logits)
    
    #Second convolution: 5x5x96
    weights = tf.Variable(tf.random_normal([5,5,96,3]))
    bias = tf.Variable(tf.zeros([3,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    #logits = tf.nn.max_pool(logits, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits = tf.nn.relu(logits)
    
    #Third convolution: 5x5x128
    weights = tf.Variable(tf.random_normal([5,5,96,128]))
    bias = tf.Variable(tf.zeros([128,]))
    logits = tf.nn.conv2d(logits, filter=weights, strides=[1,1,1,1], padding='SAME')
    logits = tf.add(logits, bias)
    logits = tf.nn.max_pool(logits, ksize=[1,2,2,1], strides=[1,2,2,1], padding='VALID')
    logits = tf.nn.relu(logits)
    
    return img1_64, img2_64, logits


In [27]:
def train(prediction, y):
    l2_loss = tf.reduce_mean(tf.square(prediction - y))
    optimizer = tf.train.AdamOptimizer(0.01)
    train = optimizer.minimize(l2_loss)
    
    return l2_loss, train

In [28]:
# run the model
def run(X, Y, n_epochs = 10, batch_size = 10):
    img1 = tf.placeholder(shape=(batch_size,128,128,3), dtype=tf.float32, name="s1s")
    img2 = tf.placeholder(shape=(batch_size,128,128,3), dtype=tf.float32)
    
    img1_64, img2_64, output = create_graph(img1, img2)
    loss, training = train(output, img2_64)
    
    with tf.Session() as sess:
        for epoch in range(n_epochs):
            
            sess.run(tf.global_variables_initializer())
            
            print('Epoch %i/%i' % (epoch+1, n_epochs))
            cumulative_loss = 0.0
            
            for batch_num in range(int(X.shape[0]/batch_size)-1):
                # get x and y
                x = X[batch_size*batch_num: batch_size*batch_num + batch_size]
                y = Y[batch_size*batch_num: batch_size*batch_num + batch_size]
                
                # run train and loss
                _, batch_loss = sess.run([training,loss], feed_dict={img1:x, img2:y})
                print("\t\tbatch_loss:", batch_loss)
                cumulative_loss += batch_loss
            
            print("\tEpoch's loss:", cumulative_loss)
                

# TODO
1. add the next convolution (working with only one resized img (64x64))

In [29]:
run(X, Y, 10, 2)

Epoch 1/10
		batch_loss: 1.66962e+09
		batch_loss: 2.09131e+08
		batch_loss: 1.58105e+07
		batch_loss: 5.50656e+06
	Epoch's loss: 1900066741.0
Epoch 2/10
		batch_loss: 5.39253e+08
		batch_loss: 3.94068e+07
		batch_loss: 7.72511e+06
		batch_loss: 624703.0
	Epoch's loss: 587009185.375
Epoch 3/10
		batch_loss: 1.8949e+09
		batch_loss: 2.83059e+08
		batch_loss: 2.08557e+07
		batch_loss: 7.34431e+06
	Epoch's loss: 2206161403.5
Epoch 4/10
		batch_loss: 7.25359e+07
		batch_loss: 2.99733e+06
		batch_loss: 365283.0
		batch_loss: 201389.0
	Epoch's loss: 76099896.4531
Epoch 5/10
		batch_loss: 9.85981e+08
		batch_loss: 1.11454e+08
		batch_loss: 1.24701e+07
		batch_loss: 4.75838e+06
	Epoch's loss: 1114663071.5
Epoch 6/10
		batch_loss: 2.0847e+09
		batch_loss: 2.29733e+08
		batch_loss: 2.63024e+07
		batch_loss: 1.67848e+07
	Epoch's loss: 2357520118.0
Epoch 7/10
		batch_loss: 4.67182e+07
		batch_loss: 3.44956e+06
		batch_loss: 84348.6
		batch_loss: 123258.0
	Epoch's loss: 50375399.2578
Epoch 8/10
		b