# Recursive Style Seq2Seq Network

In [1]:
import numpy as np

import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell

In [2]:
NUM_INPUT_UNITS = 1
NUM_STATE_INPUT_UNITS = 11
NUM_HIDDEN_UNITS = 7
MAX_OUPUT_LEN = 3
MAX_RESCURISVE_DEPTH = 3
NUM_OUTPUTS = 2

Things I need to track in the recursive while loop
- current recurrent depth of "frame"
- current recurrent accum for the frame
- the (frame_idx, time_idx) from where the current frame was spawned from
- Can combine the recurrent trackers in the above: elements will look like (BATCH_SIZE, 2) of type float32
    - elements will be (current depth, recursive accumulator) 
    - e.g., [[0.4, 0], [0.7, 1], [0.8, 2]]
    
- Can combine the frame_idx and time_idx in the above: elements will look like (BATCH_SIZE, 2) of type int32
    - elements will be (frame_idx, time_idx) 
    - e.g., [[0.4, 0], [0.7, 1], [0.8, 2]]

In [71]:
rec_accum.shape

TensorShape([Dimension(None)])

In [75]:
# Reset graph
tf.reset_default_graph()

# Specify inputs
inp_hidden = tf.placeholder(tf.float32, shape=(None, NUM_HIDDEN_UNITS))
inp_cell = tf.placeholder(tf.float32, shape=(None, NUM_HIDDEN_UNITS))


# Constant for decoding start
start_sentinel = tf.one_hot(
    tf.zeros(shape=(tf.shape(inp_hidden)[0],), dtype=tf.int32),
    NUM_OUTPUTS+1, dtype=tf.float32
)

# initial values for recurrent accumulators
rec_accum = tf.expand_dims(tf.zeros_like(inp_hidden[:, 0]), axis=-1)
rec_count = tf.expand_dims(tf.zeros_like(inp_hidden[:, 0], dtype=tf.float32), axis=-1)

# Create recurisve tensorarray holders
rec_idx = tf.constant(0, dtype=tf.int32)
recursive_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)
frame_ref_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)

# Initalize the recursive tensorarray
recursive_ta = recursive_ta.write(rec_idx, tf.concat([rec_accum, rec_count], axis=-1))
# recursive_ta = recursive_ta.write(rec_idx, tf.concat([rec_accum, rec_accum], axis=-1))
# head = recur

# ========================= #
# Initialize Recurrent Cell #
# ========================= #

cell = LSTMCell(NUM_HIDDEN_UNITS)
test_init = cell(start_sentinel, (inp_hidden, inp_cell))

# =========================== #
# Define Inner Function Calls #
# =========================== #

# full recurrent step including aciton probs
def network(prev_output, states):

    output, states = cell(prev_output, states)
    
    action_probs = tf.layers.dense(output, NUM_OUTPUTS, activation=tf.nn.softmax)
    
    return action_probs, output, states        

# ==================== #
# Inner While Loop Ops #
# ==================== #

def cond(index, input_ta, output_ta, *states):
        return index <= MAX_OUPUT_LEN
    
def step(index, prev_out, output_ta, *states):
    
    # Call the lstm cell
    action_probs, output, state_tuple = network(prev_out, states)
    states = state_tuple.h, state_tuple.c

    # out probs
    action_max = tf.argmax(action_probs, axis=1)
    action_max_one_hot = tf.one_hot(action_max, depth=NUM_OUTPUTS+1)

    # write the current action_prob output
    output_ta = output_ta.write(index, action_probs)

    return (index+1, action_max_one_hot, output_ta) + tuple(states)


# ==================== #
# Outer While Loop Ops #
# ==================== #

def recursive_func(inp_hidden, inp_state, rec_accum, rec_count, output_ta_list=[]):
    """
    Takes in the input to a decoder, 
    """
    # Instantiate an index
    time = tf.constant(0, dtype=tf.int32)
    
    # Create an internal output_ta
    output_ta = tf.TensorArray(tf.float32, size=MAX_OUPUT_LEN, dynamic_size=True)
    
    # while loop
    final_index, final_action, output_ta_final, _, _ = tf.while_loop(
        cond,
        step,
        loop_vars=[time, start_sentinel, output_ta, inp_hidden, inp_state]
    )
    
    return final_index, final_action, output_ta_final

In [132]:
output_list = []
final_index, final_action, output_ta_final = recursive_func(inp_hidden, inp_cell, 0, 2, output_ta_list=output_list)

recursive count is 2
recursive count is 1
recursive count is 0
recursive count is -1
recursive count is -2
recursive count is -3
recursive count is -4
recursive count is -5
recursive count is -6
recursive count is -7
recursive count is -8
recursive count is -9
recursive count is -10
recursive count is -11
recursive count is -12
recursive count is -13
recursive count is -14
recursive count is -15
recursive count is -16
recursive count is -17
recursive count is -18
recursive count is -19
recursive count is -20
recursive count is -21
recursive count is -22
recursive count is -23
recursive count is -24
recursive count is -25
recursive count is -26
recursive count is -27
recursive count is -28
recursive count is -29
recursive count is -30
recursive count is -31
recursive count is -32
recursive count is -33
recursive count is -34
recursive count is -35
recursive count is -36
recursive count is -37
recursive count is -38
recursive count is -39
recursive count is -40
recursive count is -41
rec

KeyboardInterrupt: 

In [92]:
output_list_mat = map(lambda ta: ta.stack(), output_list)

In [79]:
sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

In [63]:
# ita = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
int_ta = tf.TensorArray(tf.int32, size=2)

int_ta = int_ta.write(0, tf.constant([[1, 1]], dtype=tf.int32))
int_ta = int_ta.write(1, -1*tf.ones(shape=(1, 2), dtype=tf.int32))

# ita = ita.write(0, tf.random_normal(shape=(5, 2)))
# ita = ita.write(1, tf.random_normal(shape=(5, 2)))
# ita = ita.write(2, tf.random_normal(shape=(5, 2)))
# ita = ita.write(3, tf.random_normal(shape=(5, 2)))

In [80]:
BATCH_SIZE = 10
res = sess.run(
    [rec_accum, rec_count, start_sentinel, recursive_ta.concat()],
    feed_dict={
        inp_hidden: np.random.rand(BATCH_SIZE, NUM_HIDDEN_UNITS),
        inp_cell: np.random.rand(BATCH_SIZE, NUM_HIDDEN_UNITS)
    }
)

In [82]:
res[-1]

array([[ 0.,  0.],
       [ 0.,  0.]], dtype=float32)