<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Load-Images" data-toc-modified-id="Load-Images-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Load Images</a></span></li><li><span><a href="#Identity-Transform" data-toc-modified-id="Identity-Transform-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Identity Transform</a></span></li><li><span><a href="#Rotation" data-toc-modified-id="Rotation-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Rotation</a></span></li></ul></div>

In [1]:
import numpy as np
from PIL import Image
import tensorflow as tf

from utils import img2array, array2img
from stn import spatial_transformer_network as transformer

## Load Images

In [3]:
DIMS = (600, 600)
data_dir = './data/'

# load 4 cat images
img1 = img2array(data_dir + 'cat1.jpg', DIMS, expand=True)#, view=True)
img2 = img2array(data_dir + 'cat2.jpg', DIMS, expand=True)
img3 = img2array(data_dir + 'cat3.jpg', DIMS, expand=True)
img4 = img2array(data_dir + 'cat4.jpg', DIMS, expand=True)

In [8]:
input_img = np.concatenate([img1, img2, img3, img4], axis=0)
B, H, W, C = input_img.shape
print("Input Img Shape: {}".format(input_img.shape))

Input Img Shape: (4, 600, 600, 3)


## Identity Transform

We'll be using an identity transform as a sanity check. This means the output image should look like the input image.

In [12]:
tf.compat.v1.disable_eager_execution()

In [13]:
# identity transform
theta = np.array([[1., 0, 0], [0, 1., 0]])

In [15]:
x = tf.compat.v1.placeholder(tf.float32, [None, H, W, C])

with tf.compat.v1.variable_scope('spatial_transformer'):
    theta = theta.astype('float32')
    theta = theta.flatten()

    # define loc net weight and bias
    loc_in = H*W*C
    loc_out = 6
    W_loc = tf.Variable(tf.zeros([loc_in, loc_out]), name='W_loc')
    b_loc = tf.Variable(initial_value=theta, name='b_loc')
    
    # tie everything together
    fc_loc = tf.matmul(tf.zeros([B, loc_in]), W_loc) + b_loc
    h_trans = transformer(x, fc_loc)

In [18]:
# run session
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
y = sess.run(h_trans, feed_dict={x: input_img})
print("y: {}".format(y.shape))
array2img(y[0]).show()

y: (4, 600, 600, 3)


## Rotation

Let's try rotating the picture by 45 degrees.

In [20]:
from utils import deg2rad

In [21]:
# initialize affine transform tensor `theta`
degree = 45
theta = np.array([
    [np.cos(deg2rad(degree)), -np.sin(deg2rad(degree)), 0], 
    [np.sin(deg2rad(degree)), np.cos(deg2rad(degree)), 0]
])

In [24]:
x = tf.compat.v1.placeholder(tf.float32, [None, H, W, C])

with tf.compat.v1.variable_scope('spatial_transformer'):
    theta = theta.astype('float32')
    theta = theta.flatten()

    # define loc net weight and bias
    loc_in = H*W*C
    loc_out = 6
    W_loc = tf.Variable(tf.zeros([loc_in, loc_out]), name='W_loc')
    b_loc = tf.Variable(initial_value=theta, name='b_loc')
    
    # tie everything together
    fc_loc = tf.matmul(tf.zeros([B, loc_in]), W_loc) + b_loc
    h_trans = transformer(x, fc_loc)

In [26]:
# run session
sess = tf.compat.v1.Session()
sess.run(tf.compat.v1.global_variables_initializer())
y = sess.run(h_trans, feed_dict={x: input_img})
print("y: {}".format(y.shape))
array2img(y[0]).show()

y: (4, 600, 600, 3)
