# Chapter 6: Recurrent Neural Networks (RNNs)

This notebook covers **Chapter 6** of the Deep Learning in Hebrew book, focusing on Recurrent Neural Networks (RNNs). We'll learn how to build networks that can process sequential data and maintain memory of past information.

## Overview

The advantage of convolutional layers over FC layers is the utilization of the spatial relationship between different elements in the data, such as the relationship between nearby pixels in an image. However, there are other types of data where the relationship is not spatial but **temporal** - the order of elements is important, such as in text, speech, DNA sequences, stock prices, etc.

Unlike spatial data, temporal data has a **sequence** - the order of elements matters, and the same element cannot appear multiple times. For this reason, convolutional networks are not suitable for this type of data, and a different architecture is needed.

Recurrent Neural Networks (RNNs) are designed to handle sequential data. They receive a sequence of vectors and output a single vector, where the single vector encodes relationships over the original data. This vector can then be passed to an FC layer for the task at hand.

---

## Table of Contents

### 6.1 Sequence Models
- 6.1.1 [Vanilla Recurrent Neural Networks](#611-vanilla-recurrent-neural-networks)
- 6.1.2 [Learning Parameters](#612-learning-parameters)

### 6.2 RNN Architectures
- 6.2.1 [Long Short-Term Memory (LSTM)](#621-long-short-term-memory-lstm)
- 6.2.2 [Gated Recurrent Units (GRU)](#622-gated-recurrent-units-gru)
- 6.2.3 [Deep RNN](#623-deep-rnn)
- 6.2.4 [Bidirectional RNN](#624-bidirectional-rnn)

## Setup and Imports

Let's start by importing the necessary libraries for this chapter.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Set matplotlib style
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 12

# 6.1 Sequence Models

## 6.1.1 Vanilla Recurrent Neural Networks

Recurrent networks are a generalization of deep neural networks, where they contain weights that give meaning to the **order of input elements**. We can look at these weights as an internal memory component, where in addition to the input vector, there is an additional variable component that depends on past values.

Given an input vector $x_t$, it is multiplied by weight $w_{xh}$ and enters the memory component $h_t$, where $h_t$ is a function of $x_t$ and $h_{t-1}$:

$$h_t = f(h_{t-1}, x_t)$$

In addition to the weights operating on the input vector, there are also weights operating on the internal weights (memory component) - $w_{hh}$, and weights operating on the output of this component - $w_{hy}$. The weights $w_{xh}, w_{hh}, w_{hy}$ are shared and updated together.

Also, the function $f_w$ is constant for all elements, for example tanh, sigmoid, or ReLU. Formally, we can write:

$$h_t = f_w(w_{hh}h_{t-1} + w_{xh}x_t), \quad f_w = \tanh/\text{ReLU}/\text{sigmoid}$$

$$y_t = w_{hy}h_t$$

### RNN Architectures

There are several different RNN models, suitable for different problems:

- **One to Many**: There is a single input and we want to output a sequence of outputs, for example - generating a description for an image (Image captioning).
- **Many to One**: A sequence of inputs and a single output - for example, to classify a sentence and determine its sentiment - whether it is positive or negative.
- **Many to Many**: For each input sequence there is an output sequence, for example translation from one language to another - input sequence and output sequence.

In [None]:
# Vanilla RNN Implementation from Scratch
class VanillaRNN:
    """
    Simple Vanilla RNN implementation from scratch.
    """
    def __init__(self, input_size, hidden_size, output_size):
        # Initialize weights
        self.W_xh = np.random.randn(input_size, hidden_size) * 0.01
        self.W_hh = np.random.randn(hidden_size, hidden_size) * 0.01
        self.W_hy = np.random.randn(hidden_size, output_size) * 0.01
        
        # Initialize biases
        self.b_h = np.zeros((1, hidden_size))
        self.b_y = np.zeros((1, output_size))
        
        self.hidden_size = hidden_size
        
    def forward(self, x_seq):
        """
        Forward pass through the RNN.
        
        Parameters:
        -----------
        x_seq : list of arrays, shape (T, input_size)
            Input sequence
        
        Returns:
        --------
        h_states : list of arrays
            Hidden states at each time step
        y_outputs : list of arrays
            Outputs at each time step
        """
        T = len(x_seq)
        h_states = []
        y_outputs = []
        
        # Initialize hidden state
        h_prev = np.zeros((1, self.hidden_size))
        
        for t in range(T):
            # Compute hidden state
            h_t = np.tanh(np.dot(x_seq[t], self.W_xh) + 
                         np.dot(h_prev, self.W_hh) + self.b_h)
            
            # Compute output
            y_t = np.dot(h_t, self.W_hy) + self.b_y
            
            h_states.append(h_t)
            y_outputs.append(y_t)
            h_prev = h_t
        
        return h_states, y_outputs

# Example: Many-to-One RNN
input_size = 3
hidden_size = 4
output_size = 2

rnn = VanillaRNN(input_size, hidden_size, output_size)

# Create a sequence of inputs
T = 5
x_sequence = [np.random.randn(1, input_size) for _ in range(T)]

# Forward pass
h_states, y_outputs = rnn.forward(x_sequence)

print("Vanilla RNN - Many-to-One Example:")
print("=" * 50)
print(f"Input sequence length: {T}")
print(f"Input size: {input_size}")
print(f"Hidden size: {hidden_size}")
print(f"Output size: {output_size}")
print(f"\nHidden states shape at each step: {h_states[0].shape}")
print(f"Outputs shape at each step: {y_outputs[0].shape}")
print(f"\nFinal output (for classification): {y_outputs[-1]}")

In [None]:
# Visualize RNN Architecture
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Many-to-One
ax1 = axes[0]
ax1.set_xlim(-0.5, 5.5)
ax1.set_ylim(-0.5, 3.5)
ax1.set_aspect('equal')
ax1.axis('off')
ax1.set_title('Many-to-One RNN', fontsize=14, weight='bold', pad=20)

# Input sequence
for i in range(4):
    circle = plt.Circle((i, 2.5), 0.3, color='lightblue', ec='black', linewidth=2)
    ax1.add_patch(circle)
    ax1.text(i, 2.5, f'$x_{i+1}$', ha='center', va='center', fontsize=10)
    if i < 3:
        ax1.arrow(i+0.3, 2.5, 0.4, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')

# Hidden states
for i in range(4):
    circle = plt.Circle((i, 1.5), 0.3, color='lightgreen', ec='black', linewidth=2)
    ax1.add_patch(circle)
    ax1.text(i, 1.5, f'$h_{i+1}$', ha='center', va='center', fontsize=10)
    if i < 3:
        ax1.arrow(i+0.3, 1.5, 0.4, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')
    # Recurrent connection
    if i > 0:
        ax1.plot([i-0.3, i-0.3, i-0.3], [1.8, 2.2, 1.8], 'r-', linewidth=1.5, alpha=0.5)
        ax1.arrow(i-0.3, 1.8, 0, -0.2, head_width=0.08, head_length=0.08, fc='red', ec='red', alpha=0.5)

# Output
circle = plt.Circle((3, 0.5), 0.3, color='lightyellow', ec='black', linewidth=2)
ax1.add_patch(circle)
ax1.text(3, 0.5, '$y$', ha='center', va='center', fontsize=12, weight='bold')
ax1.arrow(3, 1.2, 0, -0.4, head_width=0.1, head_length=0.1, fc='black', ec='black')

# Many-to-Many
ax2 = axes[1]
ax2.set_xlim(-0.5, 5.5)
ax2.set_ylim(-0.5, 3.5)
ax2.set_aspect('equal')
ax2.axis('off')
ax2.set_title('Many-to-Many RNN', fontsize=14, weight='bold', pad=20)

# Input sequence
for i in range(4):
    circle = plt.Circle((i, 2.5), 0.3, color='lightblue', ec='black', linewidth=2)
    ax2.add_patch(circle)
    ax2.text(i, 2.5, f'$x_{i+1}$', ha='center', va='center', fontsize=10)
    if i < 3:
        ax2.arrow(i+0.3, 2.5, 0.4, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')

# Hidden states
for i in range(4):
    circle = plt.Circle((i, 1.5), 0.3, color='lightgreen', ec='black', linewidth=2)
    ax2.add_patch(circle)
    ax2.text(i, 1.5, f'$h_{i+1}$', ha='center', va='center', fontsize=10)
    if i < 3:
        ax2.arrow(i+0.3, 1.5, 0.4, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')
    # Recurrent connection
    if i > 0:
        ax2.plot([i-0.3, i-0.3, i-0.3], [1.8, 2.2, 1.8], 'r-', linewidth=1.5, alpha=0.5)
        ax2.arrow(i-0.3, 1.8, 0, -0.2, head_width=0.08, head_length=0.08, fc='red', ec='red', alpha=0.5)

# Output sequence
for i in range(4):
    circle = plt.Circle((i, 0.5), 0.3, color='lightyellow', ec='black', linewidth=2)
    ax2.add_patch(circle)
    ax2.text(i, 0.5, f'$y_{i+1}$', ha='center', va='center', fontsize=10)
    ax2.arrow(i, 1.2, 0, -0.4, head_width=0.1, head_length=0.1, fc='black', ec='black')

plt.tight_layout()
plt.show()

print("RNN Architecture Types:")
print("  - Many-to-One: Sequence → Single output (e.g., sentiment classification)")
print("  - Many-to-Many: Sequence → Sequence (e.g., machine translation)")
print("  - One-to-Many: Single input → Sequence (e.g., image captioning)")

## 6.1.2 Learning Parameters

The learning of the network is done in a similar way to networks in previous chapters. For data $x = (x_1, \ldots, x_n), (y_1, \ldots, y_n)$, we define the cost function:

$$L(\theta) = \frac{1}{n}\sum_i L(\hat{y}_i, y_i, \theta)$$

where the function $L(\hat{y}_i, y_i, \theta)$ depends on the task - for classification tasks we use cross entropy, and for regression tasks we use the MSE criterion.

Training is performed using GD, but we cannot use regular backpropagation because the weights operate on multiple inputs - $w_{xh}$ operates on all inputs and $w_{hh}$ operates on all memory components. To update the weights, we use **Backpropagation Through Time (BPTT)** - we look at the network as one large network, calculate the gradient for each weight, and then sum or average all the gradients.

If the input has $n$ elements, then there are $n$ memory components, and $n-1$ uses of $w_{hh}$. The gradient for each use is:

$$\frac{\partial L}{\partial w_{hh}} = \sum_{t=1}^{n-1} \frac{\partial L}{\partial w_{hh}^{(t)}} \quad \text{or} \quad \frac{\partial L}{\partial w_{hh}} = \frac{1}{n-1}\sum_{t=1}^{n-1} \frac{\partial L}{\partial w_{hh}^{(t)}}$$

Since the weights are identical throughout the network, $w_{hh}^{(t)} = w_{hh}$, and the change in time will only be after performing BPTT and will be relevant to the entire vector.

### The Vanishing/Exploding Gradient Problem

The simple form of BPTT creates a problem with the gradient. Suppose we write the hidden state explicitly:

$$h_t = f(z_t) = f(w_{hh}h_{t-1} + w_{xh}x_t + b_h)$$

According to the chain rule:

$$\frac{\partial h_n}{\partial x_1} = \frac{\partial h_n}{\partial h_{n-1}} \times \frac{\partial h_{n-1}}{\partial h_{n-2}} \times \cdots \times \frac{\partial h_2}{\partial h_1} \times \frac{\partial h_1}{\partial x_1}$$

Since $w_{hh}$ is constant with respect to time for a single input vector, we get:

$$\frac{\partial h_t}{\partial h_{t-1}} = f'(z_t) \cdot w_{hh}$$

If we substitute this in the chain rule, we get that for calculating the derivative $\frac{\partial h_n}{\partial x_1}$, we multiply $n-1$ times by $w_{hh}$. Therefore, if $||w_{hh}|| > 1$, then the gradient will explode, and if $||w_{hh}|| < 1$, the gradient will vanish.

For the problem of exploding gradient, we can perform **gradient clipping** - if the gradient is too large, we normalize it:

$$\text{if } ||g|| > c, \text{ then } g = c\frac{g}{||g||}$$

The problem of vanishing gradient does not cause calculations of huge numbers, but it actually prevents learning long-term dependencies. If the sequence is long, then the gradient through BPTT will be very small, and there is almost no influence of the first word on the last word. In other words - the vanishing gradient causes a problem of **Long-term dependency**, meaning it is difficult to learn data that changes slowly or sequences that are long.

Because of this problem, we do not use the classic RNN (also called **Vanilla RNN**), but rather more advanced architectures that will be detailed in the following sections.

In [None]:
# Demonstrate Vanishing/Exploding Gradient Problem
def compute_gradient_through_time(w_hh, n_steps, activation='tanh'):
    """
    Compute gradient through time for RNN.
    
    Shows how gradient changes as we backpropagate through time.
    """
    gradients = []
    
    if activation == 'tanh':
        # For tanh, derivative is bounded by 1
        # But we multiply by w_hh each step
        grad = 1.0
        for t in range(n_steps):
            grad = grad * w_hh  # Simplified: assuming f'(z) ≈ 1
            gradients.append(grad)
    elif activation == 'sigmoid':
        # For sigmoid, derivative is bounded by 0.25
        grad = 1.0
        for t in range(n_steps):
            grad = grad * 0.25 * w_hh
            gradients.append(grad)
    
    return gradients

# Test different weight values
n_steps = 20
w_values = [0.5, 0.9, 1.0, 1.1, 1.5]

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Vanishing gradient (w < 1)
ax1 = axes[0]
for w in [0.5, 0.9]:
    grads = compute_gradient_through_time(w, n_steps, 'tanh')
    ax1.plot(range(1, n_steps+1), grads, label=f'$w_{{hh}} = {w}$', linewidth=2)

ax1.set_xlabel('Time Steps', fontsize=12)
ax1.set_ylabel('Gradient Magnitude', fontsize=12)
ax1.set_title('Vanishing Gradient Problem ($w_{{hh}} < 1$)', fontsize=14, weight='bold')
ax1.legend(fontsize=11)
ax1.grid(True, alpha=0.3)
ax1.set_yscale('log')

# Exploding gradient (w > 1)
ax2 = axes[1]
for w in [1.1, 1.5]:
    grads = compute_gradient_through_time(w, n_steps, 'tanh')
    ax2.plot(range(1, n_steps+1), grads, label=f'$w_{{hh}} = {w}$', linewidth=2)

ax2.set_xlabel('Time Steps', fontsize=12)
ax2.set_ylabel('Gradient Magnitude', fontsize=12)
ax2.set_title('Exploding Gradient Problem ($w_{{hh}} > 1$)', fontsize=14, weight='bold')
ax2.legend(fontsize=11)
ax2.grid(True, alpha=0.3)
ax2.set_yscale('log')

plt.tight_layout()
plt.show()

print("Gradient Problems in RNNs:")
print("=" * 50)
print("Vanishing Gradient ($w_{{hh}} < 1$):")
print("  - Gradient becomes exponentially small")
print("  - Cannot learn long-term dependencies")
print("  - Early time steps have negligible influence")
print("\nExploding Gradient ($w_{{hh}} > 1$):")
print("  - Gradient becomes exponentially large")
print("  - Causes numerical instability")
print("  - Can be fixed with gradient clipping")

# 6.2 RNN Architectures

## 6.2.1 Long Short-Term Memory (LSTM)

To overcome the problem of vanishing gradient that prevents the network from using long-term memory, we need to modify the memory component so that it does not only contain information about the past, but also has control over how and how much to use the information.

In the simple RNN, the memory component has two inputs - $h_{t-1}, x_t$, and with their help we compute the hidden state using $f_w(h_{t-1}, x_t)$. In fact, the memory component is fixed and learning is performed only in the weights.

In **LSTM**, there are several important components - in addition to the regular inputs, there is another input called **cell state memory** and denoted by $c_{t-1}$, and in addition $h_t$ is calculated in a more complex way. In this way, the element $c_t$ takes care of long-term memory of things, and $h_t$ is responsible for short-term memory.

### LSTM Architecture

The pair $[x_t, h_{t-1}]$ enters the cell and is multiplied by weight $w$, and then passes separately through four gates (note that we do not perform operations between $x_t$ and $h_{t-1}$, but they remain separate and all operations are done on each element separately).

The first gate $f_t = [\sigma(x_t), \sigma(h_{t-1})]$ is the **forget gate** and is responsible for erasing part of the memory. If, for example, there is a new topic, then this gate should erase the topic that was stored in memory.

The second gate $i_t$ is the **input gate** (memory gate) and is responsible for how much to remember the new information for the long term. For example, if there is indeed a new topic in a certain sentence, then the gate will decide to remember this information.

The fourth gate $o_t$ is the **output gate** and is responsible for how much of the information is relevant to the current data $x_t$. The output of the cell is calculated as a function of the cell state given past information.

These three gates are called **masks**, and they receive values between 0 and 1 since they pass through sigmoid.

There is another gate $\tilde{c}_t$ (also called $g$ - the candidate for memory) that prepares the information to enter the next cell. This gate weighs the information together with $i_t$. The new information is written to memory only if it is relevant.

In this way, we get both $h_t$ which is responsible for short-term memory like in Vanilla RNN, and $c_t$ which is responsible for memory of all the past.

### LSTM Equations

Formally, we can formulate the operations as follows:

$$\begin{pmatrix} i \\ f \\ o \\ \tilde{c} \end{pmatrix} = \begin{pmatrix} \sigma \\ \sigma \\ \sigma \\ \tanh \end{pmatrix} W \begin{pmatrix} h_{t-1} \\ x_t \end{pmatrix} = \begin{pmatrix} \sigma(w_i \cdot [x_t, h_{t-1}] + b_i) \\ \sigma(w_f \cdot [x_t, h_{t-1}] + b_f) \\ \sigma(w_o \cdot [x_t, h_{t-1}] + b_o) \\ \tanh(w_{\tilde{c}} \cdot [x_t, h_{t-1}] + b_{\tilde{c}}) \end{pmatrix}$$

$$c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$$

$$h_t = o_t \odot \tanh(c_t)$$

where the operator $\odot$ symbolizes element-wise multiplication.

In [None]:
# LSTM Implementation using PyTorch
class LSTM(nn.Module):
    """
    LSTM cell implementation.
    """
    def __init__(self, input_size, hidden_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        
        # Input gate
        self.W_ii = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hi = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_i = nn.Parameter(torch.zeros(hidden_size))
        
        # Forget gate
        self.W_if = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hf = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_f = nn.Parameter(torch.ones(hidden_size))  # Initialize to 1 to remember
        
        # Output gate
        self.W_io = nn.Linear(input_size, hidden_size, bias=False)
        self.W_ho = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_o = nn.Parameter(torch.zeros(hidden_size))
        
        # Cell gate
        self.W_ig = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hg = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_g = nn.Parameter(torch.zeros(hidden_size))
    
    def forward(self, x, hidden):
        """
        Forward pass through LSTM cell.
        
        Parameters:
        -----------
        x : tensor, shape (batch, input_size)
            Input at time t
        hidden : tuple (h, c)
            h: hidden state, shape (batch, hidden_size)
            c: cell state, shape (batch, hidden_size)
        
        Returns:
        --------
        h_new : tensor
            New hidden state
        c_new : tensor
            New cell state
        """
        h_prev, c_prev = hidden
        
        # Input gate
        i_t = torch.sigmoid(self.W_ii(x) + self.W_hi(h_prev) + self.b_i)
        
        # Forget gate
        f_t = torch.sigmoid(self.W_if(x) + self.W_hf(h_prev) + self.b_f)
        
        # Output gate
        o_t = torch.sigmoid(self.W_io(x) + self.W_ho(h_prev) + self.b_o)
        
        # Cell gate (candidate)
        g_t = torch.tanh(self.W_ig(x) + self.W_hg(h_prev) + self.b_g)
        
        # Update cell state
        c_new = f_t * c_prev + i_t * g_t
        
        # Update hidden state
        h_new = o_t * torch.tanh(c_new)
        
        return h_new, c_new

# Test LSTM
input_size = 5
hidden_size = 10
lstm = LSTM(input_size, hidden_size)

# Create a sequence
seq_length = 8
batch_size = 2
x_seq = torch.randn(seq_length, batch_size, input_size)

# Initialize hidden and cell states
h = torch.zeros(batch_size, hidden_size)
c = torch.zeros(batch_size, hidden_size)

print("LSTM Forward Pass:")
print("=" * 50)
print(f"Input sequence length: {seq_length}")
print(f"Batch size: {batch_size}")
print(f"Input size: {input_size}")
print(f"Hidden size: {hidden_size}")

# Process sequence
for t in range(seq_length):
    h, c = lstm(x_seq[t], (h, c))
    print(f"Step {t+1}: h.shape={h.shape}, c.shape={c.shape}")

print(f"\nFinal hidden state shape: {h.shape}")
print(f"Final cell state shape: {c.shape}")

In [None]:
# Visualize LSTM Cell Architecture
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(-1, 10)
ax.set_ylim(-1, 8)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('LSTM Cell Architecture', fontsize=16, weight='bold', pad=20)

# Input
ax.text(0, 6, '$x_t$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
ax.text(0, 4, '$h_{t-1}$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

# Concatenate
ax.text(2, 5, 'Concat', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))
ax.arrow(0.5, 6, 0.8, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(0.5, 4, 0.8, 0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Gates
gate_y = [6.5, 5, 3.5, 2]
gate_names = ['$i_t$ (Input)', '$f_t$ (Forget)', '$o_t$ (Output)', '$\\tilde{c}_t$ (Candidate)']
gate_colors = ['lightcoral', 'lightpink', 'lightcyan', 'wheat']

for i, (y, name, color) in enumerate(zip(gate_y, gate_names, gate_colors)):
    ax.text(4.5, y, name, fontsize=10, ha='center', va='center',
            bbox=dict(boxstyle='round', facecolor=color, edgecolor='black'))
    ax.arrow(2.5, 5, 1.5, y-5, head_width=0.15, head_length=0.1, fc='black', ec='black', alpha=0.6)

# Cell state
ax.text(7, 6, '$c_{t-1}$', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))

# Operations
ax.text(7, 4, '$\\times$', fontsize=16, ha='center', va='center')
ax.text(7, 2.5, '$+$', fontsize=16, ha='center', va='center')
ax.text(7, 1, '$\\times$', fontsize=16, ha='center', va='center')

# New cell state
ax.text(9, 3.5, '$c_t$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))

# New hidden state
ax.text(9, 1, '$h_t$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

# Arrows
ax.arrow(5.2, 6.5, 1.3, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.2, 5, 1.3, -1, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.2, 3.5, 1.3, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.2, 2, 1.3, 1.5, head_width=0.15, head_length=0.1, fc='black', ec='black')

ax.arrow(7.5, 3.5, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(7.5, 3.5, 0, -1, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(7.5, 1, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')

# tanh
ax.text(8.5, 2.5, 'tanh', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))

plt.tight_layout()
plt.show()

print("LSTM Key Components:")
print("=" * 50)
print("1. Forget Gate ($f_t$): Decides what to forget from $c_{t-1}$")
print("2. Input Gate ($i_t$): Decides what new information to store")
print("3. Output Gate ($o_t$): Decides what parts of $c_t$ to output")
print("4. Cell State ($c_t$): Long-term memory")
print("5. Hidden State ($h_t$): Short-term memory")

### LSTM Variants

There are different variations of LSTM components - for example, we can connect $c_{t-1}$ not only to the output $h_t$ but also to the other gates. Such a connection is called **peepholes**, because it allows the gates to observe $c_{t-1}$ directly.

There are architectures that connect $c_{t-1}$ to all gates, and there are architectures that connect it only to some gates. Connecting all gates to $c_{t-1}$ changes the gate equations from $\sigma(w \cdot [x_t, h_{t-1}] + b)$ to $\sigma(w \cdot [c_{t-1}, x_t, h_{t-1}] + b)$.

Another variation of LSTM unifies the forget gate $f_t$ with the input gate $i_t$, so that the decision of how much to forget from memory is obtained together with the decision of how much new information to write. This change affects the memory cell, where instead of $c_t = f_t \odot c_{t-1} + i_t \odot \tilde{c}_t$, the update equation becomes $c_t = f_t \odot c_{t-1} + (1 - f_t) \odot \tilde{c}_t$.

## 6.2.2 Gated Recurrent Units (GRU)

There is another architecture of memory cell called **Gated Recurrent Units (GRU)**, which has fewer components than LSTM.

### GRU Architecture

The significant change from LSTM is the fact that there is no **cell state memory**, and all gates are based only on the input and output of the previous cell. To enable memory both for the long term and the short term, there are two gates - **Reset gate** and **Update gate**, and they are computed as follows:

**Update gate**: $z_t = \sigma(w_z \cdot [x_t, h_{t-1}])$

**Reset gate**: $r_t = \sigma(w_r \cdot [x_t, h_{t-1}])$

Using the reset gate, we compute the **Candidate hidden state**:

$$\tilde{h}_t = \tanh(w \cdot [x_t, r_t \odot h_{t-1}])$$

First, note that $r_t \in [0, 1]$ since it is the result of sigmoid. Now, if we compare $\tilde{h}_t$ to the simple hidden state of Vanilla RNN: $h_t = f_w(w_{hh}h_{t-1} + w_{xh}x_t)$. If $r_t$ is close to 1:

$$\tilde{h}_t = \tanh(w \cdot [x_t, r_t \odot h_{t-1}] \approx \tanh(w[x_t, h_{t-1}]) = \tanh(w_{hx}x_t + w_{hh}h_{t-1}])$$

The meaning is that if $r_t \to 1$, we get the classic memory component. If $r_t \to 0$, then $\tilde{h}_t \approx \tanh(w \cdot [x_t, 0 \odot h_{t-1}] = \tanh(w_{xh}x_t)$, and in fact the short-term memory is reset.

To combine $\tilde{h}_t$ with the previous hidden state, we use $z_t$, which also receives values between 0 and 1:

$$h_t = (1 - z_t) \odot h_{t-1} + z_t \odot \tilde{h}_t$$

If $z_t \to 0$, then $h_t \approx h_{t-1}$, meaning we do not consider $\tilde{h}_t$, and in fact we pass the previous state as is. If, on the other hand, $z_t \to 1$, then $h_t \approx \tilde{h}_t$, meaning we ignore the previous state as is and take the Candidate hidden state, which is the weighting of the previous hidden state with the current input element. For any other value of $z_t$, we get a combination of the previous hidden and the Candidate hidden state.

This architecture solves the gradient problem - the gradient is close to 1 and does not explode, and since we pass the state through time, the gradient is close to 1 all the time and does not vanish.

In [None]:
# GRU Implementation using PyTorch
class GRU(nn.Module):
    """
    GRU cell implementation.
    """
    def __init__(self, input_size, hidden_size):
        super(GRU, self).__init__()
        self.hidden_size = hidden_size
        
        # Update gate
        self.W_iz = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hz = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_z = nn.Parameter(torch.zeros(hidden_size))
        
        # Reset gate
        self.W_ir = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hr = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_r = nn.Parameter(torch.zeros(hidden_size))
        
        # Candidate hidden state
        self.W_in = nn.Linear(input_size, hidden_size, bias=False)
        self.W_hn = nn.Linear(hidden_size, hidden_size, bias=False)
        self.b_n = nn.Parameter(torch.zeros(hidden_size))
    
    def forward(self, x, h_prev):
        """
        Forward pass through GRU cell.
        
        Parameters:
        -----------
        x : tensor, shape (batch, input_size)
            Input at time t
        h_prev : tensor, shape (batch, hidden_size)
            Previous hidden state
        
        Returns:
        --------
        h_new : tensor
            New hidden state
        """
        # Update gate
        z_t = torch.sigmoid(self.W_iz(x) + self.W_hz(h_prev) + self.b_z)
        
        # Reset gate
        r_t = torch.sigmoid(self.W_ir(x) + self.W_hr(h_prev) + self.b_r)
        
        # Candidate hidden state
        n_t = torch.tanh(self.W_in(x) + r_t * (self.W_hn(h_prev) + self.b_n))
        
        # Update hidden state
        h_new = (1 - z_t) * h_prev + z_t * n_t
        
        return h_new

# Test GRU
input_size = 5
hidden_size = 10
gru = GRU(input_size, hidden_size)

# Create a sequence
seq_length = 8
batch_size = 2
x_seq = torch.randn(seq_length, batch_size, input_size)

# Initialize hidden state
h = torch.zeros(batch_size, hidden_size)

print("GRU Forward Pass:")
print("=" * 50)
print(f"Input sequence length: {seq_length}")
print(f"Batch size: {batch_size}")
print(f"Input size: {input_size}")
print(f"Hidden size: {hidden_size}")

# Process sequence
for t in range(seq_length):
    h = gru(x_seq[t], h)
    print(f"Step {t+1}: h.shape={h.shape}")

print(f"\nFinal hidden state shape: {h.shape}")

# Compare parameters
lstm_params = sum(p.numel() for p in LSTM(input_size, hidden_size).parameters())
gru_params = sum(p.numel() for p in GRU(input_size, hidden_size).parameters())

print(f"\nParameter Comparison:")
print(f"  LSTM parameters: {lstm_params}")
print(f"  GRU parameters: {gru_params}")
print(f"  GRU has {lstm_params - gru_params} fewer parameters")

In [None]:
# Visualize GRU Cell Architecture
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(-1, 10)
ax.set_ylim(-1, 8)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('GRU Cell Architecture', fontsize=16, weight='bold', pad=20)

# Input
ax.text(0, 6, '$x_t$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))
ax.text(0, 4, '$h_{t-1}$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

# Concatenate
ax.text(2, 5, 'Concat', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))
ax.arrow(0.5, 6, 0.8, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(0.5, 4, 0.8, 0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Gates
ax.text(4.5, 6.5, '$z_t$ (Update)', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral', edgecolor='black'))
ax.text(4.5, 5, '$r_t$ (Reset)', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightpink', edgecolor='black'))

ax.arrow(2.5, 5, 1.5, 1.5, head_width=0.15, head_length=0.1, fc='black', ec='black', alpha=0.6)
ax.arrow(2.5, 5, 1.5, 0, head_width=0.15, head_length=0.1, fc='black', ec='black', alpha=0.6)

# Candidate
ax.text(4.5, 3, '$\\tilde{h}_t$', fontsize=12, weight='bold', ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='wheat', edgecolor='black'))
ax.arrow(2.5, 5, 1.5, -2, head_width=0.15, head_length=0.1, fc='black', ec='black', alpha=0.6)

# Operations
ax.text(7, 6.5, '$\\times$', fontsize=16, ha='center', va='center')
ax.text(7, 4.5, '$1-z_t$', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))
ax.text(7, 3, '$z_t$', fontsize=10, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))
ax.text(7, 1.5, '$+$', fontsize=16, ha='center', va='center')

# New hidden state
ax.text(9, 3.5, '$h_t$', fontsize=14, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))

# Arrows
ax.arrow(5.2, 6.5, 1.3, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.2, 5, 1.3, -0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(5.2, 3, 1.3, 0.5, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(7.5, 3.5, 1, 0, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Reset gate connection
ax.plot([4.5, 4.5, 6.5], [5, 2.5, 2.5], 'r--', linewidth=1.5, alpha=0.5)
ax.text(5.5, 3.5, '$r_t$', fontsize=9, color='red', alpha=0.7)

plt.tight_layout()
plt.show()

print("GRU Key Components:")
print("=" * 50)
print("1. Reset Gate ($r_t$): Controls how much past information to forget")
print("2. Update Gate ($z_t$): Controls how much to update with new information")
print("3. Candidate Hidden State ($\\tilde{h}_t$): New information candidate")
print("4. Hidden State ($h_t$): Combination of old and new information")
print("\nAdvantages over LSTM:")
print("  - Simpler architecture (no cell state)")
print("  - Fewer parameters")
print("  - Often performs similarly to LSTM")

## 6.2.3 Deep RNN

So far we have discussed single memory components, which can be chained together to get a layer that can analyze temporal data. We can also stack several layers to get a deep network.

### Deep RNN Architecture

Let us describe the network formally. At each time point $t$, there is an input vector $x_t \in \mathbb{R}^{n \times d}$ (a vector with $n$ elements, where each element is of dimension $d$). The sequence elements enter a network with $L$ layers and $T$ elements in the input, where at each time point $t$ there are $L$ layers (and hidden states) - for each layer $\ell$ at time point $t$, we denote $H_t^{(\ell)} \in \mathbb{R}^{n \times h}$.

At each time point there is also an output vector of length $n$ - $o_t \in \mathbb{R}^{n \times o}$. We denote: $H_t^{(0)} = x_t$, and we assume that between one layer and another there are no direct connections, only through function $\phi$. Using these notations, we can write:

$$H_t^{(\ell)} = \phi_\ell(H_t^{(\ell-1)}w_{xh}^{(\ell)} + H_{t-1}^{(\ell)}w_{hh}^{(\ell)} + b_h^{(\ell)})$$

where $w_{xh}^{(\ell)} \in \mathbb{R}^{h \times h}, w_{hh}^{(\ell)} \in \mathbb{R}^{h \times h}, b_h^{(\ell)} \in \mathbb{R}^{1 \times h}$ are the parameters of the hidden layer $\ell$.

The output $o_t$ depends directly only on layer $L$, and is calculated by:

$$o_t = H_t^{(L)}w_{hy}^{(L)} + b_y^{(L)}$$

where $w_{hy}^{(L)} \in \mathbb{R}^{h \times o}, b_y^{(L)} \in \mathbb{R}^{1 \times o}$ are the parameters of the output layer.

Of course, we can use hidden states in GRU or LSTM memory components, and thus get a Deep Gated RNN.

In [None]:
# Deep RNN Implementation using PyTorch
class DeepRNN(nn.Module):
    """
    Deep RNN with multiple stacked layers.
    """
    def __init__(self, input_size, hidden_size, num_layers, output_size, cell_type='LSTM'):
        super(DeepRNN, self).__init__()
        self.num_layers = num_layers
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        
        # Create stacked RNN layers
        if cell_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers, batch_first=False)
        elif cell_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, num_layers, batch_first=False)
        else:
            self.rnn = nn.RNN(input_size, hidden_size, num_layers, batch_first=False)
        
        # Output layer
        self.fc = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        """
        Forward pass through deep RNN.
        
        Parameters:
        -----------
        x : tensor, shape (seq_len, batch, input_size)
            Input sequence
        
        Returns:
        --------
        output : tensor, shape (seq_len, batch, output_size)
            Output sequence
        """
        # RNN forward pass
        rnn_out, _ = self.rnn(x)
        
        # Apply output layer to each time step
        output = self.fc(rnn_out)
        
        return output

# Test Deep RNN
input_size = 10
hidden_size = 20
num_layers = 3
output_size = 5
seq_length = 15
batch_size = 4

# Create model
deep_lstm = DeepRNN(input_size, hidden_size, num_layers, output_size, cell_type='LSTM')
deep_gru = DeepRNN(input_size, hidden_size, num_layers, output_size, cell_type='GRU')

# Create input sequence
x = torch.randn(seq_length, batch_size, input_size)

# Forward pass
with torch.no_grad():
    output_lstm = deep_lstm(x)
    output_gru = deep_gru(x)

print("Deep RNN Architecture:")
print("=" * 50)
print(f"Input size: {input_size}")
print(f"Hidden size: {hidden_size}")
print(f"Number of layers: {num_layers}")
print(f"Output size: {output_size}")
print(f"\nInput shape: {x.shape}")
print(f"LSTM output shape: {output_lstm.shape}")
print(f"GRU output shape: {output_gru.shape}")

# Count parameters
lstm_params = sum(p.numel() for p in deep_lstm.parameters())
gru_params = sum(p.numel() for p in deep_gru.parameters())

print(f"\nParameters:")
print(f"  Deep LSTM: {lstm_params:,}")
print(f"  Deep GRU: {gru_params:,}")

# Visualize architecture
fig, ax = plt.subplots(figsize=(12, 8))
ax.set_xlim(-0.5, 6)
ax.set_ylim(-0.5, 4)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Deep RNN Architecture (3 Layers)', fontsize=14, weight='bold', pad=20)

# Input sequence
for i in range(3):
    ax.text(0, 3-i*0.8, f'$x_t$', fontsize=10, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black'))

# Layers
for layer in range(3):
    y_pos = 2.5 - layer * 0.8
    ax.text(2, y_pos, f'Layer {layer+1}', fontsize=10, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black'))
    ax.text(4, y_pos, f'$H_t^{{({layer+1})}}$', fontsize=10, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))

# Output
ax.text(5.5, 1.5, '$o_t$', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))

# Arrows
for i in range(3):
    ax.arrow(0.5, 3-i*0.8, 1, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')
    ax.arrow(2.5, 2.5-i*0.8, 1, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')
    if i == 2:
        ax.arrow(4.5, 1.5, 0.8, 0, head_width=0.1, head_length=0.1, fc='black', ec='black')

# Recurrent connections
for layer in range(3):
    y_pos = 2.5 - layer * 0.8
    ax.plot([2, 2, 1.5, 1.5, 2], [y_pos, y_pos+0.3, y_pos+0.3, y_pos, y_pos], 
            'r--', linewidth=1.5, alpha=0.5)

plt.tight_layout()
plt.show()

In [None]:
# Bidirectional RNN Implementation using PyTorch
class BidirectionalRNN(nn.Module):
    """
    Bidirectional RNN implementation.
    """
    def __init__(self, input_size, hidden_size, output_size, cell_type='LSTM'):
        super(BidirectionalRNN, self).__init__()
        self.hidden_size = hidden_size
        self.cell_type = cell_type
        
        # Bidirectional RNN
        if cell_type == 'LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, num_layers=1, 
                              batch_first=False, bidirectional=True)
        elif cell_type == 'GRU':
            self.rnn = nn.GRU(input_size, hidden_size, num_layers=1,
                             batch_first=False, bidirectional=True)
        else:
            self.rnn = nn.RNN(input_size, hidden_size, num_layers=1,
                             batch_first=False, bidirectional=True)
        
        # Output layer (input is 2*hidden_size because of bidirectional)
        self.fc = nn.Linear(2 * hidden_size, output_size)
    
    def forward(self, x):
        """
        Forward pass through bidirectional RNN.
        
        Parameters:
        -----------
        x : tensor, shape (seq_len, batch, input_size)
            Input sequence
        
        Returns:
        --------
        output : tensor, shape (seq_len, batch, output_size)
            Output sequence
        """
        # Bidirectional RNN forward pass
        rnn_out, _ = self.rnn(x)  # Shape: (seq_len, batch, 2*hidden_size)
        
        # Apply output layer
        output = self.fc(rnn_out)
        
        return output

# Test Bidirectional RNN
input_size = 10
hidden_size = 20
output_size = 5
seq_length = 8
batch_size = 3

# Create models
bi_lstm = BidirectionalRNN(input_size, hidden_size, output_size, cell_type='LSTM')
bi_gru = BidirectionalRNN(input_size, hidden_size, output_size, cell_type='GRU')

# Create input sequence
x = torch.randn(seq_length, batch_size, input_size)

# Forward pass
with torch.no_grad():
    output_lstm = bi_lstm(x)
    output_gru = bi_gru(x)

print("Bidirectional RNN Architecture:")
print("=" * 50)
print(f"Input size: {input_size}")
print(f"Hidden size (per direction): {hidden_size}")
print(f"Total hidden size: {2 * hidden_size} (forward + backward)")
print(f"Output size: {output_size}")
print(f"\nInput shape: {x.shape}")
print(f"Bidirectional LSTM output shape: {output_lstm.shape}")
print(f"Bidirectional GRU output shape: {output_gru.shape}")

# Count parameters
lstm_params = sum(p.numel() for p in bi_lstm.parameters())
gru_params = sum(p.numel() for p in bi_gru.parameters())

print(f"\nParameters:")
print(f"  Bidirectional LSTM: {lstm_params:,}")
print(f"  Bidirectional GRU: {gru_params:,}")

# Visualize architecture
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(-0.5, 7)
ax.set_ylim(-0.5, 5)
ax.set_aspect('equal')
ax.axis('off')
ax.set_title('Bidirectional RNN Architecture', fontsize=14, weight='bold', pad=20)

# Input sequence
for i in range(3):
    ax.text(0, 4-i*1.2, f'$x_{i+1}$', fontsize=11, ha='center',
            bbox=dict(boxstyle='round', facecolor='lightblue', edgecolor='black', linewidth=2))

# Forward RNN
ax.text(2, 3.5, '$\\overrightarrow{H}_t$', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightgreen', edgecolor='black', linewidth=2))
ax.arrow(0.5, 4, 1, -0.5, head_width=0.15, head_length=0.1, fc='green', ec='green')
ax.arrow(0.5, 2.8, 1, 0.5, head_width=0.15, head_length=0.1, fc='green', ec='green')
ax.arrow(0.5, 1.6, 1, 0.5, head_width=0.15, head_length=0.1, fc='green', ec='green')

# Backward RNN
ax.text(4, 3.5, '$\\overleftarrow{H}_t$', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='lightcoral', edgecolor='black', linewidth=2))
ax.arrow(5.5, 1.6, -1, 0.5, head_width=0.15, head_length=0.1, fc='red', ec='red')
ax.arrow(5.5, 2.8, -1, 0.5, head_width=0.15, head_length=0.1, fc='red', ec='red')
ax.arrow(5.5, 4, -1, -0.5, head_width=0.15, head_length=0.1, fc='red', ec='red')

# Concatenate
ax.text(3, 1.5, 'Concat', fontsize=11, ha='center', va='center',
        bbox=dict(boxstyle='round', facecolor='lightyellow', edgecolor='black'))
ax.arrow(2.3, 3.5, 0.4, -1.8, head_width=0.15, head_length=0.1, fc='black', ec='black')
ax.arrow(4.3, 3.5, -0.4, -1.8, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Output
ax.text(3, 0.5, '$o_t$', fontsize=12, weight='bold', ha='center',
        bbox=dict(boxstyle='round', facecolor='orange', edgecolor='black', linewidth=2))
ax.arrow(3, 1.3, 0, -0.6, head_width=0.15, head_length=0.1, fc='black', ec='black')

# Labels
ax.text(2, 4.5, 'Forward', fontsize=10, ha='center', color='green', weight='bold')
ax.text(4, 4.5, 'Backward', fontsize=10, ha='center', color='red', weight='bold')

plt.tight_layout()
plt.show()

print("\nBidirectional RNN Benefits:")
print("=" * 50)
print("  - Can use information from both past and future")
print("  - Better for tasks like:")
print("    * Machine translation")
print("    * Named entity recognition")
print("    * Sentiment analysis")
print("    * Any task where context matters")