# Creating a RNN block

In [1]:
import numpy as np
import functools

In [2]:
def softmax(logits: np.ndarray) -> np.ndarray:
    """ Returns probabilities """
    return np.exp(logits - np.max(logits))/np.sum(np.exp(logits - np.max(logits)), axis=0, keepdims=True)

In [3]:
def relu(x: np.ndarray) -> np.ndarray:
    """ Rectified Linear Unit activation function """
    return np.fmax(0, x)

In [4]:
# test softmax
logits = np.array([2.0, 1.0, 0.1])
softmax(logits)

array([0.65900114, 0.24243297, 0.09856589])

In [5]:
def cross_entropy(y_true: np.ndarray, y_hat: np.ndarray) -> float:
    """ Cross entropy loss """
    return -np.sum(y_true * np.log(y_hat))

In [6]:
# test loss function
y_true = np.array([0, 1, 0, 0, 0])              # True distribution
y_pred = np.array([0.1, 0.6, 0.1, 0.15, 0.05])  # Predicted distribution

print(f"Cross Entropy: {cross_entropy(y_true, y_pred):.2f}")

Cross Entropy: 0.51


In [7]:
class Module:

    def __init__(self, cls) -> None:
        functools.update_wrapper(self, cls)

    def __call__(self, *args):
        return self.forward(*args)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'

In [137]:
class RNNCell(Module):

    def __init__(
        self, 
        dim_hidden_units: int, 
        dim_input: int, 
        batch_size: int, 
        dim_output: int, 
        activation=np.tanh, 
        loss=cross_entropy,
        bool_pred = True
    ) -> None:
        super().__init__(self)
        # input size is ((n_x), m) where n_x is input dimensions
        self.Wxh = np.random.randn(dim_hidden_units, dim_input) # we assume the last shape is T_x
        self.Whh = np.random.randn(dim_hidden_units, dim_hidden_units)
        self.Wy = np.random.randn(dim_output, dim_hidden_units) 
        self.ba = np.zeros((dim_hidden_units, 1))
        self.by = np.zeros((dim_output, 1))
        self.activation = activation
        self.loss = loss
        self.caches = []
        self.bool_pred = bool_pred

    def forward(self, x: np.ndarray, hidden_state_prev: np.ndarray) -> tuple[np.ndarray, np.ndarray]: 
        
        stack = np.vstack((x, hidden_state_prev)) # stack the inputs together
        Wa = np.hstack((self.Wxh, self.Whh)) # stack the matrices together
        z = Wa @ stack + self.ba
        hidden_state = self.activation(z)
        logits = self.Wy @ hidden_state + self.by
        if self.bool_pred: y_hat = softmax(logits)

        cache = {}
        cache['x'] = x
        cache['hidden_state_prev'] = hidden_state_prev
        cache['z'] = z
        cache['hidden_state'] = hidden_state
        if self.bool_pred: cache['y_hat'] = y_hat 

        self.caches.append(cache)

        return hidden_state, y_hat if self.bool_pred else hidden_state

    def reset_sequence(self) -> None:
        self.caches = []

    def compute_loss(self, y_true: np.ndarray | list[float | int]) -> float:
        outputs = [cache['y_hat'] for cache in self.caches]
        y_hats = np.stack(tuple(outputs), axis=-1)
        return np.sum(self.loss(y_true, y_hats)) / len(outputs)
    
    def parameters(self) -> list:
        return [self.Wxh, self.Whh, self.ba, self.Wy, self.by]

    def bptt(self, y_true: np.ndarray, verbose=False) -> list:
        T = len(self.caches)

        dL_dWy = np.zeros_like(self.Wy)
        dL_dby = np.zeros_like(self.by)
        dL_dWxh = np.zeros_like(self.Wxh)
        dL_dWhh = np.zeros_like(self.Whh)
        dL_dba = np.zeros_like(self.ba)
        
        dL_dht_next = np.zeros_like(self.caches[0]['hidden_state'])
        
        for t, cache in enumerate(self.caches[::-1]):
            # each cache represents a time step
            curr_hidden_state = cache['hidden_state']
            z = cache['z']
            y_pred = cache['y_hat']
            
            dL_dyhat = y_pred - y_true[:, :, t]
            dL_dht = self.Wy.T @ dL_dyhat + dL_dht_next
            dL_dtanh = dL_dht * (1 - curr_hidden_state ** 2)
                        
            dL_dWy += dL_dyhat @ curr_hidden_state.T
            dL_dby += np.sum(dL_dyhat, axis=-1, keepdims=True)
            
            x = cache['x']
            hidden_state_prev = cache['hidden_state_prev']
            dL_dWxh += dL_dtanh @ x.T  
            dL_dWhh += dL_dtanh @ hidden_state_prev.T
            dL_dba += np.sum(dL_dtanh, axis=-1, keepdims=True)
            
            dL_dht_next = self.Whh.T @ dL_dtanh
            
        return [dL_dWxh, dL_dWhh, dL_dba, dL_dWy, dL_dby]
    
        if verbose:
            print(f"dL_dyhat: {dL_dyhat}, y_hat shape: {y_pred.shape}, dL_dyhat shape: {dL_dyhat.shape}")
            print(f"dL_dWy: {dL_dWy}, Wy shape: {self.Wy.shape}, dL_dWy shape: {dL_dWy.shape}")
            print(f"dL_dby: {dL_dby}, by shape: {self.by.shape}, dL_dby shape: {dL_dby.shape}")
            print(f"dL_ht: {dL_ht}, ht shape: {curr_hidden_state.shape}, dL_ht shape: {dL_ht.shape}")
            print(f"dL_dtanh: {dL_dtanh}, z shape: {curr_hidden_state.shape}, dL_dtanh shape: {dL_dtanh.shape}")
        
    def optimizer_step(self, learning_rate: float, grads: list) -> None:
        params = self.parameters()
        for param, grad in zip(params, grads):
            param -= learning_rate * grad

In [83]:
# test rnn forward and backward
np.random.seed(42)

batch_size = 2
seq_length = 3
input_size = 4
hidden_size = 5
output_size = 3
cell = RNNCell(hidden_size, input_size, batch_size, output_size)

input_sequence = np.random.rand(input_size, batch_size, seq_length)
hidden_state = np.zeros((hidden_size, batch_size))

for t in range(seq_length):
    input_t = input_sequence[:, :, t] 
    hidden_state, y_hat = cell(input_t, hidden_state)
    print(f"Time step {t + 1}: Hidden state =\n{hidden_state.shape}, \nY_hat = \n{y_hat.shape} ")

print(f'{'=' * 100}')

y_true = np.array([[[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 0, 1]], [[0, 1, 0], [0, 0, 1]]])
print(f"Loss at end of sequence: {cell.compute_loss(y_true)}")
grads = cell.bptt(y_true)

Time step 1: Hidden state =
(5, 2), 
Y_hat = 
(3, 2) 
Time step 2: Hidden state =
(5, 2), 
Y_hat = 
(3, 2) 
Time step 3: Hidden state =
(5, 2), 
Y_hat = 
(3, 2) 
Loss at end of sequence: 2.1971992325361147


In [64]:
# test parameters are returned
np.random.seed(42)

batch_size = 2
seq_length = 3
input_size = 4
hidden_size = 5
output_size = 3
cell2 = RNNCell(hidden_size, input_size, batch_size, output_size)
params = cell2.parameters()

assert np.allclose(cell2.Wxh, params[0])
assert np.allclose(cell2.Whh, params[1])
assert np.allclose(cell2.ba, params[2])
assert np.allclose(cell2.Wy, params[3])
assert np.allclose(cell2.by, params[4])


In [281]:
from collections.abc import Iterable

class OneHotEncoder(Module):

    def __init__(self, num_classes: int) -> None:
        super().__init__(self)
        self.num_classes = num_classes

    def forward(self, indices: int | np.ndarray | list[int]) -> np.ndarray:
        if isinstance(indices, Iterable):
            if len(indices) > self.num_classes:
                raise ValueError("Cannot have more 1s than number of classes")
        encoding = np.zeros((self.num_classes,))
        encoding[indices] = 1 
        return encoding


In [282]:
# test onehot encoder class

encoder = OneHotEncoder(5)
encoder(3), encoder([1, 2])

(array([0., 0., 0., 1., 0.]), array([0., 1., 1., 0., 0.]))

In [140]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class TorchRNNCell(nn.Module):
    def __init__(self, dim_hidden_units, dim_input, dim_output):
        super(TorchRNNCell, self).__init__()
        self.Wxh = nn.Parameter(torch.randn(dim_hidden_units, dim_input))
        self.Whh = nn.Parameter(torch.randn(dim_hidden_units, dim_hidden_units))
        self.Wy = nn.Parameter(torch.randn(dim_output, dim_hidden_units))
        self.ba = nn.Parameter(torch.zeros(dim_hidden_units, 1))
        self.by = nn.Parameter(torch.zeros(dim_output, 1))

    def forward(self, x, hidden_state_prev):
        stack = torch.cat((x, hidden_state_prev), dim=0)
        Wa = torch.cat((self.Wxh, self.Whh), dim=1)
        z = Wa @ stack + self.ba
        hidden_state = torch.tanh(z)
        logits = self.Wy @ hidden_state + self.by
        y_hat = F.softmax(logits, dim=0)
        return hidden_state, y_hat

def compute_loss_torch(y_pred, y_true):
    batch_size, seq_len = y_true.size(1), y_true.size(2)
    y_true_reshaped = y_true.permute(1, 2, 0).contiguous().view(-1, y_true.size(0))
    y_pred_repeated = y_pred.unsqueeze(1).expand(-1, seq_len, -1).contiguous().view(-1, y_pred.size(0))
    
    if y_true_reshaped.dtype == torch.long:
        y_true_reshaped = y_true_reshaped.float()
    
    loss_fn = nn.CrossEntropyLoss()
    return loss_fn(y_pred_repeated, y_true_reshaped)


def compare_gradients(custom_rnn, torch_rnn, x_np, hidden_np, y_true_np, learning_rate=0.01):
    custom_rnn.reset_sequence()
    hidden_custom, y_hat_custom = custom_rnn.forward(x_np, hidden_np)
    custom_loss = custom_rnn.compute_loss(y_true_np)
    grads_custom = custom_rnn.bptt(y_true_np)
    
    x_torch = torch.tensor(x_np, dtype=torch.float32, requires_grad=True)
    hidden_torch = torch.tensor(hidden_np, dtype=torch.float32, requires_grad=True)
    y_true_torch = torch.tensor(y_true_np, dtype=torch.long)
    hidden_torch_out, y_hat_torch = torch_rnn(x_torch, hidden_torch)
    loss_torch = compute_loss_torch(y_hat_torch, y_true_torch)
    loss_torch.backward()

    grads_torch = {
        'Wxh': torch_rnn.Wxh.grad.detach().numpy(),
        'Whh': torch_rnn.Whh.grad.detach().numpy(),
        'ba': torch_rnn.ba.grad.detach().numpy(),
        'Wy': torch_rnn.Wy.grad.detach().numpy(),
        'by': torch_rnn.by.grad.detach().numpy()
    }

    grads_custom_dict = {
        'Wxh': grads_custom[0],
        'Whh': grads_custom[1],
        'ba': grads_custom[2],
        'Wy': grads_custom[3],
        'by': grads_custom[4]
    }

    print("Comparing gradients between Custom RNN and Torch RNN:")
    for name, grad_torch in grads_torch.items():
        grad_custom = grads_custom_dict[name]
        print(f"Gradient for {name}:")
        # print(f"Custom RNN grad:\n{grad_custom}")
        # print(f"Torch RNN grad:\n{grad_torch}")
        # print(f"Difference:\n{np.abs(grad_custom - grad_torch)}\n")
        print(f'Allclose:\n{np.allclose(grad_custom, grad_torch)}')

dim_hidden_units = 5
dim_input = 4
dim_output = 3
batch_size = 2
seq_length = 3

custom_rnn = RNNCell(dim_hidden_units, dim_input, batch_size, dim_output)
torch_rnn = TorchRNNCell(dim_hidden_units, dim_input, dim_output)

x_np = np.random.randn(dim_input, batch_size)
hidden_np = np.random.randn(dim_hidden_units, batch_size)
y_true_np = np.array([[[1, 0, 0], [0, 1, 0]], [[1, 0, 0], [0, 0, 1]], [[0, 1, 0], [0, 0, 1]]])

compare_gradients(custom_rnn, torch_rnn, x_np, hidden_np, y_true_np)

Comparing gradients between Custom RNN and Torch RNN:
Gradient for Wxh:
Allclose:
False
Gradient for Whh:
Allclose:
False
Gradient for ba:
Allclose:
False
Gradient for Wy:
Allclose:
False
Gradient for by:
Allclose:
False
