# Recursive Style Seq2Seq Network

In [1]:
import numpy as np

import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell

In [2]:
NUM_INPUT_UNITS = 1
NUM_STATE_INPUT_UNITS = 11
NUM_HIDDEN_UNITS = 7
MAX_OUPUT_LEN = 3
MAX_RESCURISVE_DEPTH = 3
NUM_OUTPUTS = 2

Things I need to track in the recursive while loop
- current recurrent depth of "frame"
- current recurrent accum for the frame
- the (frame_idx, time_idx) from where the current frame was spawned from
- Can combine the recurrent trackers in the above: elements will look like (BATCH_SIZE, 2) of type float32
    - elements will be (current depth, recursive accumulator) 
    - e.g., [[0.4, 0], [0.7, 1], [0.8, 2]]
    
- Can combine the frame_idx and time_idx in the above: elements will look like (BATCH_SIZE, 2) of type int32
    - elements will be (frame_idx, time_idx) 
    - e.g., [[0.4, 0], [0.7, 1], [0.8, 2]]

In [223]:
# Reset graph
tf.reset_default_graph()

# Specify inputs
inp_hidden = tf.placeholder(tf.float32, shape=(None, NUM_HIDDEN_UNITS))
inp_cell = tf.placeholder(tf.float32, shape=(None, NUM_HIDDEN_UNITS))

# Constant
# decoding initial value
start_sentinel = tf.one_hot(
    tf.zeros(shape=(tf.shape(inp_hidden)[0],), dtype=tf.int32),
    NUM_OUTPUTS+1, dtype=tf.float32
)

# "Add 1 to col" tensor
add_1_tensor = tf.ones(shape=(tf.shape(inp_hidden)[0], 1), dtype=tf.float32)
add_1_tensor = tf.concat([tf.zeros_like(add_1_tensor, dtype=tf.float32), add_1_tensor], axis=1)


# initial values for recurrent accumulators
rec_accum = tf.expand_dims(tf.zeros_like(inp_hidden[:, 0]), axis=-1)
rec_count = tf.expand_dims(tf.zeros_like(inp_hidden[:, 0], dtype=tf.float32), axis=-1)

# Create recurisve tensorarray holders
rec_idx = tf.constant(0, dtype=tf.int32)
recursive_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)
frame_ref_ta = tf.TensorArray(tf.int32, size=1, dynamic_size=True)

# Initalize the recursive tensorarray
recursive_ta = recursive_ta.write(rec_idx, tf.concat([rec_accum, rec_count], axis=-1))
frame_ref_ta = frame_ref_ta.write(rec_idx, tf.constant([-1, -1]))

# Initialize ouput array
final_probs_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)
final_actions_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)

# ========================= #
# Initialize Recurrent Cell #
# ========================= #

cell = LSTMCell(NUM_HIDDEN_UNITS)
test_init = cell(start_sentinel, (inp_hidden, inp_cell))

# =========================== #
# Define Inner Function Calls #
# =========================== #

# full recurrent step including aciton probs
def network(prev_output, states):

    output, states = cell(prev_output, states)
    
    action_probs = tf.layers.dense(output, NUM_OUTPUTS, activation=tf.nn.softmax)
    
    return action_probs, output, states        

# ==================== #
# Inner While Loop Ops #
# ==================== #

def cond(time, prev_out, prev_recursive, probs_ta, actions_ta, recursive_ta, frame_ref_ta, *states):
        return time <= MAX_OUPUT_LEN
    
def step(time, prev_out, prev_recursive, probs_ta, actions_ta, recursive_ta, frame_ref_ta, *states):
    
    # Call the lstm cell
    action_probs, output, state_tuple = network(prev_out, states)
    states = state_tuple.h, state_tuple.c

    # out probs
    action_max = tf.argmax(action_probs, axis=1, output_type=tf.float32,)
    action_max_one_hot = tf.one_hot(action_max, depth=NUM_OUTPUTS+1)

    # write the current action_prob output
    probs_ta = probs_ta.write(time, action_probs)
    actions_ta = actions_ta.write(time, action_max_one_hot)
    
    # update recursion metrics
    nested_recursive = tf.where(
        tf.logical_and(
            tf.expand_dims(tf.greater(action_max, 0), 1),
            tf.less_equal(prev_recursive[:, 1:], MAX_RESCURISVE_DEPTH)
        ),
        lambda: prev_recursive + add_1_tensor,
        lambda: prev_recursive
    )
    
    # Should we write new frames?
    write_new_bool = tf.logical_and(
        tf.reduce_any(tf.greater(action_max, 0)),
        tf.reduce_any(tf.less(prev_recursive[:, 1], MAX_RESCURISVE_DEPTH))
    )
    
    # For the current time step check to see if we need to spawn new signal trees
    recursive_ta = tf.cond(
        write_new_bool,
        lambda: recursive_ta.write(recursive_ta.size(), nested_recursive),
        lambda: recursive_ta
    )
    
    # Write the frame reference
    frame_ref_ta = tf.cond(
        write_new_bool,
        lambda: frame_ref_ta.write(frame_ref_ta.size(), tf.stack([rec_idx, time], axis=0)),
        lambda: frame_ref_ta
    )

    return (
        index+1, 
        action_max_one_hot, 
        prev_recursive, 
        probs_ta, 
        actions_ta,
        recursive_ta,
        frame_ref_ta
    ) + tuple(states)


# ==================== #
# Outer While Loop Ops #
# ==================== #

def recursive_cond(rec_idx, recursive_ta, frame_ref_ta, final_probs_ta, final_actions_ta):
    
#     recursive_ta_idx = recursive_ta.read(rec_idx)
    
#     below_max_rec_depth = tf.reduce_any(
#         tf.less(recursive_ta_idx[:, 0], MAX_RESCURISVE_DEPTH),
#         axis=0
#     )
    
#     frames_depleted = tf.greater(recursive_ta.size(), rec_idx)
    
#     return tf.logical_and(below_max_rec_depth, frames_depleted)
    return tf.greater(recursive_ta.size(), rec_idx)

def recursive_func(rec_idx, recursive_ta, frame_ref_ta, final_probs_ta, final_actions_ta):
    """
    Takes in the input to a decoder, 
    """
    # Instantiate a time index
    time = tf.constant(0, dtype=tf.int32)
    
    # Read in the current recursive state
    prev_recursive = recursive_ta.read(rec_idx)
    
    # Create TensorArrays for the internal times
    probs_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)
    actions_ta = tf.TensorArray(tf.float32, size=1, dynamic_size=True)
    
    # while loop
    (
        final_time,
        final_action,
        final_recurisve,
        probs_ta, 
        actions_ta,
        recursive_ta,
        frame_ref_ta,
        _,
        _
    ) = tf.while_loop(
        cond,
        step,
        loop_vars=[
            time,
            start_sentinel, 
            prev_recursive,
            probs_ta, 
            actions_ta,
            recursive_ta,
            frame_ref_ta,
            inp_hidden, inp_cell
        ]
    )
    
    # Stack the probs and actions
    probs = probs_ta.stack()
    actions = actions_ta.stack()
    
    # Write the probs and actions 
    final_probs_ta = final_probs_ta.write(rec_idx, probs)
    final_actions_ta = final_actions_ta.write(rec_idx, actions)
    
    return rec_idx + 1, recursive_ta, frame_ref_ta, final_probs_ta, final_actions_ta

In [210]:
rta = recursive_ta.concat()

In [211]:
sess = tf.InteractiveSession()

tf.global_variables_initializer().run()

In [185]:
# ita = tf.TensorArray(tf.float32, size=0, dynamic_size=True)
int_ta = tf.TensorArray(tf.int32, size=0, dynamic_size=True)
int_ta = int_ta.write(0, tf.constant([[1, 1]], dtype=tf.int32))
int_ta = int_ta.write(1, -1*tf.ones(shape=(1, 2), dtype=tf.int32))
int_ta = int_ta.write(int_ta.size(), -1*tf.ones(shape=(1, 2), dtype=tf.int32))
int_ta = int_ta.write(int_ta.size(), -1*tf.ones(shape=(1, 2), dtype=tf.int32))

# ita = ita.write(0, tf.random_normal(shape=(5, 2)))
# ita = ita.write(1, tf.random_normal(shape=(5, 2)))
# ita = ita.write(2, tf.random_normal(shape=(5, 2)))
# ita = ita.write(3, tf.random_normal(shape=(5, 2)))

In [213]:
BATCH_SIZE = 10
res = sess.run(
    [rec_accum, rec_count, rta, add_1_tensor],
    feed_dict={
        inp_hidden: np.random.rand(BATCH_SIZE, NUM_HIDDEN_UNITS),
        inp_cell: np.random.rand(BATCH_SIZE, NUM_HIDDEN_UNITS)
    }
)

In [215]:
res[-1]

array([[ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.],
       [ 0.,  1.]], dtype=float32)