텐서플로우 Cell 오브젝트를 모방해서 만들어보자 
상속을 받아서 셀을 구성해보자

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
class cell(object):
    def __init__(self):        
        self.layer_size = 10
        self.input_hidden = 15
    
    def __call__(self, inputs, states):
        with tf.variable_scope('cell', reuse=tf.AUTO_REUSE):            
            stack = tf.concat([inputs, states],axis = -1)
            out = tf.layers.dense(stack, self.layer_size)
        return out, out

In [3]:
cell_ = cell()

In [None]:
def make_rnn(cell):
    
    def cond(i, inputs, *args):
        return tf.less(i, inputs.shape[1])

    def body(i, input_, state_, outputs):        
        input__ = tf.gather(input_, i, axis = 1)
        out, states = cell_(input__, state_)
        outputs = outputs.write(i, out)
        i = tf.add(i, 1)    
        return i, input_, states, outputs
    
    outputs = tf.TensorArray(dtype=tf.float32, size=inputs.shape[1])

    _, _, state, outputs = tf.while_loop(cond, body, [tf.constant(0), inputs, initial_state, outputs])
    
    return outputs, state

In [4]:
layer_size= 10
input_hidden = 15

inputs = tf.constant(4, shape=[2,3,layer_size], dtype=tf.float32)
initial_state = tf.constant(0, shape=[2,layer_size], dtype=tf.float32)

def cond(i, inputs, *args):
    return tf.less(i, inputs.shape[1])

def body(i, input_, state_, outputs):        
    input__ = tf.gather(input_, i, axis = 1)
    out, states = cell_(input__, state_)
    outputs = outputs.write(i, out)
    i = tf.add(i, 1)    
    return i, input_, states, outputs

outputs = tf.TensorArray(dtype=tf.float32, size=inputs.shape[1])

_, _, state, outputs = tf.while_loop(cond, body, [tf.constant(0), inputs, initial_state, outputs])

In [5]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.transpose(outputs.stack(),[1,0,2])).shape)

(2, 3, 10)


Cell oject 를 스스로 만드는 것보다 tensorflow의 cell object를 상속받아서 cell을 수정하는 것이 많은 Wrapper를 사용할 수 있기 때문에 효율적일 수 있다.

In [1]:
import tensorflow as tf
from collections import namedtuple

In [2]:
StateTuple = namedtuple('StateTuple', ['cell_state', 'hidden_state'])

In [3]:
class LSTMCell(object):
    """Basic LSTM cell 
    
    Args:
        num_unit : layer size of LSTMcell
    
    """
    
    def __init__(self, num_unit, scope):
        self._state_size = num_unit
        self._output_size = num_unit
        self._scope = scope
        
    def __call__(self, input_, state_tuple):
        """
        Args:
            input_ : 2-D tensor, [Batch_size, Input_embedding_size]
            state_tuple : namedtuple, which has two properties ['cell_state', 'hidden_state']
                cell_state   : 2-D tensor, [Batch_size, state_size]
                hidden_state : 2-D tensor, [Batch_size, state_size]
        """
        
        cell_state = state_tuple.cell_state
        hidden_state = state_tuple.hidden_state
        
        concat = tf.concat([input_, hidden_state], axis = -1)
        concat_size = tf.shape(concat)[-1]
        
        sigmoid = tf.sigmoid
        tanh = tf.tanh
        
        with tf.variable_scope(self._scope, reuse=tf.AUTO_REUSE):
            
            kernel = tf.layers.dense(concat,
                                     units=self.state_size * 4,
                                     activation=None,
                                     kernel_initializer=tf.truncated_normal_initializer(stddev=self.state_size**(-1/2)),
                                     bias_initializer= tf.constant_initializer(),
                                     name='kernel')
            
            f, i, C, o = tf.split(kernel, 4, axis=-1)          
                                
            after_forget = tf.multiply(cell_state, sigmoid(f))
            next_cell_state = tf.add(after_forget, tf.multiply(sigmoid(i), tanh(C)))
            next_hidden_state = tf.multiply(tanh(next_cell_state), sigmoid(o))
            
            output = next_hidden_state
            next_state_tuple = StateTuple(next_cell_state, next_hidden_state)
            
            
        return output, next_state_tuple
    
    def get_initial_state(self, batch_size, dtype=None):
        
        cell_state = tf.constant(0, dtype=dtype, shape=[batch_size, self._state_size])
        hidden_state = tf.constant(0, dtype=dtype, shape=[batch_size, self._state_size])
        Tuple = StateTuple(cell_state, hidden_state)        
        
        return Tuple
    
    @property
    def state_size(self):
        return self._state_size  
    
    @property
    def output_size(self):
        return self._output_size

In [4]:
def build_rnn(cell, input_):
    """
    Args:
        cell: cell object 
        inputs_: 3-D tensor that has shape of [Batch, Time, Embedding_size]
    
    Returns:
        outputs: 3-D tensor that has shape of [Batch, Time, Layer_hidden_size]
        last_state: List of StateTuple, Last hidden state of Cell
    """
    
    time_step = input_.shape[1]
    i = tf.constant(0, dtype=tf.int32)
    
    _output_tensor_array = tf.TensorArray(dtype=input_.dtype, 
                                          size=time_step, 
                                          clear_after_read=False)
    
    
    _cond = lambda i, *args : tf.less(i, time_step)    
    
    def _body(i, input_, state_tuple, _output_tensor_array):
        
        out, next_state = cell(input_[:,i,:], state_tuple)
        _output_tensor_array = _output_tensor_array.write(i, out)
        i = i + 1
        
        return i, input_, next_state, _output_tensor_array
    
    initial_vars = [i, input_, cell.get_initial_state(input_.shape[0], input_.dtype), _output_tensor_array]
    
    _, _, last_state, _outputs = tf.while_loop(_cond, _body, initial_vars)
    
    outputs = tf.transpose(_outputs.stack(), [1,0,2])
    _outputs.close()
    
    return outputs, last_state

In [5]:
cell = LSTMCell(256,'LSTM_1')

In [7]:
outputs, last_state = build_rnn(cell, tf.constant(1.0, dtype=tf.float32, shape=[10,12,15]))

In [42]:
class MultiCell(object):
    def __init__(self, cells):
        self._state_size = sum([cell.state_size for cell in cells])
        
    def __call__(self, prev_input, prev_state):
        
        # for next state
        tmp_state = []
        
        prev_outputs = prev_input
        prev_states = prev_state.unstack(-1)
        
        for i, cell in enumerate(cells):
            cur_output, cur_state = cell(prev_output, states)
            tmp_state.append(cur_state)
            prev_output = cur_output                        
    
    def get_initial_state():
        pass
    
    @property
    def state_size(self):
        return self._state_size
    
    @property
    def output_size(self):
        return self._output_size