# Creating a RNN block

In [2]:
import numpy as np

In [35]:
def softmax(logits: np.ndarray) -> np.ndarray:
    return np.exp(logits - np.max(logits))/np.sum(np.exp(logits - np.max(logits)), axis=0, keepdims=True)

In [36]:
# test softmax
logits = np.array([2.0, 1.0, 0.1])
softmax(logits)

array([0.65900114, 0.24243297, 0.09856589])

In [85]:
class RNNCell:

    def __init__(self, dim_hidden_units, dim_input, batch_size, dim_output, activation=np.tanh) -> None:
        # input size is ((n_x), m) where n_x is input dimensions
        self.Wa = np.random.randn(dim_hidden_units, dim_input+dim_hidden_units) # we assume the last shape is T_x
        self.Wy = np.random.randn(dim_output, dim_hidden_units)
        self.ba = np.zeros((dim_hidden_units, 1))
        self.by = np.zeros((1, 1))
        self.activation = activation

    def __call__(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 
        hidden_state = self.activation(self.Wa @ x + self.ba)
        y_hat = softmax(self.Wy @ hidden_state + self.by)
        return hidden_state, y_hat

    def __repr__(self) -> str:
        return f'RNNCell()'



In [86]:
# test rnn forward
batch_size = 2
seq_length = 3
input_size = 4
hidden_size = 5
output_size = 3
cell = RNNCell(hidden_size, input_size, batch_size, output_size)

input_sequence = np.random.rand(input_size, batch_size, seq_length)
hidden_state = np.zeros((hidden_size, batch_size))

for t in range(seq_length):
    input_t = input_sequence[:, :, t] 
    inp = np.vstack((input_t, hidden_state))
    hidden_state, y_hat = cell(inp)
    print(f"Time step {t + 1}: Hidden state =\n{hidden_state}, \nY_hat = \n{y_hat} ")


Time step 1: Hidden state =
[[-0.32622662  0.00199797]
 [-0.752167   -0.91085786]
 [ 0.92956147  0.95214357]
 [-0.82340251 -0.6967403 ]
 [ 0.26984353  0.72060859]], 
Y_hat = 
[[0.58785492 0.60734337]
 [0.04530643 0.01968375]
 [0.36683865 0.37297289]] 
Time step 2: Hidden state =
[[ 0.20285604  0.04999598]
 [-0.99912931 -0.99962968]
 [ 0.35454741  0.86044744]
 [-0.99502696 -0.99255363]
 [ 0.93684784  0.92218082]], 
Y_hat = 
[[0.08397552 0.35150295]
 [0.01581646 0.01613345]
 [0.90020801 0.63236359]] 
Time step 3: Hidden state =
[[ 0.05282222 -0.3378512 ]
 [-0.99925038 -0.99956888]
 [ 0.52258449  0.58295158]
 [-0.97999171 -0.9883243 ]
 [ 0.87255972  0.57738154]], 
Y_hat = 
[[0.13478624 0.13023652]
 [0.01940477 0.03385962]
 [0.84580899 0.83590386]] 
