# Part 3: Beam Search and Modern Translation

## Welcome to Part 3!

You've learned seq2seq (Part 1) and attention (Part 2). Now let's explore how to find BETTER translations!

**What you'll learn:**
- Why greedy decoding fails
- How beam search works
- Quality vs speed tradeoffs
- Modern translation systems

**Time: 15-20 minutes**

---

In [None]:
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)
print("Ready to explore search strategies!")

## 1. The Problem with Greedy Decoding

### What is Greedy?

In Parts 1 and 2, we used **greedy decoding**:
```python
word = argmax(probabilities)  # Pick highest probability
```

Pick the best word NOW, never look back.

### Why Greedy Fails

Sometimes the best word NOW leads to a BAD translation LATER!

**Example:**
```
Greedy Approach:
Step 1: "Le" (90%) vs "La" (10%) → Pick "Le"
Step 2: "chat" (70%)  [total: 90% × 70% = 63%]
Step 3: "assis" (40%) [total: 63% × 40% = 25%]
Result: "Le chat assis" (25% overall)

Better Path:
Step 1: "La" (10%)    → Pick "La"
Step 2: "petite" (80%) [total: 10% × 80% = 8%]
Step 3: "chatte" (90%) [total: 8% × 90% = 7.2%]
Result: "La petite chatte" (7.2% overall but more natural!)
```

**Key Insight:** Lower probability first step can lead to better overall sentence!

## 2. Beam Search: Keeping Multiple Options

### The Idea

Instead of keeping ONLY the best path (greedy), keep the **top k best paths**!

**Beam width k = 2:**
```
Step 1: Generate all first words
  "Le" (90%), "La" (10%), "Un" (5%)
  Keep top 2: "Le" (90%), "La" (10%)

Step 2: For each, generate all second words
  "Le" + "chat" (63%), "Le" + "chien" (18%)
  "La" + "petite" (8%), "La" + "grande" (6%)
  Keep top 2: "Le chat" (63%), "Le chien" (18%)

Step 3: Continue...
```

### Let's Implement It

In [None]:
def beam_search(beam_width=3, max_length=4):
    """
    Simple beam search implementation
    """
    vocab = ['le', 'la', 'chat', 'chien', 'assis', 'court', '<END>']
    
    # Start with empty sequence
    beams = [{'words': [], 'score': 1.0}]
    
    print(f"Beam Search (width={beam_width})")
    print("="*60)
    
    for step in range(max_length):
        print(f"\nStep {step+1}:")
        print(f"Current beams: {len(beams)}")
        
        # Show current beams
        for i, beam in enumerate(beams):
            words_str = ' '.join(beam['words']) if beam['words'] else '[empty]'
            print(f"  Beam {i+1}: '{words_str}' (score: {beam['score']:.3f})")
        
        # Expand all beams
        candidates = []
        for beam in beams:
            # Generate next word probabilities (simplified)
            probs = np.random.dirichlet(np.ones(len(vocab))) * 0.5 + 0.1
            probs = probs / probs.sum()
            
            # Create candidates
            for word, prob in zip(vocab, probs):
                new_beam = {
                    'words': beam['words'] + [word],
                    'score': beam['score'] * prob
                }
                candidates.append(new_beam)
        
        print(f"  Generated {len(candidates)} candidates")
        
        # Keep top k
        candidates.sort(key=lambda x: x['score'], reverse=True)
        beams = candidates[:beam_width]
        
        # Check if done
        if all(b['words'] and b['words'][-1] == '<END>' for b in beams):
            print("  All beams reached <END>")
            break
    
    print("\n" + "="*60)
    print("FINAL RESULTS:")
    for i, beam in enumerate(beams):
        words = [w for w in beam['words'] if w != '<END>']
        print(f"  {i+1}. '{' '.join(words)}' (score: {beam['score']:.4f})")
    
    return beams[0]['words']

# Try different beam widths
result = beam_search(beam_width=3, max_length=4)

## 3. Beam Search Visualization

Let's visualize how beam search explores the search space:

```
              START
             /  |  \
           /    |    \
         Le    La    Un
        (90%)  (10%) (5%) ← PRUNE Un!
         |      |
      -------  ------
     /       \/      \
  chat  chien petite grande
  (63%) (18%)  (8%)  (6%)  ← Keep top 2
    |      |
   BEST  runner-up
```

**Text representation:**

In [None]:
def visualize_beam_text():
    """
    Show beam search as a text tree
    """
    print("\nBeam Search Tree (k=2):")
    print("\n" + "  "*0 + "START")
    print("  "*0 + "│")
    print("  "*0 + "├─ Le (0.90) ✓ KEEP")
    print("  "*0 + "├─ La (0.10) ✓ KEEP")
    print("  "*0 + "└─ Un (0.05) ✗ PRUNE")
    print()
    print("  "*1 + "Le:")
    print("  "*1 + "├─ chat (0.63) ✓ KEEP")
    print("  "*1 + "└─ chien (0.18) ✓ KEEP")
    print()
    print("  "*1 + "La:")
    print("  "*1 + "├─ petite (0.08) ✗ PRUNE")
    print("  "*1 + "└─ grande (0.06) ✗ PRUNE")
    print()
    print("  "*2 + "Best: Le chat (0.63)")
    print()
    print("Key:")
    print("  ✓ = Kept in beam")
    print("  ✗ = Pruned (score too low)")

visualize_beam_text()

## 4. Comparing Search Strategies

Let's compare different approaches:

In [None]:
# Compare quality and speed
strategies = ['Greedy\n(k=1)', 'Beam\n(k=3)', 'Beam\n(k=5)', 'Beam\n(k=10)']
quality = [65, 82, 88, 91]
speed = [100, 75, 55, 30]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Quality
colors = ['#FF6B6B', '#FFD93D', '#6BCB77', '#4D96FF']
ax1.bar(range(len(strategies)), quality, color=colors, edgecolor='black', linewidth=2)
ax1.set_ylabel('Translation Quality (%)', fontweight='bold')
ax1.set_title('Quality Comparison', fontweight='bold')
ax1.set_xticks(range(len(strategies)))
ax1.set_xticklabels(strategies)
ax1.set_ylim(0, 100)
ax1.grid(True, alpha=0.3, axis='y')
for i, (bar, val) in enumerate(zip(ax1.patches, quality)):
    ax1.text(bar.get_x() + bar.get_width()/2, val + 2, f'{val}%',
            ha='center', va='bottom', fontweight='bold')

# Speed
ax2.bar(range(len(strategies)), speed, color=colors, edgecolor='black', linewidth=2)
ax2.set_ylabel('Relative Speed', fontweight='bold')
ax2.set_title('Speed Comparison', fontweight='bold')
ax2.set_xticks(range(len(strategies)))
ax2.set_xticklabels(strategies)
ax2.set_ylim(0, 110)
ax2.grid(True, alpha=0.3, axis='y')
for i, (bar, val) in enumerate(zip(ax2.patches, speed)):
    ax2.text(bar.get_x() + bar.get_width()/2, val + 2, f'{val}',
            ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

print("\nTradeoffs:")
print("- Greedy (k=1): Fastest but lowest quality")
print("- Beam (k=3-5): Best balance (MOST COMMON)")
print("- Beam (k=10): Higher quality but much slower")
print("\nGoogle Translate uses k≈5")

## 5. Modern Translation Systems

### Evolution Timeline

```
2000s: Rule-Based
  - Hand-written grammar rules
  - Dictionary lookups
  - Quality: 30%

2010-2016: Statistical (SMT)
  - Learn patterns from data
  - Phrase-based
  - Quality: 45-55%

2016-2017: Neural + Attention (NMT)
  - What you learned in this course!
  - Quality jump: 75%
  - Google Translate switched in 2016

2017-now: Transformers
  - Self-attention everywhere
  - Parallel processing
  - Quality: 85-95%
```

In [None]:
# Visualize evolution
years = [2005, 2010, 2015, 2016, 2018, 2020, 2024]
quality = [30, 45, 55, 75, 85, 92, 95]
labels = ['Rule\nBased', 'SMT', 'SMT', 'NMT\n+Attn', 'Trans\nformer', 'GPT-3', 'GPT-4']

plt.figure(figsize=(11, 5))
plt.plot(years, quality, 'o-', linewidth=3, markersize=12, color='#4ECDC4')

# Annotations
for i, (year, qual, label) in enumerate(zip(years, quality, labels)):
    if i in [0, 3, 6]:  # Key milestones
        plt.annotate(label, xy=(year, qual), 
                   xytext=(year, qual + 8 if i % 2 == 0 else qual - 12),
                   fontsize=10, fontweight='bold', ha='center',
                   arrowprops=dict(arrowstyle='->', lw=2))

# Highlight breakthrough
plt.axvspan(2015.5, 2017, alpha=0.2, color='green')
plt.text(2016.2, 65, 'Attention\nRevolution', ha='center', 
        fontsize=11, fontweight='bold', color='darkgreen')

plt.xlabel('Year', fontweight='bold', fontsize=12)
plt.ylabel('Translation Quality (%)', fontweight='bold', fontsize=12)
plt.title('Evolution of Machine Translation', fontweight='bold', fontsize=14)
plt.ylim(0, 105)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("\nKey Milestones:")
print("- 2016: Google switches to NMT (20% quality jump!)")
print("- 2017: Transformers paper ('Attention Is All You Need')")
print("- 2020+: GPT models achieve near-human quality")

## 6. Connection to ChatGPT

### The Surprising Truth

ChatGPT uses the SAME core ideas you learned:

```
Translation (2016):          ChatGPT (2024):
- Encoder-Decoder            - Decoder-only
- Attention in decoder       - Self-attention everywhere
- Beam search (k=5)          - Sampling + temperature
- 200M parameters            - 175B+ parameters
- Trained on sentences       - Trained on internet
```

**Same principles, bigger scale!**

Your knowledge powers:
- Google Translate
- ChatGPT / GPT-4
- BERT, T5, Claude
- Image captioning
- Speech recognition

## 7. Summary: Your Learning Journey

### Part 1: Basic Seq2Seq
- Encoder compresses input → context vector
- Decoder expands context → output
- Problem: **Bottleneck**

### Part 2: Attention
- Decoder looks at ALL encoder states
- Custom context for each output word
- Solution to bottleneck!

### Part 3: Beam Search
- Keep multiple paths (not just best)
- Quality vs speed tradeoff
- Used in real systems

### The Big Picture

You now understand **the foundation of modern NLP**:

```
Week 1-3: Foundations (N-grams, RNNs)
    ↓
Week 4: Seq2Seq + Attention ← YOU ARE HERE!
    ↓
Week 5: Transformers (next week!)
    ↓
Week 6+: Pre-trained models (BERT, GPT)
```

### What's Next?

**Week 5: Transformers**
- Self-attention (attention on steroids!)
- Parallel processing
- The architecture behind GPT and BERT

---

## Try It Yourself

In [None]:
# Experiment: Try different beam widths
print("Comparing different beam widths:\n")

for k in [1, 2, 5]:
    print(f"\n{'='*60}")
    print(f"Beam width k={k}:")
    result = beam_search(beam_width=k, max_length=3)
    print(f"Best: {' '.join([w for w in result if w != '<END>'])}")

print("\n" + "="*60)
print("\nObservations:")
print("- k=1 (greedy): Fastest, might miss better paths")
print("- k=2: Some exploration")
print("- k=5: More exploration, better quality")
print("\nQuestion: What happens with k=10? Try it above!")

## Congratulations!

You've completed the Week 4 lab series!

### You Now Understand:
- How neural translation works
- Why attention is crucial
- How to find better translations
- The foundation of modern AI

### Keep Learning!

Next week: **Transformers** - the architecture that changed everything.