# MultiRNNCell with shared weights
A MultiRNNCell object that stacks several LSTM cell with weight sharing.

In [1]:
import tensorflow as tf  # 0.12.1
import numpy as np

from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh

from tensorflow.python.ops import variable_scope as vs

In [2]:
def reset_graph():
    '''Convenience function to reset graph'''
    if 'sess' in globals() and sess:
        sess.close()
    tf.reset_default_graph()

In [3]:
'''
Custom MultiRNNCell wrapper

Exactly the Same as tf.nn.rnn_cell.MultiRNNCell but without the 
name scoping. The original wrapper adds /cell# to the name scope,
which would have created new variables instead of reusing exising
ones.
'''
class MultiRNNCell_shared_weights(tf.nn.rnn_cell.RNNCell):
    """RNN cell composed sequentially of multiple simple cells."""
    def __init__(self, cells):
        self._cells = cells

    @property
    def state_size(self):
        return tuple(cell.state_size for cell in self._cells)

    @property
    def output_size(self):
        return self._cells[-1].output_size

    def __call__(self, inputs, state, scope=None):
        """Run this multi-layer cell on inputs, starting from state."""
        cur_state_pos = 0
        cur_inp = inputs
        new_states = []
        for i, cell in enumerate(self._cells):
            cur_state = state[i]
            cur_inp, new_state = cell(cur_inp, cur_state)
            new_states.append(new_state)
        new_states = tuple(new_states)
        return cur_inp, new_states

In [4]:
def build_multicell_lstm_graph(
    state_size=8,
    batch_size=8,
    num_steps=2,
    num_classes=2,
    num_cells=2):
    '''
    This function builds a dynamic rnn graph with multiple LSTM cells 
    stacked ontop of eachother.
    The LSTM cells share weights between them though the modified
    MultiRNNCell wrapper.
    '''
    
    reset_graph()
    
    x = tf.placeholder(tf.int32, [batch_size, num_steps], name='input_placeholder')
    
    # shape (batch_size, num_steps, state_size)
    rnn_inputs = tf.nn.embedding_lookup(tf.random_normal([num_classes, state_size]), x)

    # Create variables
    # using the same variable names generated in BasicLSTMCell
    with tf.variable_scope('RNN') as scope:  # hack
        W = vs.get_variable('BasicLSTMCell/Linear/Matrix',
                            [state_size + state_size, state_size * 4])
        b = vs.get_variable('BasicLSTMCell/Linear/Bias', [state_size * 4],
                            initializer=tf.constant_initializer(0.0))
    scope.reuse_variables()  # mark scope to reuse from now on
    
    cell = tf.nn.rnn_cell.BasicLSTMCell(state_size)
    cell = MultiRNNCell_shared_weights([cell] * num_cells)  # custom wrapper
    
    # Cells are unrolled in tf.nn.dynamic_rnn and weights are shared between them
    rnn_outputs, final_state = tf.nn.dynamic_rnn(cell, rnn_inputs, dtype=tf.float32, scope=scope)
    

In [5]:
build_multicell_lstm_graph()

def count_parameters_and_variables():
    total_parameters = 0
    for variable in tf.trainable_variables():
        shape = variable.get_shape()
        variable_parameters = 1
        for dim in shape:
            variable_parameters *= dim.value
        print('{}, shape={}, params={}'.format(variable.name, shape, variable_parameters))
        total_parameters += variable_parameters
    print('total_parameters:', total_parameters)
    
    g = tf.get_default_graph()
    vars = g.get_collection('variables')
    print(len(vars), 'variables')
    
    return len(vars), total_parameters

assert count_parameters_and_variables()[0] == 2


RNN/BasicLSTMCell/Linear/Matrix:0, shape=(16, 32), params=512
RNN/BasicLSTMCell/Linear/Bias:0, shape=(32,), params=32
total_parameters: 544
2 variables
