# Creating a RNN block

In [1]:
import numpy as np
import functools

In [2]:
def softmax(logits: np.ndarray) -> np.ndarray:
    """ Returns probabilities """
    return np.exp(logits - np.max(logits))/np.sum(np.exp(logits - np.max(logits)), axis=0, keepdims=True)

In [3]:
def relu(x: np.ndarray) -> np.ndarray:
    """ Rectified Linear Unit activation function """
    return np.fmax(0, x)

In [4]:
# test softmax
logits = np.array([2.0, 1.0, 0.1])
softmax(logits)

array([0.65900114, 0.24243297, 0.09856589])

In [5]:
def cross_entropy(y_true: np.ndarray, y_hat: np.ndarray) -> float:
    """ Cross entropy loss """
    return -np.sum(y_true * np.log(y_hat))

In [6]:
# test loss function
y_true = np.array([0, 1, 0, 0, 0])              # True distribution
y_pred = np.array([0.1, 0.6, 0.1, 0.15, 0.05])  # Predicted distribution

print(f"Cross Entropy: {cross_entropy(y_true, y_pred):.2f}")

Cross Entropy: 0.51


In [7]:
class Module:

    def __init__(self, cls) -> None:
        functools.update_wrapper(self, cls)

    def __call__(self, *args):
        return self.forward(*args)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [66]:
class RNNCell(Module):

    def __init__(
        self, 
        dim_hidden_units: int, 
        dim_input: int, 
        batch_size: int, 
        dim_output: int, 
        activation=np.tanh, 
        loss=cross_entropy,
        bool_pred = True
    ) -> None:
        super().__init__(self)
        # input size is ((n_x), m) where n_x is input dimensions
        self.Wxh = np.random.randn(dim_hidden_units, dim_input) * 0.01 # we assume the last shape is T_x
        self.Whh = np.random.randn(dim_hidden_units, dim_hidden_units) * 0.01
        self.Wy = np.random.randn(dim_output, dim_hidden_units) * 0.01 
        self.ba = np.zeros((dim_hidden_units, 1))
        self.by = np.zeros((dim_output, 1))
        self.activation = activation
        self.loss = loss
        self.caches = []
        self.bool_pred = bool_pred

    def forward(self, x: np.ndarray, hidden_state_prev: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 
        
        stack = np.vstack((x, hidden_state_prev)) # stack the inputs together
        Wa = np.hstack((self.Wxh, self.Whh)) # stack the matrices together
        z = Wa @ stack + self.ba
        hidden_state = self.activation(z)
        logits = self.Wy @ hidden_state + self.by
        if self.bool_pred: y_hat = softmax(logits)

        cache = {}
        cache['x'] = stack
        cache['z'] = z
        cache['hidden_state'] = hidden_state
        if self.bool_pred: cache['y_hat'] = y_hat 

        self.caches.append(cache)

        return hidden_state, y_hat if self.bool_pred else hidden_state

    def reset_sequence() -> None:
        self.outputs = []
        self.hidden_states = []

    def compute_loss(self, y_true: np.ndarray | list[float | int]) -> float:
        outputs = [cache['y_hat'] for cache in self.caches]
        y_hats = np.stack(tuple(outputs), axis=-1)
        return np.sum(self.loss(y_true, y_hats)) / len(outputs)
    
    def parameters(self) -> list:
        return [self.Wxh, self.Whh, self.ba, self.Wy, self.by]

    def bptt(self, y_pred, y_true) -> None:
        T = len(self.caches)
        
        dl_dhts = []
        
        dL_dWy = np.zeros_like(self.Wy)
        dL_dby = np.zeros_like(self.by)
        
        for t, cache in enumerate(self.caches):
            # each cache represents a time step
            curr_hidden_state = cache['hidden_state']
            z = cache['z']
            
            dL_dyhat = (y_pred - y_true[:, :, t])
            dL_dWy += dL_dyhat @ curr_hidden_state.T
            dL_dby += np.sum(dL_dyhat, axis=-1, keepdims=True)
            dL_ht = (self.Whh.T)**(T-t+1) @ self.Wy.T @ dL_dyhat
            dL_dtanh = dL_ht @ (1 - self.activation(z)**2).T

        print(f"dL_dyhat: {dL_dyhat}, y_hat shape: {y_pred.shape}, dL_dyhat shape: {dL_dyhat.shape}")
        print(f"dL_dWy: {dL_dWy}, Wy shape: {self.Wy.shape}, dL_dWy shape: {dL_dWy.shape}")
        print(f"dL_dby: {dL_dby}, by shape: {self.by.shape}, dL_dby shape: {dL_dby.shape}")
        print(f"dL_ht: {dL_ht}, ht shape: {curr_hidden_state.shape}, dL_ht shape: {dL_ht.shape}")
        print(f"dL_dtanh: {dL_dtanh}, z shape: {curr_hidden_state.shape}, dL_dtanh shape: {dL_dtanh.shape}")

In [65]:
# test rnn forward and backward
np.random.seed(42)

batch_size = 2
seq_length = 3
input_size = 4
hidden_size = 5
output_size = 3
cell = RNNCell(hidden_size, input_size, batch_size, output_size)

input_sequence = np.random.rand(input_size, batch_size, seq_length)
hidden_state = np.zeros((hidden_size, batch_size))

for t in range(seq_length):
    input_t = input_sequence[:, :, t] 
    hidden_state, y_hat = cell(input_t, hidden_state)
    print(f"Time step {t + 1}: Hidden state =\n{hidden_state.shape}, \nY_hat = \n{y_hat.shape} ")

print(f'{'=' * 100}')

y_true = np.array([[[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 0, 1]], [[0, 1, 0], [0, 0, 1]]])
print(f"Loss at end of sequence: {cell.compute_loss(y_true)}")
cell.bptt(y_hat, y_true)

Time step 1: Hidden state =
(5, 2), 
Y_hat = 
(3, 2) 
Time step 2: Hidden state =
(5, 2), 
Y_hat = 
(3, 2) 
Time step 3: Hidden state =
(5, 2), 
Y_hat = 
(3, 2) 
Loss at end of sequence: 2.1971992325361147
dL_dyhat: [[ 0.33336584  0.33336301]
 [ 0.33331019 -0.66669399]
 [ 0.33332397 -0.66666902]], y_hat shape: (3, 2), dL_dyhat shape: (3, 2)
dL_dWy: [[-0.00139308  0.00110765 -0.0009676   0.00908511  0.00219662]
 [ 0.00821907  0.00413507 -0.00278707  0.0033507  -0.00675695]
 [ 0.00102602  0.00217657 -0.00064715 -0.0056532  -0.00240306]], Wy shape: (3, 5), dL_dWy shape: (3, 5)
dL_dby: [[ 1.86556336e-04]
 [-1.51406478e-04]
 [-3.51498581e-05]], by shape: (3, 1), dL_dby shape: (3, 1)
dL_ht: [[ 0.00178467 -0.0107689 ]
 [-0.00561644  0.00662652]
 [ 0.00023716  0.01009848]
 [ 0.00428847 -0.00514094]
 [ 0.00081078 -0.01925459]], ht shape: (5, 2), dL_ht shape: (5, 2)
dL_dtanh: [[-0.00898381 -0.00898439 -0.00898424 -0.00898284 -0.00898365]
 [ 0.00101028  0.0010106   0.00101017  0.00100989  0.00101

In [64]:
# test parameters are returned
np.random.seed(42)

batch_size = 2
seq_length = 3
input_size = 4
hidden_size = 5
output_size = 3
cell2 = RNNCell(hidden_size, input_size, batch_size, output_size)
params = cell2.parameters()

assert np.allclose(cell2.Wxh, params[0])
assert np.allclose(cell2.Whh, params[1])
assert np.allclose(cell2.ba, params[2])
assert np.allclose(cell2.Wy, params[3])
assert np.allclose(cell2.by, params[4])


In [281]:
from collections.abc import Iterable

class OneHotEncoder(Module):

    def __init__(self, num_classes: int) -> None:
        super().__init__(self)
        self.num_classes = num_classes

    def forward(self, indices: int | np.ndarray | list[int]) -> np.ndarray:
        if isinstance(indices, Iterable):
            if len(indices) > self.num_classes:
                raise ValueError("Cannot have more 1s than number of classes")
        encoding = np.zeros((self.num_classes,))
        encoding[indices] = 1 
        return encoding


In [282]:
# test onehot encoder class

encoder = OneHotEncoder(5)
encoder(3), encoder([1, 2])

(array([0., 0., 0., 1., 0.]), array([0., 1., 1., 0., 0.]))