# Cell
TensorFlow模型分三层：Cell/Layer/Model<br/>
其中，Cell是特殊的Layer，用在反馈网络中<br/>
遵循 Functional API 约定，需要实现函数：<br/>
- build 定义参数
- call  定义前向传播过程

其中，Functional API 约定是用来实现链式调用的形式化方法. <br/>
目标：可以实现LSTM/GRU/RRU/DCRNN单元

In [None]:
'''
A Naive LSTM Cell Implementation.

@link tensorflow/tensorflow/python/keras/layers/legacy_rnn/rnn_cell_impl.py#BasicLSTMCell
@link https://github.com/piEsposito/pytorch-lstm-by-hand/blob/master/LSTM.ipynb
@date MAR-29-2022
@note straightforward, in sacrifice of efficiency
'''
class LSTM(keras.layers.Layer):
    def __init__(self, units=32, trans=False, iters=3):
        super(LSTM, self).__init__()
        self.units = units
        self.trans = trans
        self.iters = iters

    def build(self, input_shape):
        self.W_i = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True)
        self.U_i = self.add_weight(shape=(self.units, self.units), initializer="random_normal", trainable=True)
        self.b_i = self.add_weight(shape=(self.units, ), initializer="random_normal", trainable=True)

        self.W_f = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True)
        self.U_f = self.add_weight(shape=(self.units, self.units), initializer="random_normal", trainable=True)
        self.b_f = self.add_weight(shape=(self.units, ), initializer="random_normal", trainable=True)

        self.W_c = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True)
        self.U_c = self.add_weight(shape=(self.units, self.units), initializer="random_normal", trainable=True)
        self.b_c = self.add_weight(shape=(self.units, ), initializer="random_normal", trainable=True)

        self.W_o = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True)
        self.U_o = self.add_weight(shape=(self.units, self.units), initializer="random_normal", trainable=True)
        self.b_o = self.add_weight(shape=(self.units, ), initializer="random_normal", trainable=True)

        if self.trans:
            self.Q   = self.add_weight(shape=(self.units, input_shape[-1]), initializer="random_normal", trainable=True)
            self.R   = self.add_weight(shape=(input_shape[-1], self.units), initializer="random_normal", trainable=True)

    def mogrigy(self, x_t, h_t):
        sigmoid = math_ops.sigmoid
        for i in range(1, self.iters + 1):
            if (i % 2 == 0):
                h_t = (2 * sigmoid(x_t @ self.R)) * h_t
            else:
                x_t = (2 * sigmoid(h_t @ self.Q)) * x_t
        return x_t, h_t

    def call(self, inputs, init_states=None):
        bs = inputs.shape[1]
        one = constant_op.constant(1, dtype=dtypes.int32)
        sigmoid = math_ops.sigmoid
        tanh = math_ops.tanh
        if init_states is None:
            h_t, c_t = (
                tf.zeros((bs, self.units)),
                tf.zeros((bs, self.units)),
            )
        else:
            h_t, c_t = array_ops.split(value=init_states, num_or_size_splits=2, axis=one)

        x_t = inputs

        if self.trans:
            x_t, h_t = self.mogrigy(x_t, h_t)

        i_t = sigmoid(x_t @ self.W_i + h_t @ self.U_i + self.b_i)
        f_t = sigmoid(x_t @ self.W_f + h_t @ self.U_f + self.b_f)
        g_t = tanh(x_t @ self.W_c + h_t @ self.U_c + self.b_c)
        o_t = sigmoid(x_t @ self.W_o + h_t @ self.U_o + self.b_o)
        c_t = f_t * c_t + i_t * g_t
        h_t = o_t * tanh(c_t)
            
        new_state = array_ops.concat([h_t, c_t], 1)
        return h_t, new_state