In [1]:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import sys

import svhn_dataset 
import tensorflow as tf
import numpy as np

In [2]:
batch_size = 128
img_sz = 32
img_len = img_sz*img_sz
n_itr = 10000
n_class = 10
lstm_size = 256
hidden_size = 1024
T = 10
eps = 1e-7
lr = 1e-3
patch_size = 8
std = 0.03

In [3]:
svhn = svhn_dataset.read_data_sets('svhn_data')

In [None]:
x = tf.placeholder(tf.float32, [None, img_len])
y = tf.placeholder(tf.float32, [None, n_class])

In [None]:
def conv2d(x, W, b, strides=1):
    out = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
    out = tf.nn.bias_add(out, b)
    return tf.nn.relu(out)

In [None]:
def glimpse_network(x, w, b, loc):
    x_g = tf.image.extract_glimpse(x, tf.shape(patch_size, patch_size), loc)
    x_g = tf.reshape(x_g, shape=[-1, patch_size, patch_size, 1])

    conv1 = conv2d(x_g, w['wg1'], b['bg1'])
    conv2 = conv2d(conv1, w['wg2'], b['bg2'])
    conv3 = conv2d(conv2, w['wg3'], b['bg3'])
   
    fc1 = tf.reshape(conv3, [-1, w['wg'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, w['wg']), b['bg'])
    Gimage = tf.nn.relu(fc1)
    
    Gloc = tf.add(tf.matmul(loc, w['wgl']), b['bgl'])
    gn = tf.multiply(Gimage, Gloc)
    return gn
    
def context_network(x, w, b):
    x_g = tf.image.resize_images(x, patch_size, patch_size)
    x_g = tf.reshape(x_g, shape=[-1, patch_size, patch_size, 1])

    conv1 = conv2d(x_g, w['wc1'], b['bc1'])
    conv2 = conv2d(conv1, w['wc2'], b['bc2'])
    conv3 = conv2d(conv2, w['wc3'], b['bc3'])
   
    fc1 = tf.reshape(conv3, [-1, w['wc'].get_shape().as_list()[0]])
    fc1 = tf.add(tf.matmul(fc1, w['wc']), b['bc'])
    return tf.nn.relu(fc1)

In [None]:
W = {
    'wc1': tf.Variable(tf.random_normal([5, 5, 1, 64])),
    'wc2': tf.Variable(tf.random_normal([3, 3, 64, 64])),
    'wc3': tf.Variable(tf.random_normal([3, 3, 64, 128])),
    'wc': tf.Variable(tf.random_normal([3*3*128, lstm_size])),
    
    'wg1': tf.Variable(tf.random_normal([5, 5, 1, 64])),
    'wg2': tf.Variable(tf.random_normal([3, 3, 64, 64])),
    'wg3': tf.Variable(tf.random_normal([3, 3, 64, 128])),
    'wg': tf.Variable(tf.random_normal([3*3*128, hidden_size])),
    'wgl': tf.Variable(tf.random_normal([2, hidden_size])),
    
    'wl': tf.Variable(tf.random_normal([lstm_size, 2])),
    'wo': tf.Variable(tf.random_normal([lstm_size, 10])),
}

b = {
    'bc1': tf.Variable(tf.zeros([64])),
    'bc2': tf.Variable(tf.zeros([64])),
    'bc3': tf.Variable(tf.zeros([128])),
    'bc': tf.Variable(tf.zeros([lstm_size])),
    
    'bg1': tf.Variable(tf.zeros([64])),
    'bg2': tf.Variable(tf.zeros([64])),
    'bg3': tf.Variable(tf.zeros([128])),
    'bg': tf.Variable(tf.zeros([hidden_size])),
    'bgl': tf.Variable(tf.zeros([hidden_size])),
    
    'bl': tf.Variable(tf.zeros([2])),
    'bo': tf.Variable(tf.zeros([10])),
}


In [None]:
y = [0]*T
loc = [0]*(T+1)

rnn1 = tf.nn.rnn_cell.LSTMCell(lstm_size)
rnn2 = tf.nn.rnn_cell.LSTMCell(lstm_size)

h1 = tf.zeros([None, lstm_size])
state2 = context_network(x, w, b)

with tf.variable_scope("rnn2", reuse=False):
    h2, state2 = rnn2(h1, state2)
    loc[0] = tf.add(tf.matmul(h2, w['wl']), b['bl'])

state1 = rnn_enc.zero_state(None, tf.float32)  
for t in range(T):      
    gn = glimpse_network(x, w, b, loc[t])
    with tf.variable_scope("rnn1", reuse=(t != 0)):
        h1, state1 = rnn1(gn, state1)
        y[t] = tf.add(tf.matmul(h1, w['wo']), b['bo']) 
    with tf.variable_scope("rnn2", reuse=True):
        h2, state2 = rnn2(h1, state2)
        loc[t+1] = tf.add(tf.matmul(h2, w['wl']), b['bl'])
        