# BiLSTM Step-by-Step: Processing "I love NLP"

## Setup
- Sentence: **"I love NLP"**
- Hidden units: **4** (per direction)
- Embedding dimension: **3**
- **Same as RNN and LSTM examples for comparison!**

## Bidirectional LSTM (BiLSTM)
- Processes sequence **FORWARD** (left-to-right): I → love → NLP
- Processes sequence **BACKWARD** (right-to-left): NLP → love → I
- **Concatenates** both directions at each time step
- Output dimension: **8** (4 forward + 4 backward)

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle, FancyBboxPatch, FancyArrowPatch

# Set random seed for reproducibility
np.random.seed(42)

# Configuration (SAME as RNN and LSTM)
vocab = {"I": 0, "love": 1, "NLP": 2}
embedding_dim = 3
hidden_units = 4  # Per direction

print("Vocabulary:", vocab)
print(f"Embedding dimension: {embedding_dim}")
print(f"Hidden units per direction: {hidden_units}")
print(f"Total output dimension: {hidden_units * 2} (forward + backward)")
print()
print("BiLSTM has TWO separate LSTMs:")
print("  1. Forward LSTM:  processes I → love → NLP")
print("  2. Backward LSTM: processes NLP → love → I")

## Step 1: Create Embedding Matrix (SAME as RNN/LSTM)

In [None]:
# Embedding matrix (SAME as RNN and LSTM examples)
embedding_matrix = np.array([
    [0.5, 0.2, 0.1],   # "I"
    [0.8, 0.6, 0.3],   # "love"
    [0.1, 0.9, 0.7]    # "NLP"
])

print("Embedding Matrix:")
print(embedding_matrix)

# Get embeddings for our sentence
sentence = ["I", "love", "NLP"]
sentence_ids = [vocab[word] for word in sentence]
embeddings = embedding_matrix[sentence_ids]

print("\nWord Embeddings:")
for i, word in enumerate(sentence):
    print(f"{word:6s}: {embeddings[i]}")

print("\nForward sequence:  I (idx=0) → love (idx=1) → NLP (idx=2)")
print("Backward sequence: NLP (idx=2) → love (idx=1) → I (idx=0)")

## Step 2: Initialize Weights for BOTH Directions

BiLSTM has **TWO sets of weights** (one for forward, one for backward)

In [None]:
# Helper functions
def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

# Simplified LSTM weights (using same structure as LSTM example)
# FORWARD LSTM weights
np.random.seed(42)
W_f_fwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_f_fwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_f_fwd = np.ones(hidden_units) * 0.5

W_i_fwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_i_fwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_i_fwd = np.ones(hidden_units) * 0.1

W_C_fwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_C_fwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_C_fwd = np.ones(hidden_units) * 0.1

W_o_fwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_o_fwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_o_fwd = np.ones(hidden_units) * 0.1

# BACKWARD LSTM weights (different from forward)
np.random.seed(123)
W_f_bwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_f_bwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_f_bwd = np.ones(hidden_units) * 0.5

W_i_bwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_i_bwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_i_bwd = np.ones(hidden_units) * 0.1

W_C_bwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_C_bwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_C_bwd = np.ones(hidden_units) * 0.1

W_o_bwd_x = np.random.randn(hidden_units, embedding_dim) * 0.1
W_o_bwd_h = np.random.randn(hidden_units, hidden_units) * 0.1
b_o_bwd = np.ones(hidden_units) * 0.1

print("BiLSTM Weight Configuration:")
print("="*60)
print("Forward LSTM:  4 gates × weights = separate parameters")
print("Backward LSTM: 4 gates × weights = separate parameters")
print()
print("Total parameters ≈ 2× regular LSTM")
print("="*60)

## Step 3: LSTM Helper Function

In [None]:
def lstm_step(x_t, h_prev, C_prev, W_f_x, W_f_h, b_f, W_i_x, W_i_h, b_i,
              W_C_x, W_C_h, b_C, W_o_x, W_o_h, b_o):
    """Single LSTM forward step"""
    
    # Forget gate
    f_t = sigmoid(W_f_x.T @ x_t + W_f_h.T @ h_prev + b_f)
    
    # Input gate
    i_t = sigmoid(W_i_x.T @ x_t + W_i_h.T @ h_prev + b_i)
    
    # Candidate cell state
    C_tilde = np.tanh(W_C_x.T @ x_t + W_C_h.T @ h_prev + b_C)
    
    # New cell state
    C_t = f_t * C_prev + i_t * C_tilde
    
    # Output gate
    o_t = sigmoid(W_o_x.T @ x_t + W_o_h.T @ h_prev + b_o)
    
    # New hidden state
    h_t = o_t * np.tanh(C_t)
    
    return h_t, C_t

print("LSTM step function defined!")

## Step 4: FORWARD LSTM Pass (Left-to-Right)

Process: **I → love → NLP**

In [None]:
print("FORWARD LSTM PASS (I → love → NLP)")
print("="*70)

# Initialize forward states
h_fwd = [np.zeros(hidden_units)]  # h_0
C_fwd = [np.zeros(hidden_units)]  # C_0

print("Initial state:")
print(f"h_fwd_0: {h_fwd[0]}")
print(f"C_fwd_0: {C_fwd[0]}")
print()

# Process each word in FORWARD direction
for i, word in enumerate(sentence):
    x_t = embeddings[i]
    h_prev = h_fwd[-1]
    C_prev = C_fwd[-1]
    
    print(f"Time step {i+1}: Processing '{word}'")
    print("-"*70)
    print(f"Input x_{i+1}: {x_t}")
    
    h_t, C_t = lstm_step(x_t, h_prev, C_prev,
                        W_f_fwd_x, W_f_fwd_h, b_f_fwd,
                        W_i_fwd_x, W_i_fwd_h, b_i_fwd,
                        W_C_fwd_x, W_C_fwd_h, b_C_fwd,
                        W_o_fwd_x, W_o_fwd_h, b_o_fwd)
    
    h_fwd.append(h_t)
    C_fwd.append(C_t)
    
    print(f"h_fwd_{i+1}: {h_t}")
    print(f"C_fwd_{i+1}: {C_t}")
    print()

print("="*70)
print("Forward LSTM complete!")
print(f"Final forward hidden state h_fwd_3: {h_fwd[3]}")
print(f"  → Encodes information from 'I love NLP' (left-to-right)")
print("="*70)

## Step 5: BACKWARD LSTM Pass (Right-to-Left)

Process: **NLP → love → I**

In [None]:
print("\nBACKWARD LSTM PASS (NLP → love → I)")
print("="*70)

# Initialize backward states
h_bwd = [np.zeros(hidden_units)]  # h_0
C_bwd = [np.zeros(hidden_units)]  # C_0

print("Initial state:")
print(f"h_bwd_0: {h_bwd[0]}")
print(f"C_bwd_0: {C_bwd[0]}")
print()

# Process each word in BACKWARD direction (reverse order)
for i, word in enumerate(reversed(sentence)):
    idx = len(sentence) - 1 - i  # Actual index in original sequence
    x_t = embeddings[idx]
    h_prev = h_bwd[-1]
    C_prev = C_bwd[-1]
    
    print(f"Time step {i+1}: Processing '{word}' (from position {idx})")
    print("-"*70)
    print(f"Input x_{idx}: {x_t}")
    
    h_t, C_t = lstm_step(x_t, h_prev, C_prev,
                        W_f_bwd_x, W_f_bwd_h, b_f_bwd,
                        W_i_bwd_x, W_i_bwd_h, b_i_bwd,
                        W_C_bwd_x, W_C_bwd_h, b_C_bwd,
                        W_o_bwd_x, W_o_bwd_h, b_o_bwd)
    
    h_bwd.append(h_t)
    C_bwd.append(C_t)
    
    print(f"h_bwd_{i+1}: {h_t}")
    print(f"C_bwd_{i+1}: {C_t}")
    print()

# Reverse the backward states to align with forward
h_bwd_aligned = [h_bwd[0]] + list(reversed(h_bwd[1:]))
C_bwd_aligned = [C_bwd[0]] + list(reversed(C_bwd[1:]))

print("="*70)
print("Backward LSTM complete!")
print(f"Final backward hidden state h_bwd_3: {h_bwd[3]}")
print(f"  → Encodes information from 'I love NLP' (right-to-left)")
print("="*70)

## Step 6: Concatenate Forward and Backward States

At each time step, combine both directions

In [None]:
print("\nCONCATENATING FORWARD AND BACKWARD STATES")
print("="*70)

# Concatenate hidden states at each time step
h_bilstm = []
for i in range(len(sentence) + 1):
    h_combined = np.concatenate([h_fwd[i], h_bwd_aligned[i]])
    h_bilstm.append(h_combined)

print("BiLSTM outputs at each time step:")
print()
for i, word in enumerate(['(init)', 'I', 'love', 'NLP']):
    print(f"t={i} ({word}):")
    print(f"  Forward:  {h_fwd[i]}")
    print(f"  Backward: {h_bwd_aligned[i]}")
    print(f"  Combined: {h_bilstm[i]}")
    print(f"  Shape: {h_bilstm[i].shape} (4 forward + 4 backward = 8 total)")
    print()

print("="*70)
print("Key Insight:")
print("At each position, BiLSTM has context from BOTH directions:")
print("  - Forward: what came BEFORE")
print("  - Backward: what comes AFTER")
print()
print("Example at 'love' (t=2):")
print("  Forward state knows: 'I love'")
print("  Backward state knows: 'love NLP'")
print("  → BiLSTM knows FULL sentence context!")
print("="*70)

## Visualization 1: Directional Flow

In [None]:
fig, ax = plt.subplots(figsize=(16, 8))
ax.set_xlim(0, 16)
ax.set_ylim(0, 8)
ax.axis('off')

# Title
ax.text(8, 7.5, 'BiLSTM: Bidirectional Processing of "I love NLP"', 
        ha='center', fontsize=18, fontweight='bold')

# Time step positions
x_positions = [3, 6, 9, 12]
words = ['(init)', 'I', 'love', 'NLP']

# Draw forward path (top)
y_fwd = 5.5
for i, (x, word) in enumerate(zip(x_positions, words)):
    # Forward neuron cluster
    if i > 0:
        for j in range(4):
            y_offset = y_fwd + (j - 1.5) * 0.2
            circle = plt.Circle((x, y_offset), 0.12, color='#5B9BD5', 
                              ec='black', linewidth=1.5, zorder=3)
            ax.add_patch(circle)
        
        ax.text(x, y_fwd - 0.6, word, ha='center', fontsize=11, fontweight='bold')
    
    # Forward arrows
    if i < len(x_positions) - 1:
        ax.arrow(x + 0.4, y_fwd, x_positions[i+1] - x - 0.9, 0,
                head_width=0.2, head_length=0.2, fc='#5B9BD5', ec='#5B9BD5', 
                linewidth=3, zorder=2)

ax.text(1.5, y_fwd, 'FORWARD →', ha='center', fontsize=12, 
        fontweight='bold', color='#5B9BD5')

# Draw backward path (bottom)
y_bwd = 2.5
for i, (x, word) in enumerate(zip(x_positions, words)):
    # Backward neuron cluster
    if i > 0:
        for j in range(4):
            y_offset = y_bwd + (j - 1.5) * 0.2
            circle = plt.Circle((x, y_offset), 0.12, color='#E67E22',
                              ec='black', linewidth=1.5, zorder=3)
            ax.add_patch(circle)
        
        ax.text(x, y_bwd + 0.6, word, ha='center', fontsize=11, fontweight='bold')
    
    # Backward arrows (right to left)
    if i > 0:
        ax.arrow(x - 0.4, y_bwd, -(x - x_positions[i-1] - 0.9), 0,
                head_width=0.2, head_length=0.2, fc='#E67E22', ec='#E67E22',
                linewidth=3, zorder=2)

ax.text(14.5, y_bwd, '← BACKWARD', ha='center', fontsize=12,
        fontweight='bold', color='#E67E22')

# Draw concatenation
for i, x in enumerate(x_positions[1:], 1):
    # Vertical connector
    ax.plot([x, x], [y_fwd - 0.5, y_bwd + 0.5], 'k--', linewidth=2, alpha=0.5)
    
    # Combined output
    rect = Rectangle((x - 0.4, 4 - 0.3), 0.8, 0.6, 
                     facecolor='#2ECC71', edgecolor='black', linewidth=2)
    ax.add_patch(rect)
    ax.text(x, 4, f'BiLSTM\n({hidden_units*2})', ha='center', va='center',
           fontsize=9, fontweight='bold', color='white')

# Legend
ax.text(1, 1, 'Forward (4 units): Past context', fontsize=10, color='#5B9BD5')
ax.text(1, 0.6, 'Backward (4 units): Future context', fontsize=10, color='#E67E22')
ax.text(1, 0.2, 'Combined (8 units): Full context', fontsize=10, color='#2ECC71')

plt.tight_layout()
plt.savefig('bilstm_directional_flow.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'bilstm_directional_flow.png'")

## Visualization 2: State Heatmaps

In [None]:
# Prepare data for heatmaps
h_fwd_array = np.array(h_fwd[1:])  # Skip initial zero state
h_bwd_array = np.array([h_bwd_aligned[1], h_bwd_aligned[2], h_bwd_aligned[3]])
h_bilstm_array = np.array(h_bilstm[1:])

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Forward hidden states
sns.heatmap(h_fwd_array, annot=True, fmt='.2f', cmap='Blues',
            xticklabels=[f'Unit {i+1}' for i in range(4)],
            yticklabels=['I', 'love', 'NLP'],
            ax=axes[0], cbar_kws={'label': 'Activation'})
axes[0].set_title('Forward LSTM\n(I → love → NLP)', fontsize=13, fontweight='bold')
axes[0].set_ylabel('Time Step', fontsize=11)

# Backward hidden states
sns.heatmap(h_bwd_array, annot=True, fmt='.2f', cmap='Oranges',
            xticklabels=[f'Unit {i+1}' for i in range(4)],
            yticklabels=['I', 'love', 'NLP'],
            ax=axes[1], cbar_kws={'label': 'Activation'})
axes[1].set_title('Backward LSTM\n(NLP → love → I)', fontsize=13, fontweight='bold')
axes[1].set_ylabel('Time Step', fontsize=11)

# BiLSTM combined
sns.heatmap(h_bilstm_array, annot=True, fmt='.2f', cmap='Greens',
            xticklabels=[f'Fwd{i+1}' for i in range(4)] + [f'Bwd{i+1}' for i in range(4)],
            yticklabels=['I', 'love', 'NLP'],
            ax=axes[2], cbar_kws={'label': 'Activation'})
axes[2].set_title('BiLSTM Combined\n(Forward + Backward)', fontsize=13, fontweight='bold')
axes[2].set_ylabel('Time Step', fontsize=11)

plt.tight_layout()
plt.savefig('bilstm_state_heatmaps.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'bilstm_state_heatmaps.png'")

## Visualization 3: Context at Each Position

In [None]:
fig, ax = plt.subplots(figsize=(14, 8))
ax.set_xlim(0, 14)
ax.set_ylim(0, 8)
ax.axis('off')

# Title
ax.text(7, 7.5, 'BiLSTM: Context Available at Each Position', 
        ha='center', fontsize=16, fontweight='bold')

# Words and their contexts
contexts = [
    {
        'word': 'I',
        'x': 3,
        'fwd': 'START + I',
        'bwd': 'I + love + NLP',
        'combined': 'Full sentence'
    },
    {
        'word': 'love',
        'x': 7,
        'fwd': 'I + love',
        'bwd': 'love + NLP',
        'combined': 'Full sentence'
    },
    {
        'word': 'NLP',
        'x': 11,
        'fwd': 'I + love + NLP',
        'bwd': 'NLP + END',
        'combined': 'Full sentence'
    }
]

for ctx in contexts:
    x = ctx['x']
    
    # Word box
    word_box = FancyBboxPatch((x - 0.6, 5), 1.2, 0.8,
                              boxstyle="round,pad=0.1",
                              edgecolor='black', facecolor='#3498DB',
                              linewidth=2.5)
    ax.add_patch(word_box)
    ax.text(x, 5.4, ctx['word'], ha='center', va='center',
           fontsize=14, fontweight='bold', color='white')
    
    # Forward context
    fwd_box = FancyBboxPatch((x - 1, 3.5), 2, 0.6,
                             boxstyle="round,pad=0.05",
                             edgecolor='#5B9BD5', facecolor='#E8F4F9',
                             linewidth=2)
    ax.add_patch(fwd_box)
    ax.text(x - 1, 3.9, 'Forward:', ha='left', va='center',
           fontsize=9, fontweight='bold', color='#5B9BD5')
    ax.text(x, 3.7, ctx['fwd'], ha='center', va='center',
           fontsize=8, color='#2C3E50')
    
    # Backward context
    bwd_box = FancyBboxPatch((x - 1, 2.5), 2, 0.6,
                            boxstyle="round,pad=0.05",
                            edgecolor='#E67E22', facecolor='#FDF2E9',
                            linewidth=2)
    ax.add_patch(bwd_box)
    ax.text(x - 1, 2.9, 'Backward:', ha='left', va='center',
           fontsize=9, fontweight='bold', color='#E67E22')
    ax.text(x, 2.7, ctx['bwd'], ha='center', va='center',
           fontsize=8, color='#2C3E50')
    
    # Combined context
    combined_box = FancyBboxPatch((x - 1, 1.3), 2, 0.6,
                                 boxstyle="round,pad=0.05",
                                 edgecolor='#2ECC71', facecolor='#E8F8F5',
                                 linewidth=2.5)
    ax.add_patch(combined_box)
    ax.text(x - 1, 1.7, 'BiLSTM:', ha='left', va='center',
           fontsize=9, fontweight='bold', color='#2ECC71')
    ax.text(x, 1.5, ctx['combined'], ha='center', va='center',
           fontsize=8, fontweight='bold', color='#27AE60')

# Legend
ax.text(7, 0.5, '✓ BiLSTM sees ENTIRE sentence at each position (past + future)',
        ha='center', fontsize=11, style='italic', color='#27AE60', fontweight='bold')

plt.tight_layout()
plt.savefig('bilstm_context_diagram.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'bilstm_context_diagram.png'")

## Comparison: RNN vs LSTM vs BiLSTM

In [None]:
print("\n" + "="*80)
print("COMPARISON: RNN vs LSTM vs BiLSTM")
print("="*80)

comparison = [
    ("States", "1 (h_t)", "2 (h_t, C_t)", "2 per direction (h_fwd, C_fwd, h_bwd, C_bwd)"),
    ("Directions", "Forward only", "Forward only", "Forward + Backward"),
    ("Gates", "None", "3 gates", "3 gates per direction"),
    ("Output dim", "4", "4", "8 (4 fwd + 4 bwd)"),
    ("Parameters", "~32", "~128", "~256 (2× LSTM)"),
    ("Context", "Past only", "Past only", "Past + Future"),
    ("For 'love'", "Knows 'I'", "Knows 'I'", "Knows 'I' AND 'NLP'"),
    ("Speed", "Fast", "Medium", "Slower (2 passes)"),
    ("Best for", "Simple tasks", "Long sequences", "Tasks needing full context"),
]

print(f"{'Aspect':<15} {'RNN':<20} {'LSTM':<20} {'BiLSTM':<30}")
print("-"*80)
for row in comparison:
    print(f"{row[0]:<15} {row[1]:<20} {row[2]:<20} {row[3]:<30}")
print("="*80)

print("\nKey Advantages of BiLSTM:")
print("  1. Sees FUTURE context (backward pass)")
print("  2. Better for tasks where future info helps (NER, POS tagging, sentiment)")
print("  3. Each position has complete sentence information")
print()
print("Trade-offs:")
print("  - 2× parameters compared to LSTM")
print("  - 2× slower (must process sequence twice)")
print("  - Cannot be used for real-time/streaming (needs full sequence)")
print("="*80)

## When to Use BiLSTM vs LSTM

### Use BiLSTM when:
- ✅ You have the **complete sequence** available
- ✅ Future context helps (Named Entity Recognition, POS tagging)
- ✅ Accuracy is more important than speed
- ✅ You're doing sequence labeling (classify each token)

**Examples:**
- "John works at **Microsoft**" ← knowing "Microsoft" helps identify "John" as PERSON
- "The **bank** was steep" vs "The **bank** was closed" ← future words disambiguate meaning

### Use LSTM when:
- ✅ Real-time/streaming predictions (don't have future tokens)
- ✅ Speed is important
- ✅ Language modeling / next word prediction
- ✅ Sequence generation

**Examples:**
- Autocomplete (predicting next word)
- Real-time speech recognition
- Stock price prediction (can't see future)

## Summary

### What is BiLSTM?

**Two separate LSTMs:**
```python
Forward LSTM:  processes I → love → NLP
Backward LSTM: processes NLP → love → I
```

**Output at each position:**
```python
BiLSTM_output[t] = [h_forward[t], h_backward[t]]
                    ↑               ↑
                  4 dims          4 dims
                    └─────────┬─────────┘
                           8 dims total
```

### Example at position "love":
```
Forward state:  encodes "I love" (past)
Backward state: encodes "love NLP" (future)
Combined:       encodes full sentence "I love NLP"
```

### Dimensions:
- Forward hidden state: **4 dimensions**
- Backward hidden state: **4 dimensions**
- BiLSTM output: **8 dimensions** (concatenated)

### Key Insight:
**BiLSTM processes the sequence TWICE** (forward and backward), giving each position access to **complete context** from both past and future!