# Part 2.3: RNN vs LSTM vs GRU Comparison

Comprehensive comparison of RNN architectures with hands-on implementation and analysis.

## Objective
- Implement vanilla RNN, LSTM, and GRU from scratch and using PyTorch
- Compare on sequence modeling tasks
- Analyze gate activations and learning dynamics
- Perform ablation study on gate mechanisms

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import accuracy_score, f1_score
import time

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(42)
np.random.seed(42)

print(f"Using device: {device}")
print("RNN Architecture Comparison")
print("=" * 40)

In [None]:
# Custom LSTM implementation
class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomLSTM, self).__init__()
        self.hidden_size = hidden_size
        
        # Forget gate
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        # Input gate
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
        # Candidate values
        self.W_C = nn.Linear(input_size + hidden_size, hidden_size)
        # Output gate
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        
        # Store gate activations for visualization
        self.gate_activations = {
            'forget': [],
            'input': [],
            'output': [],
            'cell': []
        }
        
    def forward(self, x, hidden_state=None):
        batch_size, seq_len, _ = x.size()
        
        if hidden_state is None:
            h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
            C_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        else:
            h_t, C_t = hidden_state
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            combined = torch.cat([x_t, h_t], dim=1)
            
            # LSTM gates
            f_t = torch.sigmoid(self.W_f(combined))  # Forget gate
            i_t = torch.sigmoid(self.W_i(combined))  # Input gate
            C_tilde = torch.tanh(self.W_C(combined))  # Candidate values
            o_t = torch.sigmoid(self.W_o(combined))  # Output gate
            
            # Update cell state
            C_t = f_t * C_t + i_t * C_tilde
            
            # Update hidden state
            h_t = o_t * torch.tanh(C_t)
            
            outputs.append(h_t)
            
            # Store activations for analysis (only for last batch)
            if t == seq_len - 1:
                self.gate_activations['forget'].append(f_t.mean().item())
                self.gate_activations['input'].append(i_t.mean().item())
                self.gate_activations['output'].append(o_t.mean().item())
                self.gate_activations['cell'].append(C_t.mean().item())
        
        outputs = torch.stack(outputs, dim=1)
        return outputs, (h_t, C_t)

# Custom GRU implementation
class CustomGRU(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomGRU, self).__init__()
        self.hidden_size = hidden_size
        
        # Reset gate
        self.W_r = nn.Linear(input_size + hidden_size, hidden_size)
        # Update gate
        self.W_z = nn.Linear(input_size + hidden_size, hidden_size)
        # New gate
        self.W_n = nn.Linear(input_size + hidden_size, hidden_size)
        
        # Store gate activations
        self.gate_activations = {
            'reset': [],
            'update': []
        }
        
    def forward(self, x, hidden_state=None):
        batch_size, seq_len, _ = x.size()
        
        if hidden_state is None:
            h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        else:
            h_t = hidden_state
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            combined = torch.cat([x_t, h_t], dim=1)
            
            # GRU gates
            r_t = torch.sigmoid(self.W_r(combined))  # Reset gate
            z_t = torch.sigmoid(self.W_z(combined))  # Update gate
            
            # New gate with reset applied
            combined_new = torch.cat([x_t, r_t * h_t], dim=1)
            n_t = torch.tanh(self.W_n(combined_new))
            
            # Update hidden state
            h_t = (1 - z_t) * n_t + z_t * h_t
            
            outputs.append(h_t)
            
            # Store activations
            if t == seq_len - 1:
                self.gate_activations['reset'].append(r_t.mean().item())
                self.gate_activations['update'].append(z_t.mean().item())
        
        outputs = torch.stack(outputs, dim=1)
        return outputs, h_t

# Vanilla RNN for comparison
class CustomRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(CustomRNN, self).__init__()
        self.hidden_size = hidden_size
        self.W_hh = nn.Linear(hidden_size, hidden_size)
        self.W_xh = nn.Linear(input_size, hidden_size)
        
    def forward(self, x, hidden_state=None):
        batch_size, seq_len, _ = x.size()
        
        if hidden_state is None:
            h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        else:
            h_t = hidden_state
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            h_t = torch.tanh(self.W_xh(x_t) + self.W_hh(h_t))
            outputs.append(h_t)
        
        outputs = torch.stack(outputs, dim=1)
        return outputs, h_t

print("Custom RNN, LSTM, and GRU implementations ready")

In [None]:
# Sequence Classification Dataset
class SequenceClassificationDataset(Dataset):
    def __init__(self, num_samples=2000, seq_length=50, num_features=20, num_classes=3):
        self.num_samples = num_samples
        self.seq_length = seq_length
        self.num_features = num_features
        self.num_classes = num_classes
        
        # Generate synthetic sequential data
        self.data, self.labels = self._generate_data()
        
    def _generate_data(self):
        data = []
        labels = []
        
        for _ in range(self.num_samples):
            # Generate class
            class_id = np.random.randint(0, self.num_classes)
            
            # Generate sequence with class-dependent patterns
            if class_id == 0:
                # Increasing trend
                base_seq = np.linspace(0, 1, self.seq_length)
                noise = np.random.normal(0, 0.1, (self.seq_length, self.num_features))
                seq = base_seq.reshape(-1, 1) + noise
            elif class_id == 1:
                # Sinusoidal pattern
                t = np.linspace(0, 4*np.pi, self.seq_length)
                base_seq = np.sin(t)
                noise = np.random.normal(0, 0.1, (self.seq_length, self.num_features))
                seq = base_seq.reshape(-1, 1) + noise
            else:
                # Random walk
                steps = np.random.normal(0, 0.1, self.seq_length)
                base_seq = np.cumsum(steps)
                noise = np.random.normal(0, 0.05, (self.seq_length, self.num_features))
                seq = base_seq.reshape(-1, 1) + noise
            
            data.append(torch.FloatTensor(seq))
            labels.append(torch.LongTensor([class_id]))
            
        return data, labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# Model wrapper for comparison
class SequenceClassifier(nn.Module):
    def __init__(self, rnn_type, input_size, hidden_size, num_classes):
        super(SequenceClassifier, self).__init__()
        self.rnn_type = rnn_type
        
        if rnn_type == 'RNN':
            self.rnn = CustomRNN(input_size, hidden_size)
        elif rnn_type == 'LSTM':
            self.rnn = CustomLSTM(input_size, hidden_size)
        elif rnn_type == 'GRU':
            self.rnn = CustomGRU(input_size, hidden_size)
        elif rnn_type == 'PyTorch_LSTM':
            self.rnn = nn.LSTM(input_size, hidden_size, batch_first=True)
        elif rnn_type == 'PyTorch_GRU':
            self.rnn = nn.GRU(input_size, hidden_size, batch_first=True)
        
        self.classifier = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        # Get RNN output
        if self.rnn_type in ['PyTorch_LSTM', 'PyTorch_GRU']:
            output, _ = self.rnn(x)
        else:
            output, _ = self.rnn(x)
        
        # Use last timestep for classification
        last_output = output[:, -1, :]
        return self.classifier(last_output)

print("Dataset and model wrapper ready")

In [None]:
# Training function
def train_model(model, train_loader, val_loader, num_epochs=50, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    
    start_time = time.time()
    
    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0
        train_correct = 0
        
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device).squeeze()
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            
            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            
            train_loss += loss.item()
            train_correct += (outputs.argmax(1) == targets).sum().item()
        
        # Validation
        model.eval()
        val_loss = 0
        val_correct = 0
        
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device).squeeze()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                
                val_loss += loss.item()
                val_correct += (outputs.argmax(1) == targets).sum().item()
        
        # Calculate metrics
        train_loss /= len(train_loader)
        val_loss /= len(val_loader)
        train_acc = train_correct / len(train_loader.dataset)
        val_acc = val_correct / len(val_loader.dataset)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        if (epoch + 1) % 10 == 0:
            print(f'Epoch {epoch+1}: Train Acc: {train_acc:.3f}, Val Acc: {val_acc:.3f}')
    
    training_time = time.time() - start_time
    
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'training_time': training_time,
        'final_val_acc': val_accs[-1]
    }

# Experimental setup
input_size = 20
hidden_size = 64
num_classes = 3
batch_size = 32

# Create datasets
dataset = SequenceClassificationDataset(num_samples=2000, seq_length=50, 
                                      num_features=input_size, num_classes=num_classes)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f"Dataset created: {len(train_dataset)} training, {len(val_dataset)} validation samples")
print(f"Sequence length: 50, Features: {input_size}, Classes: {num_classes}")

In [None]:
# Compare different architectures
architectures = ['RNN', 'LSTM', 'GRU', 'PyTorch_LSTM', 'PyTorch_GRU']
results = {}

print("\nStarting architecture comparison...")
print("=" * 40)

for arch in architectures:
    print(f"\nðŸ”„ Training {arch}...")
    
    model = SequenceClassifier(arch, input_size, hidden_size, num_classes).to(device)
    
    # Count parameters
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    result = train_model(model, train_loader, val_loader, num_epochs=50)
    result['num_params'] = num_params
    result['model'] = model
    
    results[arch] = result
    
    print(f"âœ… {arch} completed - Final Accuracy: {result['final_val_acc']:.3f}, Time: {result['training_time']:.1f}s")

print("\nðŸŽ‰ All architectures trained!")

In [None]:
# Comprehensive visualization
fig, axes = plt.subplots(2, 3, figsize=(20, 12))
fig.suptitle('RNN Architecture Comparison', fontsize=16, fontweight='bold')

colors = ['red', 'blue', 'green', 'orange', 'purple']
arch_colors = dict(zip(architectures, colors))

# 1. Validation Accuracy Curves
ax = axes[0, 0]
for arch in architectures:
    val_accs = results[arch]['val_accs']
    ax.plot(val_accs, color=arch_colors[arch], linewidth=2, label=arch)

ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy')
ax.set_title('Learning Progress Comparison')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Final Performance vs Parameters
ax = axes[0, 1]
final_accs = [results[arch]['final_val_acc'] for arch in architectures]
num_params = [results[arch]['num_params'] for arch in architectures]

scatter = ax.scatter(num_params, final_accs, c=colors, s=150, alpha=0.7)
for i, arch in enumerate(architectures):
    ax.annotate(arch, (num_params[i], final_accs[i]), 
               xytext=(5, 5), textcoords='offset points', fontweight='bold')

ax.set_xlabel('Number of Parameters')
ax.set_ylabel('Final Validation Accuracy')
ax.set_title('Performance vs Complexity')
ax.grid(True, alpha=0.3)

# 3. Training Time Comparison
ax = axes[0, 2]
training_times = [results[arch]['training_time'] for arch in architectures]

bars = ax.bar(architectures, training_times, color=colors, alpha=0.7)
ax.set_ylabel('Training Time (seconds)')
ax.set_title('Training Efficiency')
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3)

# Add value labels
for bar, time_val in zip(bars, training_times):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
           f'{time_val:.1f}s', ha='center', va='bottom', fontweight='bold')

# 4. Gate Activations (LSTM)
ax = axes[1, 0]
lstm_model = results['LSTM']['model']
if hasattr(lstm_model.rnn, 'gate_activations'):
    gates = lstm_model.rnn.gate_activations
    epochs = range(len(gates['forget']))
    
    ax.plot(epochs, gates['forget'], label='Forget Gate', linewidth=2)
    ax.plot(epochs, gates['input'], label='Input Gate', linewidth=2)
    ax.plot(epochs, gates['output'], label='Output Gate', linewidth=2)
    
    ax.set_xlabel('Training Progress')
    ax.set_ylabel('Average Gate Activation')
    ax.set_title('LSTM Gate Activations')
    ax.legend()
    ax.grid(True, alpha=0.3)

# 5. Architecture Summary Table
ax = axes[1, 1]
ax.axis('tight')
ax.axis('off')

table_data = []
for arch in architectures:
    result = results[arch]
    table_data.append([
        arch,
        f"{result['final_val_acc']:.3f}",
        f"{result['num_params']:,}",
        f"{result['training_time']:.1f}s"
    ])

table = ax.table(cellText=table_data,
                colLabels=['Architecture', 'Accuracy', 'Parameters', 'Time'],
                cellLoc='center',
                loc='center')
table.auto_set_font_size(False)
table.set_fontsize(11)
table.scale(1.2, 2)
ax.set_title('Performance Summary', fontweight='bold')

# 6. Convergence Rate Analysis
ax = axes[1, 2]
# Find epoch where each model reaches 80% of final accuracy
convergence_epochs = []
for arch in architectures:
    val_accs = results[arch]['val_accs']
    target_acc = 0.8 * val_accs[-1]
    
    conv_epoch = len(val_accs)  # Default to end
    for i, acc in enumerate(val_accs):
        if acc >= target_acc:
            conv_epoch = i + 1
            break
    convergence_epochs.append(conv_epoch)

bars = ax.bar(architectures, convergence_epochs, color=colors, alpha=0.7)
ax.set_ylabel('Epochs to Convergence')
ax.set_title('Convergence Speed')
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3)

# Add value labels
for bar, epochs in zip(bars, convergence_epochs):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.5,
           f'{epochs}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

## Ablation Study: Gate Mechanisms

Let's analyze the importance of different gate mechanisms by selectively disabling them.

In [None]:
# Ablation study - Modified LSTM with disabled gates
class AblationLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, disable_forget=False, disable_input=False, disable_output=False):
        super(AblationLSTM, self).__init__()
        self.hidden_size = hidden_size
        self.disable_forget = disable_forget
        self.disable_input = disable_input
        self.disable_output = disable_output
        
        # Gates
        self.W_f = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_i = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_C = nn.Linear(input_size + hidden_size, hidden_size)
        self.W_o = nn.Linear(input_size + hidden_size, hidden_size)
        
    def forward(self, x, hidden_state=None):
        batch_size, seq_len, _ = x.size()
        
        if hidden_state is None:
            h_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
            C_t = torch.zeros(batch_size, self.hidden_size).to(x.device)
        else:
            h_t, C_t = hidden_state
        
        outputs = []
        
        for t in range(seq_len):
            x_t = x[:, t, :]
            combined = torch.cat([x_t, h_t], dim=1)
            
            # Gates with ablations
            f_t = torch.ones_like(C_t) if self.disable_forget else torch.sigmoid(self.W_f(combined))
            i_t = torch.ones_like(C_t) if self.disable_input else torch.sigmoid(self.W_i(combined))
            C_tilde = torch.tanh(self.W_C(combined))
            o_t = torch.ones_like(h_t) if self.disable_output else torch.sigmoid(self.W_o(combined))
            
            # Update cell state
            C_t = f_t * C_t + i_t * C_tilde
            
            # Update hidden state
            h_t = o_t * torch.tanh(C_t)
            
            outputs.append(h_t)
        
        outputs = torch.stack(outputs, dim=1)
        return outputs, (h_t, C_t)

# Ablation configurations
ablation_configs = {
    'Full LSTM': {'disable_forget': False, 'disable_input': False, 'disable_output': False},
    'No Forget': {'disable_forget': True, 'disable_input': False, 'disable_output': False},
    'No Input': {'disable_forget': False, 'disable_input': True, 'disable_output': False},
    'No Output': {'disable_forget': False, 'disable_input': False, 'disable_output': True},
    'Forget+Input Only': {'disable_forget': False, 'disable_input': False, 'disable_output': True},
}

ablation_results = {}

print("\nStarting LSTM Ablation Study...")
print("=" * 40)

for config_name, config in ablation_configs.items():
    print(f"\nðŸ”„ Testing: {config_name}")
    
    # Create model with ablation
    class AblationClassifier(nn.Module):
        def __init__(self, **kwargs):
            super().__init__()
            self.rnn = AblationLSTM(input_size, hidden_size, **kwargs)
            self.classifier = nn.Linear(hidden_size, num_classes)
            
        def forward(self, x):
            output, _ = self.rnn(x)
            last_output = output[:, -1, :]
            return self.classifier(last_output)
    
    model = AblationClassifier(**config).to(device)
    result = train_model(model, train_loader, val_loader, num_epochs=30)
    
    ablation_results[config_name] = result
    print(f"âœ… {config_name}: Final Accuracy = {result['final_val_acc']:.3f}")

print("\nðŸŽ‰ Ablation study completed!")

In [None]:
# Visualize ablation results
fig, axes = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle('LSTM Ablation Study Results', fontsize=14, fontweight='bold')

configs = list(ablation_results.keys())
ablation_colors = plt.cm.Set3(np.linspace(0, 1, len(configs)))

# 1. Validation accuracy curves
ax = axes[0]
for i, config in enumerate(configs):
    val_accs = ablation_results[config]['val_accs']
    ax.plot(val_accs, color=ablation_colors[i], linewidth=2, label=config)

ax.set_xlabel('Epoch')
ax.set_ylabel('Validation Accuracy')
ax.set_title('Learning Curves - Gate Ablation')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Final performance comparison
ax = axes[1]
final_accs = [ablation_results[config]['final_val_acc'] for config in configs]

bars = ax.bar(configs, final_accs, color=ablation_colors, alpha=0.7)
ax.set_ylabel('Final Validation Accuracy')
ax.set_title('Gate Importance Analysis')
ax.tick_params(axis='x', rotation=45)
ax.grid(True, alpha=0.3)

# Add value labels
for bar, acc in zip(bars, final_accs):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height + 0.005,
           f'{acc:.3f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Summary table
print("\nðŸ“Š ABLATION STUDY SUMMARY")
print("=" * 50)
for config, result in ablation_results.items():
    print(f"{config:20s}: {result['final_val_acc']:.3f} accuracy")

## Comprehensive Analysis: RNN vs LSTM vs GRU

### Key Findings

#### 1. Performance Comparison
- **LSTM**: Best overall performance due to sophisticated gating mechanism
- **GRU**: Close performance to LSTM with fewer parameters (more efficient)
- **Vanilla RNN**: Struggles with longer sequences due to vanishing gradients

#### 2. Architectural Differences

**LSTM (Long Short-Term Memory)**:
- **3 gates**: Forget, Input, Output
- **Cell state**: Separate memory pathway
- **Equations**:
  - $f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)$ (Forget gate)
  - $i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i)$ (Input gate)
  - $\tilde{C}_t = \tanh(W_C \cdot [h_{t-1}, x_t] + b_C)$ (Candidate values)
  - $C_t = f_t * C_{t-1} + i_t * \tilde{C}_t$ (Cell state)
  - $o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)$ (Output gate)
  - $h_t = o_t * \tanh(C_t)$ (Hidden state)

**GRU (Gated Recurrent Unit)**:
- **2 gates**: Reset, Update
- **No separate cell state**
- **Equations**:
  - $r_t = \sigma(W_r \cdot [h_{t-1}, x_t])$ (Reset gate)
  - $z_t = \sigma(W_z \cdot [h_{t-1}, x_t])$ (Update gate)
  - $\tilde{h}_t = \tanh(W \cdot [r_t * h_{t-1}, x_t])$ (New gate)
  - $h_t = (1-z_t) * \tilde{h}_t + z_t * h_{t-1}$ (Hidden state)

**Vanilla RNN**:
- **Simple recurrence**: $h_t = \tanh(W_{hh}h_{t-1} + W_{xh}x_t + b_h)$
- **No gating mechanism**

#### 3. Ablation Study Insights
- **Forget gate**: Most critical for long-term dependencies
- **Input gate**: Important for selective memory updates
- **Output gate**: Controls information flow to hidden state

#### 4. Practical Considerations
- **GRU**: Good default choice (fewer parameters, competitive performance)
- **LSTM**: Use for critical applications requiring maximum performance
- **Vanilla RNN**: Only for short sequences or when computational resources are extremely limited

### When to Use What?
- **Short sequences (<20 steps)**: Vanilla RNN might suffice
- **Medium sequences (20-100 steps)**: GRU is often optimal
- **Long sequences (>100 steps)**: LSTM or consider Transformers
- **Limited compute**: GRU over LSTM
- **Maximum performance**: LSTM with careful tuning