# RNNs and LSTMs from Scratch

This notebook builds up Recurrent Neural Networks and LSTMs from first principles, matching the stardust article's Architectures tab.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

# Set seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 1. Simple RNN from Scratch

The core RNN equation:
$$h_t = \tanh(W_{xh} \cdot x_t + W_{hh} \cdot h_{t-1} + b)$$

Let's implement this manually.

In [None]:
class SimpleRNNCell:
    """A single RNN cell implemented from scratch."""
    
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        
        # Initialize weights (using Xavier/Glorot initialization)
        scale = np.sqrt(2.0 / (input_size + hidden_size))
        self.W_xh = np.random.randn(input_size, hidden_size) * scale
        self.W_hh = np.random.randn(hidden_size, hidden_size) * scale
        self.b = np.zeros(hidden_size)
        
    def forward(self, x, h_prev):
        """Single forward step.
        
        Args:
            x: Input at current timestep (batch_size, input_size)
            h_prev: Hidden state from previous timestep (batch_size, hidden_size)
        
        Returns:
            h_new: New hidden state (batch_size, hidden_size)
        """
        # h_t = tanh(W_xh * x_t + W_hh * h_{t-1} + b)
        h_new = np.tanh(x @ self.W_xh + h_prev @ self.W_hh + self.b)
        return h_new
    
    def init_hidden(self, batch_size):
        """Initialize hidden state to zeros."""
        return np.zeros((batch_size, self.hidden_size))

# Test the RNN cell
rnn_cell = SimpleRNNCell(input_size=10, hidden_size=20)
x = np.random.randn(1, 10)  # Single input
h = rnn_cell.init_hidden(1)

# Process one step
h_new = rnn_cell.forward(x, h)
print(f"Input shape: {x.shape}")
print(f"Hidden state shape: {h_new.shape}")
print(f"Hidden state values (first 5): {h_new[0, :5]}")

## 2. Processing a Sequence

Now let's process an entire sequence, keeping the hidden state across timesteps.

In [None]:
class SimpleRNN:
    """Full RNN that processes sequences."""
    
    def __init__(self, input_size, hidden_size):
        self.cell = SimpleRNNCell(input_size, hidden_size)
        self.hidden_size = hidden_size
        
    def forward(self, sequence):
        """Process an entire sequence.
        
        Args:
            sequence: (seq_len, batch_size, input_size)
        
        Returns:
            outputs: All hidden states (seq_len, batch_size, hidden_size)
            h_final: Final hidden state (batch_size, hidden_size)
        """
        seq_len, batch_size, _ = sequence.shape
        h = self.cell.init_hidden(batch_size)
        
        outputs = []
        for t in range(seq_len):
            h = self.cell.forward(sequence[t], h)
            outputs.append(h)
            
        return np.stack(outputs), h

# Test with a sequence
rnn = SimpleRNN(input_size=10, hidden_size=20)
sequence = np.random.randn(5, 1, 10)  # 5 timesteps, batch of 1

outputs, h_final = rnn.forward(sequence)
print(f"Sequence length: 5")
print(f"All outputs shape: {outputs.shape}")
print(f"Final hidden state shape: {h_final.shape}")

## 3. Visualizing the Unrolled RNN

When we process a sequence, the RNN "unrolls" into a deep network.

In [None]:
def visualize_hidden_states(outputs, title="Hidden State Evolution"):
    """Visualize how hidden states evolve over time."""
    # outputs: (seq_len, batch_size, hidden_size)
    states = outputs[:, 0, :]  # Take first batch
    
    plt.figure(figsize=(12, 4))
    plt.imshow(states.T, aspect='auto', cmap='RdBu_r', vmin=-1, vmax=1)
    plt.colorbar(label='Activation')
    plt.xlabel('Timestep')
    plt.ylabel('Hidden Unit')
    plt.title(title)
    plt.show()

# Process a longer sequence and visualize
long_sequence = np.random.randn(20, 1, 10)
outputs, _ = rnn.forward(long_sequence)
visualize_hidden_states(outputs, "Hidden States Over 20 Timesteps")

## 4. Backpropagation Through Time (BPTT)

The gradient for RNN weights involves a product across all timesteps:

$$\frac{\partial L}{\partial W} = \sum_{t=1}^{T} \frac{\partial L_t}{\partial W} = \sum_{t} \frac{\partial L_t}{\partial h_t} \prod_{k=1}^{t} \frac{\partial h_k}{\partial h_{k-1}}$$

Let's implement BPTT for our simple RNN.

In [None]:
class RNNWithBPTT:
    """RNN with full BPTT implementation."""
    
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        self.input_size = input_size
        
        # Initialize weights
        scale = np.sqrt(2.0 / (input_size + hidden_size))
        self.W_xh = np.random.randn(input_size, hidden_size) * scale
        self.W_hh = np.random.randn(hidden_size, hidden_size) * scale
        self.b = np.zeros(hidden_size)
        
        # Gradients
        self.dW_xh = np.zeros_like(self.W_xh)
        self.dW_hh = np.zeros_like(self.W_hh)
        self.db = np.zeros_like(self.b)
        
    def forward(self, sequence):
        """Forward pass, storing values for backprop."""
        seq_len, batch_size, _ = sequence.shape
        
        # Store values for backprop
        self.inputs = sequence
        self.hiddens = [np.zeros((batch_size, self.hidden_size))]
        self.pre_activations = []
        
        for t in range(seq_len):
            pre_act = sequence[t] @ self.W_xh + self.hiddens[-1] @ self.W_hh + self.b
            h = np.tanh(pre_act)
            self.pre_activations.append(pre_act)
            self.hiddens.append(h)
            
        return np.stack(self.hiddens[1:]), self.hiddens[-1]
    
    def backward(self, d_outputs):
        """Full BPTT backward pass.
        
        Args:
            d_outputs: Gradient of loss w.r.t. outputs (seq_len, batch_size, hidden_size)
        """
        seq_len = len(self.pre_activations)
        batch_size = d_outputs.shape[1]
        
        # Reset gradients
        self.dW_xh = np.zeros_like(self.W_xh)
        self.dW_hh = np.zeros_like(self.W_hh)
        self.db = np.zeros_like(self.b)
        
        # Gradient flowing back through hidden states
        dh_next = np.zeros((batch_size, self.hidden_size))
        
        # Track gradient magnitudes for visualization
        grad_magnitudes = []
        
        # Backprop through time (reverse order)
        for t in reversed(range(seq_len)):
            # Total gradient at this timestep
            dh = d_outputs[t] + dh_next
            
            # Gradient through tanh
            d_pre_act = dh * (1 - np.tanh(self.pre_activations[t])**2)
            
            # Accumulate parameter gradients
            self.dW_xh += self.inputs[t].T @ d_pre_act
            self.dW_hh += self.hiddens[t].T @ d_pre_act
            self.db += d_pre_act.sum(axis=0)
            
            # Gradient to previous hidden state
            dh_next = d_pre_act @ self.W_hh.T
            
            grad_magnitudes.append(np.linalg.norm(dh_next))
            
        return list(reversed(grad_magnitudes))

# Test BPTT
rnn_bptt = RNNWithBPTT(input_size=10, hidden_size=20)
sequence = np.random.randn(10, 1, 10)
outputs, h_final = rnn_bptt.forward(sequence)

# Fake gradient from loss (as if loss was sum of all outputs)
d_outputs = np.ones_like(outputs) * 0.1
grad_mags = rnn_bptt.backward(d_outputs)

print("Gradient magnitudes through time:")
for t, mag in enumerate(grad_mags):
    print(f"  t={t}: {mag:.4f}")

## 5. The Vanishing Gradient Problem

Let's visualize why gradients vanish in RNNs.

In [None]:
def demonstrate_vanishing_gradients(seq_lengths=[10, 25, 50, 100]):
    """Show how gradients vanish with increasing sequence length."""
    
    plt.figure(figsize=(12, 4))
    
    for seq_len in seq_lengths:
        rnn = RNNWithBPTT(input_size=10, hidden_size=20)
        sequence = np.random.randn(seq_len, 1, 10) * 0.1
        
        outputs, _ = rnn.forward(sequence)
        d_outputs = np.ones_like(outputs) * 0.1
        grad_mags = rnn.backward(d_outputs)
        
        # Normalize for comparison
        grad_mags = np.array(grad_mags) / max(grad_mags)
        plt.plot(range(seq_len), grad_mags, label=f'Length {seq_len}', alpha=0.7)
    
    plt.xlabel('Timestep')
    plt.ylabel('Relative Gradient Magnitude')
    plt.title('Vanishing Gradients: Gradient magnitude decreases with distance')
    plt.legend()
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    plt.show()

demonstrate_vanishing_gradients()

In [None]:
# Mathematical demonstration: product of values < 1
decay_factor = 0.9
timesteps = 100

gradient_remaining = [decay_factor ** t for t in range(timesteps)]

plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
plt.plot(gradient_remaining)
plt.xlabel('Timesteps back')
plt.ylabel('Gradient remaining')
plt.title(f'Gradient decay with factor = {decay_factor}')
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.semilogy(gradient_remaining)
plt.xlabel('Timesteps back')
plt.ylabel('Gradient remaining (log scale)')
plt.title('Same data, log scale')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"After 100 timesteps with decay factor {decay_factor}:")
print(f"  Gradient remaining: {decay_factor**100:.2e}")
print(f"  That's {decay_factor**100 * 100:.6f}% of the original gradient!")

## 6. Gradient Clipping

For exploding gradients, we clip the gradient norm.

In [None]:
def clip_gradients(gradients, max_norm=1.0):
    """Clip gradients to prevent explosion.
    
    If ||g|| > max_norm: g = g * (max_norm / ||g||)
    """
    total_norm = np.sqrt(sum(np.sum(g**2) for g in gradients))
    
    if total_norm > max_norm:
        scale = max_norm / total_norm
        clipped = [g * scale for g in gradients]
        return clipped, True, total_norm
    return gradients, False, total_norm

# Simulate exploding gradients
np.random.seed(42)
large_gradients = [np.random.randn(10, 20) * 10 for _ in range(3)]

clipped, was_clipped, original_norm = clip_gradients(large_gradients, max_norm=5.0)
clipped_norm = np.sqrt(sum(np.sum(g**2) for g in clipped))

print(f"Original gradient norm: {original_norm:.2f}")
print(f"Was clipped: {was_clipped}")
print(f"Clipped gradient norm: {clipped_norm:.2f}")

## 7. LSTM from Scratch

The LSTM adds a cell state $C_t$ and three gates:

**Forget Gate:** $f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$

**Input Gate:** $i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$

**Candidate Values:** $\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$

**Cell State Update:** $C_t = f_t \odot C_{t-1} + i_t \odot \tilde{C}_t$

**Output Gate:** $o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$

**Hidden State:** $h_t = o_t \odot \tanh(C_t)$

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

class LSTMCell:
    """A single LSTM cell implemented from scratch."""
    
    def __init__(self, input_size, hidden_size):
        self.hidden_size = hidden_size
        combined_size = input_size + hidden_size
        
        # All gates share similar structure, so we can combine weights
        # Order: forget, input, candidate, output
        scale = np.sqrt(2.0 / combined_size)
        
        # Forget gate weights
        self.W_f = np.random.randn(combined_size, hidden_size) * scale
        self.b_f = np.ones(hidden_size)  # Initialize forget bias to 1 (keep everything initially)
        
        # Input gate weights
        self.W_i = np.random.randn(combined_size, hidden_size) * scale
        self.b_i = np.zeros(hidden_size)
        
        # Candidate values weights
        self.W_c = np.random.randn(combined_size, hidden_size) * scale
        self.b_c = np.zeros(hidden_size)
        
        # Output gate weights
        self.W_o = np.random.randn(combined_size, hidden_size) * scale
        self.b_o = np.zeros(hidden_size)
        
    def forward(self, x, h_prev, c_prev):
        """Single LSTM forward step.
        
        Args:
            x: Current input (batch_size, input_size)
            h_prev: Previous hidden state (batch_size, hidden_size)
            c_prev: Previous cell state (batch_size, hidden_size)
        
        Returns:
            h_new: New hidden state
            c_new: New cell state
            gates: Dictionary of gate activations for visualization
        """
        # Concatenate input and previous hidden state
        combined = np.concatenate([x, h_prev], axis=1)
        
        # Compute gates
        f = sigmoid(combined @ self.W_f + self.b_f)  # Forget gate
        i = sigmoid(combined @ self.W_i + self.b_i)  # Input gate
        c_tilde = np.tanh(combined @ self.W_c + self.b_c)  # Candidate values
        o = sigmoid(combined @ self.W_o + self.b_o)  # Output gate
        
        # Update cell state
        c_new = f * c_prev + i * c_tilde
        
        # Compute hidden state
        h_new = o * np.tanh(c_new)
        
        gates = {'forget': f, 'input': i, 'output': o, 'candidate': c_tilde}
        return h_new, c_new, gates
    
    def init_state(self, batch_size):
        """Initialize hidden and cell states."""
        h = np.zeros((batch_size, self.hidden_size))
        c = np.zeros((batch_size, self.hidden_size))
        return h, c

# Test LSTM cell
lstm_cell = LSTMCell(input_size=10, hidden_size=20)
x = np.random.randn(1, 10)
h, c = lstm_cell.init_state(1)

h_new, c_new, gates = lstm_cell.forward(x, h, c)
print(f"Hidden state shape: {h_new.shape}")
print(f"Cell state shape: {c_new.shape}")
print(f"\nGate activations (first 5 values):")
for name, values in gates.items():
    print(f"  {name}: {values[0, :5]}")

## 8. Full LSTM for Sequences

In [None]:
class LSTM:
    """Full LSTM that processes sequences."""
    
    def __init__(self, input_size, hidden_size):
        self.cell = LSTMCell(input_size, hidden_size)
        self.hidden_size = hidden_size
        
    def forward(self, sequence, return_all_gates=False):
        """Process an entire sequence.
        
        Args:
            sequence: (seq_len, batch_size, input_size)
            return_all_gates: Whether to return gate activations
        
        Returns:
            outputs: All hidden states (seq_len, batch_size, hidden_size)
            (h_final, c_final): Final states
            all_gates: List of gate dictionaries (if return_all_gates=True)
        """
        seq_len, batch_size, _ = sequence.shape
        h, c = self.cell.init_state(batch_size)
        
        outputs = []
        cell_states = []
        all_gates = []
        
        for t in range(seq_len):
            h, c, gates = self.cell.forward(sequence[t], h, c)
            outputs.append(h)
            cell_states.append(c)
            all_gates.append(gates)
            
        result = (np.stack(outputs), (h, c))
        if return_all_gates:
            result = result + (all_gates, np.stack(cell_states))
        return result

# Test LSTM
lstm = LSTM(input_size=10, hidden_size=20)
sequence = np.random.randn(15, 1, 10)

outputs, (h_final, c_final), all_gates, cell_states = lstm.forward(sequence, return_all_gates=True)
print(f"Outputs shape: {outputs.shape}")
print(f"Cell states shape: {cell_states.shape}")

## 9. Visualizing LSTM Gates

In [None]:
def visualize_lstm_gates(all_gates, cell_states):
    """Visualize how LSTM gates behave over time."""
    seq_len = len(all_gates)
    
    # Extract gate values (average across hidden units)
    forget_vals = [g['forget'].mean() for g in all_gates]
    input_vals = [g['input'].mean() for g in all_gates]
    output_vals = [g['output'].mean() for g in all_gates]
    cell_magnitude = [np.abs(c).mean() for c in cell_states]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 8))
    
    axes[0, 0].plot(forget_vals, 'b-', linewidth=2)
    axes[0, 0].set_title('Forget Gate (average)')
    axes[0, 0].set_ylim(0, 1)
    axes[0, 0].set_ylabel('Gate value')
    axes[0, 0].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    
    axes[0, 1].plot(input_vals, 'g-', linewidth=2)
    axes[0, 1].set_title('Input Gate (average)')
    axes[0, 1].set_ylim(0, 1)
    axes[0, 1].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    
    axes[1, 0].plot(output_vals, 'r-', linewidth=2)
    axes[1, 0].set_title('Output Gate (average)')
    axes[1, 0].set_ylim(0, 1)
    axes[1, 0].set_xlabel('Timestep')
    axes[1, 0].set_ylabel('Gate value')
    axes[1, 0].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    
    axes[1, 1].plot(cell_magnitude, 'm-', linewidth=2)
    axes[1, 1].set_title('Cell State Magnitude (average)')
    axes[1, 1].set_xlabel('Timestep')
    
    plt.tight_layout()
    plt.suptitle('LSTM Gate Activations Over Time', y=1.02)
    plt.show()

visualize_lstm_gates(all_gates, cell_states)

## 10. Activation Functions Comparison

LSTMs use sigmoid for gates (0 to 1) and tanh for values (-1 to 1).

In [None]:
def compare_activations():
    x = np.linspace(-5, 5, 200)
    
    relu = np.maximum(0, x)
    sigmoid_vals = 1 / (1 + np.exp(-x))
    tanh_vals = np.tanh(x)
    
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    
    # ReLU
    axes[0].plot(x, relu, 'g-', linewidth=2)
    axes[0].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    axes[0].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    axes[0].set_title('ReLU: max(0, x)')
    axes[0].set_xlabel('x')
    axes[0].set_ylabel('f(x)')
    axes[0].set_ylim(-1, 5)
    axes[0].text(2, 3.5, 'Range: [0, inf)\nFast, but can "die"', fontsize=10)
    
    # Sigmoid
    axes[1].plot(x, sigmoid_vals, 'b-', linewidth=2)
    axes[1].axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)
    axes[1].axhline(y=0, color='gray', linestyle='--', alpha=0.3)
    axes[1].axhline(y=1, color='gray', linestyle='--', alpha=0.3)
    axes[1].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    axes[1].set_title('Sigmoid: 1/(1+e^-x)')
    axes[1].set_xlabel('x')
    axes[1].set_ylim(-0.2, 1.2)
    axes[1].text(-4.5, 0.8, 'Range: [0, 1]\nPerfect for gates', fontsize=10)
    
    # Tanh
    axes[2].plot(x, tanh_vals, 'r-', linewidth=2)
    axes[2].axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    axes[2].axhline(y=-1, color='gray', linestyle='--', alpha=0.3)
    axes[2].axhline(y=1, color='gray', linestyle='--', alpha=0.3)
    axes[2].axvline(x=0, color='gray', linestyle='--', alpha=0.5)
    axes[2].set_title('Tanh: (e^x - e^-x)/(e^x + e^-x)')
    axes[2].set_xlabel('x')
    axes[2].set_ylim(-1.4, 1.4)
    axes[2].text(-4.5, 0.8, 'Range: [-1, 1]\nZero-centered, for values', fontsize=10)
    
    plt.tight_layout()
    plt.show()

compare_activations()

## 11. He Initialization

Proper initialization prevents vanishing/exploding activations in early training.

In [None]:
def compare_initializations():
    """Compare different weight initialization schemes."""
    hidden_size = 256
    num_layers = 10
    
    def forward_random(x):
        """Random initialization (bad)."""
        for _ in range(num_layers):
            W = np.random.randn(hidden_size, hidden_size)
            x = np.tanh(x @ W)
        return x
    
    def forward_xavier(x):
        """Xavier/Glorot initialization."""
        for _ in range(num_layers):
            W = np.random.randn(hidden_size, hidden_size) * np.sqrt(1.0 / hidden_size)
            x = np.tanh(x @ W)
        return x
    
    def forward_he(x):
        """He initialization (for ReLU-like, but works well generally)."""
        for _ in range(num_layers):
            W = np.random.randn(hidden_size, hidden_size) * np.sqrt(2.0 / hidden_size)
            x = np.tanh(x @ W)
        return x
    
    # Test
    np.random.seed(42)
    x = np.random.randn(32, hidden_size)
    
    results = {
        'Random': forward_random(x.copy()),
        'Xavier': forward_xavier(x.copy()),
        'He': forward_he(x.copy()),
    }
    
    print("Activation statistics after 10 layers:")
    print("-" * 50)
    for name, output in results.items():
        print(f"{name}:")
        print(f"  Mean: {output.mean():.6f}")
        print(f"  Std:  {output.std():.6f}")
        print(f"  Max:  {np.abs(output).max():.6f}")
        print()

compare_initializations()

## 12. RNN vs LSTM: Gradient Flow Comparison

In [None]:
def compare_gradient_retention(seq_lengths=[10, 25, 50, 100, 200]):
    """Compare how RNN and LSTM retain gradients over sequence length."""
    
    # Use PyTorch for cleaner gradient computation
    import torch.nn as nn
    
    rnn_retention = []
    lstm_retention = []
    
    for seq_len in seq_lengths:
        # RNN
        rnn = nn.RNN(input_size=10, hidden_size=20, batch_first=False)
        x_rnn = torch.randn(seq_len, 1, 10, requires_grad=True)
        outputs_rnn, _ = rnn(x_rnn)
        loss_rnn = outputs_rnn[-1].sum()  # Loss at final timestep
        loss_rnn.backward()
        # Gradient at first input
        rnn_grad = x_rnn.grad[0].abs().mean().item()
        rnn_retention.append(rnn_grad)
        
        # LSTM
        lstm = nn.LSTM(input_size=10, hidden_size=20, batch_first=False)
        x_lstm = torch.randn(seq_len, 1, 10, requires_grad=True)
        outputs_lstm, _ = lstm(x_lstm)
        loss_lstm = outputs_lstm[-1].sum()
        loss_lstm.backward()
        lstm_grad = x_lstm.grad[0].abs().mean().item()
        lstm_retention.append(lstm_grad)
    
    # Normalize
    rnn_retention = np.array(rnn_retention) / rnn_retention[0]
    lstm_retention = np.array(lstm_retention) / lstm_retention[0]
    
    plt.figure(figsize=(10, 5))
    plt.plot(seq_lengths, rnn_retention, 'b-o', label='RNN', linewidth=2)
    plt.plot(seq_lengths, lstm_retention, 'r-s', label='LSTM', linewidth=2)
    plt.xlabel('Sequence Length')
    plt.ylabel('Relative Gradient at First Input')
    plt.title('Gradient Retention: RNN vs LSTM')
    plt.legend()
    plt.yscale('log')
    plt.grid(True, alpha=0.3)
    plt.show()
    
    print("\nGradient retention at different sequence lengths:")
    for i, seq_len in enumerate(seq_lengths):
        print(f"  Length {seq_len:3d}: RNN={rnn_retention[i]:.4f}, LSTM={lstm_retention[i]:.4f}")

compare_gradient_retention()

## 13. Using PyTorch's Built-in LSTM

In practice, we use optimized implementations.

In [None]:
# PyTorch LSTM
pytorch_lstm = nn.LSTM(
    input_size=10,
    hidden_size=20,
    num_layers=2,
    dropout=0.1,
    bidirectional=False,
    batch_first=True
)

# Count parameters
total_params = sum(p.numel() for p in pytorch_lstm.parameters())
print(f"Total parameters: {total_params}")

# Test forward pass
x = torch.randn(4, 15, 10)  # batch=4, seq_len=15, input=10
outputs, (h_n, c_n) = pytorch_lstm(x)

print(f"\nInput shape: {x.shape}")
print(f"Output shape: {outputs.shape}")
print(f"Final hidden shape: {h_n.shape}")
print(f"Final cell shape: {c_n.shape}")

## 14. Character-Level Language Model Demo

A simple example of what RNNs/LSTMs can do.

In [None]:
class CharLSTM(nn.Module):
    """Simple character-level LSTM language model."""
    
    def __init__(self, vocab_size, embed_size, hidden_size, num_layers=1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, num_layers, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)
        
    def forward(self, x, hidden=None):
        embedded = self.embedding(x)
        outputs, hidden = self.lstm(embedded, hidden)
        logits = self.fc(outputs)
        return logits, hidden
    
    def generate(self, start_char, char_to_idx, idx_to_char, length=100, temperature=1.0):
        """Generate text starting from a character."""
        self.eval()
        generated = [start_char]
        x = torch.tensor([[char_to_idx[start_char]]])
        hidden = None
        
        with torch.no_grad():
            for _ in range(length):
                logits, hidden = self(x, hidden)
                probs = F.softmax(logits[0, -1] / temperature, dim=0)
                next_idx = torch.multinomial(probs, 1).item()
                generated.append(idx_to_char[next_idx])
                x = torch.tensor([[next_idx]])
        
        return ''.join(generated)

# Simple demo with a tiny vocabulary
text = "hello world hello neural network hello lstm hello deep learning"
chars = sorted(set(text))
char_to_idx = {c: i for i, c in enumerate(chars)}
idx_to_char = {i: c for c, i in char_to_idx.items()}

print(f"Vocabulary: {chars}")
print(f"Vocabulary size: {len(chars)}")

# Create model (would need training for good output)
model = CharLSTM(vocab_size=len(chars), embed_size=16, hidden_size=32)
print(f"\nModel architecture:")
print(model)

## 15. Pretrained Models from Hugging Face

For real applications, use pretrained models. AWD-LSTM was state-of-the-art before transformers.

In [None]:
# Note: Uncomment to use (requires internet and transformers library)

# from transformers import AutoModelForCausalLM, AutoTokenizer

# For AWD-LSTM style models, you can look at:
# - salesforce/awd-lstm-lm (Language modeling)
# - fastai's ULMFiT implementation

print("For pretrained LSTM models, check:")
print("  - salesforce/awd-lstm-lm")
print("  - fastai's language model pretraining")
print("  - https://huggingface.co/models?filter=lstm")
print("\nHowever, most production NLP now uses transformers (BERT, GPT, etc.)")

## Summary

What we covered:
1. **Simple RNN**: Process sequences with shared weights across time
2. **BPTT**: Backpropagation through the unrolled network
3. **Vanishing Gradients**: Why RNNs struggle with long sequences
4. **Gradient Clipping**: Prevent exploding gradients
5. **LSTM**: Cell state + gates = long-term memory
6. **Activation Functions**: Sigmoid for gates, tanh for values
7. **Initialization**: He/Xavier for stable training

LSTMs were dominant in NLP from 2014-2017, enabling:
- Machine translation (seq2seq)
- Text generation
- Sentiment analysis
- Speech recognition

But they have limitations: sequential processing (can't parallelize), fixed context window.

**What came next?** The Transformer architecture (2017), which uses attention to look at all positions simultaneously.