In [3]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import warnings

sns.set_theme(style="darkgrid")
warnings.filterwarnings("ignore")

## RNN

In [4]:
class RNN:
    def __init__(self, x_embd, h_embd, y_embd, lr=1e-3):
        self.x_embd = x_embd
        self.h_embd = h_embd
        self.y_embd = y_embd
        self._initialize_parameters(x_embd, h_embd, y_embd)
        self.lr = lr

    def _initialize_parameters(self, 
                               x_embd,
                               h_embd,
                               y_embd):
        """
        Initializing parameters for our RNN
        """
        Wxh = np.random.randn(h_embd, x_embd) * 0.01
        Whh = np.random.randn(h_embd, h_embd) * 0.01
        Why = np.random.randn(y_embd, h_embd) * 0.01
        bh = np.random.randn(h_embd, 1)
        by = np.random.randn(y_embd, 1)

        self.parameters = {
            "Wxh": Wxh,
            "Whh": Whh,
            "Why": Why,
            "bh": bh,
            "by": by,
        }

        self.gradients = dict()

        self.gradients["Wxh"] = np.zeros_like(self.parameters["Wxh"])
        self.gradients["Why"] = np.zeros_like(self.parameters["Why"])
        self.gradients["Whh"] = np.zeros_like(self.parameters["Whh"])
        self.gradients["by"] = np.zeros_like(self.parameters["by"])
        self.gradients["bh"] = np.zeros_like(self.parameters["bh"])
        self.gradients["dnext_h"] = np.zeros_like(self.parameters["bh"])


    def _softmax(self, x):
        """
        x: numpy array of shape (-1, 1)
        """
        temp = np.exp(x - np.max(x))
        return temp/ temp.sum(axis=0)

    def _mse(self, x):
        pass
    
    def __one_step_forward(self, h_prev, x):
        """
        Does one step forward for the RNN
        Returns: logits
        """
        h = np.tanh(np.dot(self.parameters["Wxh"], x) + np.dot(self.parameters["Whh"], h_prev) + self.parameters["bh"]) # (h_embd, 1)
        logits = np.dot(self.parameters["Why"], h) + self.parameters["by"] # (y_embd, 1)
        # preds = self._softmax(logits)

        return h, logits
    
    def forward(self, X, Y):
        """
        Implementation of a full forward pass of the RNN
        Args:
            X: list of indices (C, 1), Here C is the nubmer of examples
            Y: list of indices (C, 1)

        """
        self.x, self.y_hat, self.h, self.y = {}, {}, {}, {}

        # zeroth hidden state (h0)
        self.h[-1] = np.zeros((self.h_embd, 1))

        for i in range(len(X)):
            self.x[i] = np.zeros((self.x_embd, 1))

            # Make it a one-hot vector now
            self.x[i][X[i]] = 1


            # ----------------
            # same for y
            self.y[i] = np.zeros((self.y_embd, 1))
            self.y[i][Y[i]] = 1
            # -----------------


            self.h[i], self.y_hat[i] = self.__one_step_forward(h_prev=self.h[i-1],
                                                               x=self.x[i])
            
            # loss +=  Optional...

    def _one_step_backward(self, x, h, h_prev, dy):
        """
        Args:
            dy: dL/d y_pred
        One backward step, and accumulation of the needed gradients
        equations are:
        
        _h(t) = Wxh @ x + Whh @ h_prev + bh
        h(t) = tanh(_h(t)) + dnext_h

        # --------------------
        ### IMPORTANT HERE
        The dnext_h is the contribution of the next hidden state i.e., h(t+1)
        Now the equations...
        h(t+1) = Whh X h(t)
        now we already know dL/ dh(t+1) which is simply dnext_h
        so dh(t+1)/ dh(t) = Whh
        now dL/ dh(t) = Whh X dnext_h (which is form the succeeding hidden layers.)
        that's why we have to add it like this :).
        # --------------------

        y_pred(t) = Why @ h(t) + by

        Loss = (y - y_pred) ** 2/2 (For simplicity rn)

        Now for gradients:
        
        dL/d y_pred = (y - y_pred) -> (y_embd, 1)
        dL/d Why = (y - y_pred) @ h(t).T -> (y_embd, 1) X (1, h_embd) -> (y_embd, h_embd)
        dL/d by = (y - y_pred)

        dL/d h(t) = Why.T @ (y - y_pred) -> (h_embd, y_embd) X (y_embd, 1) -> (h_embd, 1)
        dL/d _h(t) = dL/ d h(t) * (1 - h(t) ** 2) -> (h_embd, 1)
        dL/d Wxh = dL/d _h(t) @ x.T -> (h_embd, 1) @ (1, x_embd) -> (h_embd, x_embd)
        dL/d bh = dL/d _h(t) -> (h_embd, 1)

        dL/d Whh = dL/d _h(t) @ h_prev.T -> (h_embd, 1)

        """
        self.gradients["Why"] += dy @ h.T
        self.gradients["by"] += dy

        dL_d_h = self.parameters["Why"].T @ dy + self.parameters["Whh"] @ self.gradients["dnext_h"]
        dL_d_h_ = dL_d_h * (1 - h**2) # _h(t)
        self.gradients["Wxh"] += dL_d_h_ @ x.T
        self.gradients["Whh"] += dL_d_h_ @ h_prev.T
        self.gradients["bh"] += dL_d_h_
        self.gradients["dnext_h"] = dL_d_h_

    def rnn_backward(self, X, Y):
        """
        Performs full BPT (Backpropagation through time)
        Args:
            X: list of indices (C, 1), Here C is the nubmer of examples
            Y: list of indices (C, 1)
        """
        for t in reversed(range(len(X))):
            dy = self.y[t] - self.y_hat[t]
            #  dy = dL/ dy_pred = (y - y_pred)

            self._one_step_backward(x=self.x[t],
                                    h=self.h[t],
                                    h_prev=self.h[t-1],
                                    dy=dy)
            
    def update_parameters(self, lr=None):
        """
        Does simple SGD optimization
        """
        if lr == None:
            lr = self.lr
        self.parameters["Wxh"] += -lr * self.gradients["Wxh"]
        self.parameters["Whh"] += -lr * self.gradients["Whh"]
        self.parameters["Why"] += -lr * self.gradients["Why"]
        self.parameters["by"] += -lr * self.gradients["by"]
        self.parameters["bh"] += -lr * self.gradients["bh"]


In [5]:
m = RNN(x_embd=6,
        h_embd=7,
        y_embd=5)

In [6]:
X = [1, 3, 3]
Y = [0, 2, 1]
m.forward(X, Y)

In [7]:
m.rnn_backward(X, Y)

In [8]:
m.gradients["Whh"]

array([[ 0.01020338, -0.01228859,  0.01589812,  0.01646996, -0.0136547 ,
         0.01322151,  0.0056276 ],
       [-0.0126856 ,  0.01523843, -0.01980315, -0.02049088,  0.01698445,
        -0.01647989, -0.00700893],
       [-0.0042678 ,  0.00512953, -0.00665962, -0.00689269,  0.00571348,
        -0.00554126, -0.00235711],
       [-0.00230374,  0.00275613, -0.00360689, -0.00372522,  0.00308666,
        -0.00300465, -0.00127631],
       [-0.00224441,  0.0027467 , -0.0034559 , -0.00360724,  0.00299489,
        -0.00286218, -0.00122439],
       [ 0.00460515, -0.00558828,  0.00713577,  0.00741846, -0.00615449,
         0.00592295,  0.00252694],
       [-0.02656744,  0.03174418, -0.04163379, -0.04297473,  0.03560432,
        -0.03469316, -0.01473129]])