In [1]:
import tensorflow as tf
import numpy as np

In [2]:
max_time = 10 #max time steps
batch_size = 5 #max batch size
input_depth = 10 #dimension of input to the cell
output_dim = 20 #dimension to get after cell works
inputs = tf.placeholder(shape=(None, None, input_depth), dtype = tf.float32) #input -> [max_time, batch_size, input_dim]
sequence_length = tf.placeholder(shape=(None,), dtype=tf.int32) #length of sequences passed

embeddings = tf.Variable(tf.random_normal([10, input_depth], dtype=tf.float32)) #embeddings
go_embedding = tf.nn.embedding_lookup(embeddings, tf.ones([batch_size], dtype=tf.int32)) #fast embedding lookup

In [3]:
#state of LSTM initial
state_c = tf.Variable(tf.random_normal([batch_size, output_dim]))
state_h = tf.Variable(tf.random_normal([batch_size, output_dim]))

init_state = tf.nn.rnn_cell.LSTMStateTuple(c = state_c, h = state_h)

In [4]:
#LSTM cell
cell = tf.nn.rnn_cell.LSTMCell(output_dim)

#weights and biases for cell output propagation
w = tf.Variable(tf.random_normal(shape=[output_dim, input_depth]))
b = tf.Variable(tf.random_normal(shape=[input_depth]))

In [5]:
def loop_fn(time, cell_output, cell_state, loop_state):
    """
    Loop function used by tf.nn.raw_rnn to convert last output to form input for next cell
    Two cases of its calling
    - initial call with no input, used to set input value of first time step and initial cell state time == 0
    - transition call where input is transformed to output and the fed back to next cell time > 0
    """
    if cell_output is None: #time == 0
        next_cell_state = init_state #cell initial state
        emit_output = tf.zeros([output_dim]) #zero output at time == 0
    else: #time > 0
        next_cell_state = cell_state #propagating cell state
        emit_output = cell_output #previous cell output to be transformed to next input
    
    element_finished = (time >= sequence_length) #how many time steps done == True if all done
    finished = tf.reduce_all(element_finished) #single scaler for true or false according to sequence end reached or not
    current_output = emit_output if cell_output is not None else None #set output value to be transformed to next input
    
    if cell_output is None: #if time == 0, next input will be go embeddings
        next_in = go_embedding
    else: #else next input will be calculated from last output using dense layer
        val = tf.matmul(current_output, w) + b
        next_in = val
    
    # set next input
    next_input = tf.cond(finished, lambda: tf.zeros([batch_size, input_depth], dtype=tf.float32),
                        lambda : next_in)
    
    #set right shape for next input
    next_input.set_shape([None, input_depth])
    
    #additional loop info
    next_loop_state = None
    return (element_finished, next_input, next_cell_state, emit_output, next_loop_state)

In [6]:
outputs_ta, final_state, _ = tf.nn.raw_rnn(cell, loop_fn)
outputs = outputs_ta.stack()

In [7]:
val = np.random.randn(max_time, batch_size, input_depth)
#go = np.random.randn(batch_size, input_depth)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
v, s = sess.run([outputs, final_state], feed_dict = {inputs : val, sequence_length : [max_time] * batch_size})
print(v)
sess.close()

[[[ 0.2121371   0.0075544   0.06673589 -0.31795973  0.11750382 -0.0698298
    0.37110379  0.34164777 -0.09560113  0.27152398  0.38576198  0.16678673
    0.32420823  0.26710132  0.19369729 -0.07265744  0.27676243 -0.02808553
    0.18484381  0.2332993 ]
  [-0.02399555 -0.42025563 -0.21000592  0.1625838   0.2922546  -0.07136432
   -0.24674603  0.33929935 -0.10335695  0.50097823 -0.05726094 -0.42289534
    0.00320673  0.01587123  0.0826601  -0.05146335  0.48511115  0.31355256
   -0.01188401 -0.02464564]
  [ 0.21417494  0.46720383  0.2124511   0.23273127  0.44537684 -0.00699579
   -0.19553687 -0.00119336  0.21286964  0.64244407  0.65323532 -0.18068221
   -0.18488395  0.28645974 -0.01738501  0.41316068  0.06192502 -0.09771585
   -0.09589762  0.00083089]
  [ 0.08162106 -0.09198731 -0.20811537  0.28658938 -0.00285493  0.03069
   -0.03515945  0.15316597  0.04902809  0.49996462  0.30966544  0.06288894
    0.57645887  0.4764435  -0.28484979  0.72399062 -0.03439834  0.62446761
    0.11203176  0.41