# Day 2: Understanding LSTM Networks

Day 2 of 30 Papers in 30 Days.

Today: **Long Short-Term Memory (LSTM) networks** — one of the most important breakthroughs in sequence modeling. LSTMs solved the vanishing gradient problem that made vanilla RNNs unable to learn long-range dependencies.

## What We'll Do

1. **The Problem**: Why vanilla RNNs struggle with long sequences
2. **The Solution**: How LSTMs use gates to control information flow
3. **The Architecture**: The 4 key components (forget, input, cell, output)
4. **The Implementation**: Building an LSTM from scratch in NumPy
5. **The Visualization**: Seeing what LSTMs "remember" and "forget"

## Colah's Core Insight

The cell state runs parallel to the hidden state, like a conveyor belt (Colah's metaphor). Information flows along it unchanged unless the LSTM explicitly modifies it through three gates: forget, input, and output.

In [None]:
# Setup
import numpy as np
import matplotlib.pyplot as plt
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath('__file__')))

# Import our LSTM implementation
from implementation import LSTM
from visualization import (
    plot_gate_activations, 
    plot_cell_state_evolution,
    plot_gradient_flow_comparison,
    analyze_gate_patterns
)

# Set random seed for reproducibility
np.random.seed(42)

print("All imports successful.")
print(f"NumPy version: {np.__version__}")

## 1. The Vanishing Gradient Problem

Before LSTMs, vanilla RNNs worked for short sequences but failed on long ones.

### Why

In backpropagation through time, gradients flow backward like this:

$$\frac{\partial L}{\partial h_1} = \frac{\partial L}{\partial h_T} \prod_{t=2}^{T} \frac{\partial h_t}{\partial h_{t-1}}$$

Each term in the product is typically < 1, so:

$$0.9 \times 0.9 \times 0.9 \times ... \times 0.9 \text{ (50 times)} \approx 0.005$$

The gradient vanishes. The network can't learn what happened 50 steps ago.

Colah's example from the post: predicting "sky" in "the clouds are in the ___" is easy (short range). Predicting "French" in "I grew up in France... I speak fluent ___" is hard (long range).

In [None]:
# Demonstrate vanishing gradients
def simulate_gradient_flow(initial_grad=1.0, steps=50, factor=0.9):
    """Simulate gradient flowing backward through time."""
    gradients = [initial_grad]
    for _ in range(steps):
        gradients.append(gradients[-1] * factor)
    return gradients

# Compare different scenarios
steps = range(51)
vanilla_rnn = simulate_gradient_flow(1.0, 50, 0.9)
lstm_sim = simulate_gradient_flow(1.0, 50, 0.99)  # LSTMs preserve gradients better

plt.figure(figsize=(12, 5))
plt.plot(steps, vanilla_rnn, 'r-', linewidth=2, label='Vanilla RNN (0.9 factor)')
plt.plot(steps, lstm_sim, 'g-', linewidth=2, label='LSTM (0.99 factor)')
plt.axhline(y=0.1, color='orange', linestyle='--', alpha=0.5, label='Vanishing threshold')
plt.xlabel('Time Steps Backward', fontsize=12)
plt.ylabel('Gradient Magnitude', fontsize=12)
plt.title('Vanishing Gradients: RNN vs LSTM', fontsize=14, fontweight='bold')
plt.yscale('log')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Vanilla RNN after 50 steps: {vanilla_rnn[-1]:.6f}")
print(f"LSTM after 50 steps: {lstm_sim[-1]:.6f}")
print(f"\nLSTM preserves {lstm_sim[-1] / vanilla_rnn[-1]:.1f}x more gradient!")

## 2. The LSTM Solution: The Cell State

LSTMs solve vanishing gradients with the **cell state** — a separate path that runs parallel to the hidden state.

Colah's metaphor: the cell state is a "conveyor belt" that information rides along unchanged, unless the LSTM explicitly modifies it through gates.

### The Key Equation

The cell state updates via **addition** (not multiplication):

$$C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$$

When gradients flow backward:

$$\frac{\partial C_t}{\partial C_{t-1}} = f_t$$

Since $f_t \approx 1$ (forget gate defaults to keeping information), gradients flow through without vanishing. This is the core reason LSTMs work.

## 3. LSTM Architecture: The 4 Components

An LSTM cell has 4 parts, following Colah's step-by-step walkthrough:

### 1. Forget Gate ($f_t$) — What to throw away from cell state

$$f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$$

- Outputs 0 (forget everything) to 1 (keep everything)
- Colah's example: when you see a new subject, forget the old subject's gender

### 2. Input Gate ($i_t$) — What new information to store

$$i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$$

Controls how much of the new candidate values to write.

### 3. Cell Candidate ($\tilde{C}_t$) — The new values to potentially add

$$\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$$

Outputs -1 to 1. These are the candidate values, scaled by the input gate before being added to cell state.

### 4. Output Gate ($o_t$) — What to output from cell state

$$o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$$

Colah's example: if we just saw a subject, output whether it's singular or plural (relevant for verb conjugation).

### The Update

1. **Forget**: $C_t = f_t \odot C_{t-1}$ (scale down old info)
2. **Add**: $C_t = C_t + i_t \odot \tilde{C}_t$ (add new info)
3. **Output**: $h_t = o_t \odot \tanh(C_t)$ (filter what to expose)

Cell state ($C_t$) is the memory. Hidden state ($h_t$) is the output.

In [None]:
# Create a small vocabulary and LSTM for visualization
chars = list("hello")
char_to_idx = {ch: i for i, ch in enumerate(chars)}
idx_to_char = {i: ch for i, ch in enumerate(chars)}
vocab_size = len(chars)

print(f"Vocabulary: {chars}")
print(f"Vocab size: {vocab_size}")

# Initialize a small LSTM
hidden_size = 10  # Small for visualization
lstm = LSTM(input_size=vocab_size, hidden_size=hidden_size, output_size=vocab_size)

print(f"\nLSTM created:")
print(f"  Input size: {vocab_size}")
print(f"  Hidden size: {hidden_size}")
print(f"  Output size: {vocab_size}")
print(f"  Total parameters: {sum(p.size for p in [lstm.Wf, lstm.Wi, lstm.Wc, lstm.Wo, lstm.Wy])}")

## 4. Forward Pass: Watching the Gates

Run the sequence "hello" through the LSTM and capture gate activations at each step — forget, input, output — along with cell state evolution.

In [None]:
# Prepare sequence
text = "hello"
inputs = [char_to_idx[ch] for ch in text]
print(f"Input sequence: {text}")
print(f"As indices: {inputs}")

# Storage for gate activations
gates_storage = {
    'forget': [],
    'input': [],
    'output': []
}
cell_states_storage = []

# Initialize hidden and cell states
h_prev = np.zeros(hidden_size)
C_prev = np.zeros(hidden_size)

# Forward pass through sequence
for t, idx in enumerate(inputs):
    # Create one-hot encoded input
    x = np.zeros(vocab_size)
    x[idx] = 1.0
    
    # Compute all gates manually to capture them
    concat = np.concatenate([h_prev, x])
    
    # Forget gate
    f = lstm.sigmoid(np.dot(lstm.Wf, concat) + lstm.bf)
    gates_storage['forget'].append(f.copy())
    
    # Input gate
    i = lstm.sigmoid(np.dot(lstm.Wi, concat) + lstm.bi)
    gates_storage['input'].append(i.copy())
    
    # Cell candidate
    C_tilde = np.tanh(np.dot(lstm.Wc, concat) + lstm.bc)
    
    # Update cell state
    C_prev = f * C_prev + i * C_tilde
    cell_states_storage.append(C_prev.copy())
    
    # Output gate
    o = lstm.sigmoid(np.dot(lstm.Wo, concat) + lstm.bo)
    gates_storage['output'].append(o.copy())
    
    # Update hidden state
    h_prev = o * np.tanh(C_prev)
    
    print(f"\nStep {t} ('{text[t]}'):")
    print(f"  Forget gate avg: {f.mean():.3f} (1=keep, 0=forget)")
    print(f"  Input gate avg:  {i.mean():.3f} (1=add, 0=ignore)")
    print(f"  Output gate avg: {o.mean():.3f} (1=show, 0=hide)")

print("\nForward pass complete.")

## 5. Visualizing Gate Activations

Heatmaps showing which hidden units are active at each time step, and how the cell state evolves.

In [None]:
# Visualize gate activations
plot_gate_activations(gates_storage, text)

# Visualize cell state evolution
plot_cell_state_evolution(cell_states_storage, text)

# Analyze patterns
analyze_gate_patterns(gates_storage, text)

## 6. Training on Real Text

Train the LSTM on Shakespeare text to see how it learns next-character prediction. Same task as Day 1's vanilla RNN, but with the cell state machinery.

In [None]:
# Simple training data
training_text = """
To be, or not to be, that is the question:
Whether 'tis nobler in the mind to suffer
The slings and arrows of outrageous fortune,
Or to take arms against a sea of troubles
And by opposing end them.
"""

# Create vocabulary
chars_train = sorted(list(set(training_text)))
char_to_idx_train = {ch: i for i, ch in enumerate(chars_train)}
idx_to_char_train = {i: ch for i, ch in enumerate(chars_train)}
vocab_size_train = len(chars_train)

print(f"Training text length: {len(training_text)} characters")
print(f"Vocabulary size: {vocab_size_train}")
print(f"Unique characters: {''.join(chars_train)}")

# Create LSTM
lstm_train = LSTM(input_size=vocab_size_train, 
                  hidden_size=64, 
                  output_size=vocab_size_train)

print(f"\nTraining LSTM created.")

In [None]:
# Training loop
seq_length = 25
learning_rate = 0.001
num_iterations = 1000

losses = []
h_prev = np.zeros(lstm_train.hidden_size)
C_prev = np.zeros(lstm_train.hidden_size)

print("Training LSTM...")
print("=" * 60)

for iteration in range(num_iterations):
    # Sample random starting point
    start_idx = np.random.randint(0, len(training_text) - seq_length - 1)
    
    # Get input and target sequences
    input_seq = training_text[start_idx:start_idx + seq_length]
    target_seq = training_text[start_idx + 1:start_idx + seq_length + 1]
    
    # Convert to indices
    inputs = [char_to_idx_train[ch] for ch in input_seq]
    targets = [char_to_idx_train[ch] for ch in target_seq]
    
    # Forward pass
    loss = lstm_train.forward(inputs, targets, h_prev, C_prev)
    losses.append(loss)
    
    # Backward pass
    dh_next, dC_next = lstm_train.backward()
    
    # Update weights
    lstm_train.update_weights(learning_rate)
    
    # Update states
    h_prev = lstm_train.h_states[-1].copy()
    C_prev = lstm_train.C_states[-1].copy()
    
    # Print progress
    if iteration % 100 == 0:
        smooth_loss = np.mean(losses[-100:]) if len(losses) >= 100 else np.mean(losses)
        print(f"Iteration {iteration:4d} | Loss: {smooth_loss:.4f}")
        
        # Sample text
        if iteration % 500 == 0:
            sample = lstm_train.sample(idx_to_char_train, char_to_idx_train['T'], 100)
            print(f"Sample: {sample[:60]}...")
            print()

print("\nTraining complete.")

In [None]:
# Plot training curve
plt.figure(figsize=(10, 5))
plt.plot(losses, alpha=0.3, label='Raw loss')
# Smooth curve
window = 50
if len(losses) > window:
    smoothed = np.convolve(losses, np.ones(window)/window, mode='valid')
    plt.plot(range(window-1, len(losses)), smoothed, linewidth=2, label='Smoothed loss', color='red')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.title('LSTM Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Final loss: {losses[-1]:.4f}")
print(f"Improvement: {(losses[0] - losses[-1]) / losses[0] * 100:.1f}%")

## 7. Temperature Sampling

Temperature controls sampling randomness:

- **Low (0.5)**: Conservative, picks likely characters — more coherent but repetitive
- **Medium (1.0)**: Balanced
- **High (1.5)**: More random, picks unlikely characters — more diverse but less coherent

In [None]:
# Try different temperatures
temperatures = [0.3, 0.7, 1.0, 1.5]
seed_char = 'T'

print("Sampling with different temperatures:")
print("=" * 70)

for temp in temperatures:
    sample = lstm_train.sample(idx_to_char_train, 
                               char_to_idx_train[seed_char], 
                               length=150, 
                               temperature=temp)
    print(f"\nTemperature = {temp}:")
    print(f"{sample[:120]}...")
    print("-" * 70)

## 8. Summary

**The problem:** Vanilla RNNs can't learn long-range dependencies because gradients vanish during backpropagation through time.

**The solution (from Colah's post):** LSTMs add a cell state — a separate path where information flows via addition, not multiplication. Three gates (forget, input, output) control what gets written, kept, and exposed.

**Why it works:** The gradient of $C_t$ w.r.t. $C_{t-1}$ is just $f_t$ (a scalar near 1), not a matrix multiplication. Gradients don't vanish.

**Variants (also from the post):** Peephole connections, coupled gates, and GRUs (which merge forget and input gates into a single update gate).

### Next Steps

- Try the exercises in `exercises/` — especially the ablation study (exercise 3) which shows what happens when you remove individual gates
- Compare training curves with Day 1's vanilla RNN on the same data
- Read Colah's original post for the diagrams: http://colah.github.io/posts/2015-08-Understanding-LSTMs/