In [1]:
import numpy as np

# Implement RNN

## Implement RNN CELL

In [2]:
class RnnCell(object):
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.w = np.random.randn(self.input_size + self.hidden_size, self.hidden_size) * 0.1
        self.b = np.random.randn(self.hidden_size) * 0.1
        self.init_state = [np.zeros((1, self.hidden_size))]
    
    def forward_pass(self, x, h_state):
        """
        x.shape [batch, input_size]
        h_state.shape [batch, hidden_size]
        return : [batch, hidden_size]
        """
        # Concatenate the input x at the current moment and the state h at the previous moment
        x = np.concatenate([x, h_state[0]], axis = 1) 
        hidden = np.tanh(np.dot(x, self.w) + self.b) # Linear mapping + activation
        return (hidden, ) # Tuple

## Implement the RNN operation function

In [5]:
from copy import deepcopy
def rnn(cell, x, bidirectional = False):
    """
    x.shape [sep, batch, feature]
    """
    h = tuple([np.repeat(s, x.shape[1], axis = 0) for s in cell.init_state])
    time = x.shape[0]
    states = []
    for i in range(time): # Perform cyclic calculations along the time dimension
        # Each time, input the current moment's input and the previous moment's state to perform an RNN calculation.
        h = cell.forward_pass(x[i], h) 
        states.append([h]) # Save
    if bidirectional:
        seq_len = x.shape[0]
        bp_cell = deepcopy(cell) # Deep copy means that the parameters of the two cells are different
        h = tuple([np.repeat(s, x.shape[1], axis = 0) for s in cell.init_state])
        for i in range(time): #Perform loop calculations in reverse
            h = cell.forward_pass(x[seq_len-i-1], h)
            states[seq_len-i-1].append(h) # Add the reverse simultaneous state
    return states

In [10]:
x = np.random.random((6, 5, 3))
cell = RnnCell(3, 4)
states = rnn(cell, x, True)

In [11]:
states[-1]

[(array([[ 0.05295692,  0.12319105,  0.08179296, -0.06917353],
         [ 0.06583084,  0.09899176,  0.09616207, -0.05834675],
         [ 0.01434563,  0.14206171,  0.09464099, -0.06203464],
         [ 0.05191536,  0.19296807,  0.03256816, -0.11120403],
         [ 0.00823994,  0.16062523,  0.05338136, -0.07156589]]),),
 (array([[ 0.05696254,  0.11989386,  0.09707795, -0.0730331 ],
         [ 0.07205208,  0.105245  ,  0.09947388, -0.07037611],
         [ 0.01604412,  0.14718786,  0.09916585, -0.07492343],
         [ 0.05709923,  0.18072004,  0.05950846, -0.10578757],
         [ 0.02741419,  0.15821804,  0.07174869, -0.06848138]]),)]

# 实现LSTM cell

In [14]:
def sigmoid(x):
    return 1.0 / (1.0 + 1.0 / np.exp(x))

class LSTMCell(object):
    def __init__(self, input_size, hidden_size):
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.w = [np.random.rand((self.input_size + self.hidden_size), self.hidden_size) * 0.1 for i in range(4)]
        self.b = [np.random.rand(self.hidden_size) * 0.1 for i in range(4)]
        self.init_state = [np.zeros((1, cell.hidden_size)), np.zeros((1, cell.hidden_size))]
    
    def forward_pass(self, x, states):
        """
        x.shape [batch, input_size]
        h_state.shape [batch, hidden_size]
        return : [batch, hidden_size]
        """
        h, c = states
        x = np.concatenate([x, h], axis = 1)
        f, i, j, o = [np.dot(x, w) + b for w,b in zip(self.w, self.b)]
        c = sigmoid(f) # Forgetting gate
        c = c + sigmoid(i) * np.tanh(j) # Input gate
        h = sigmoid(o) * np.tanh(c) # output gate
        return h,c

In [15]:
x = np.random.random((6, 2, 3))
cell2=LSTMCell(3,4)
rnn(cell2,x,True)

[[(array([[0.28921798, 0.27993713, 0.29234594, 0.30650183],
          [0.30886608, 0.29316847, 0.30652414, 0.321433  ]]),
   array([[0.63268652, 0.57361247, 0.60852438, 0.64132404],
          [0.67401638, 0.59098911, 0.62642999, 0.6797971 ]])),
  (array([[0.32264645, 0.30116147, 0.33165184, 0.33920808],
          [0.34155844, 0.31450075, 0.34601315, 0.35357404]]),
   array([[0.70712011, 0.6035634 , 0.67629864, 0.69746618],
          [0.74794576, 0.62103623, 0.69443647, 0.73579859]]))],
 [(array([[0.31887863, 0.2974754 , 0.32613862, 0.33440985],
          [0.31534026, 0.29609006, 0.3258635 , 0.33124649]]),
   array([[0.69829523, 0.59801867, 0.66521097, 0.68978974],
          [0.69026393, 0.59733434, 0.66825693, 0.68156015]])),
  (array([[0.31990264, 0.29797051, 0.32729482, 0.33535472],
          [0.31417323, 0.29518722, 0.32443083, 0.33001814]]),
   array([[0.70040576, 0.59856126, 0.6673301 , 0.69132165],
          [0.68749321, 0.59588047, 0.66589602, 0.67933629]]))],
 [(array([[0.30165