# DistilBERT Step-by-Step: Processing "I love NLP"

## Setup
- Sentence: **"I love NLP"**
- Hidden dimension: **4** (simplified from 768 for education)
- **Same sentence as RNN/LSTM/BiLSTM for comparison!**

## CRITICAL DIFFERENCE: Transformer vs RNN

**DistilBERT is a TRANSFORMER, not an RNN!**

```
RNN/LSTM/BiLSTM:          DistilBERT (Transformer):
Sequential processing     Parallel processing
Hidden states             Attention mechanism
Process one at a time     Process all tokens at once
```

**Note:** We're using 4 dimensions (instead of real 768) to make comparisons easier with previous examples.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle, FancyBboxPatch, FancyArrowPatch, Circle
from matplotlib.patches import ConnectionPatch

np.random.seed(42)

# Configuration
sentence = "I love NLP"
hidden_dim = 4  # Simplified (real DistilBERT uses 768)
n_heads = 2     # Number of attention heads (real: 12)
n_layers = 2    # Number of transformer layers (real: 6)

print("Sentence:", sentence)
print(f"Hidden dimension: {hidden_dim} (simplified from 768)")
print(f"Attention heads: {n_heads} (simplified from 12)")
print(f"Transformer layers: {n_layers} (simplified from 6)")
print()
print("Key Difference from RNN/LSTM:")
print("  ✗ No sequential processing")
print("  ✗ No hidden states passed through time")
print("  ✓ All tokens processed in PARALLEL")
print("  ✓ Uses SELF-ATTENTION to look at all words at once")

## Step 1: Tokenization

Unlike simple word-level tokenization, BERT uses **WordPiece tokenization**

In [None]:
# Simplified tokenization (real BERT uses WordPiece)
# In real DistilBERT: "I love NLP" → [CLS] I love NLP [SEP]

tokens = ['[CLS]', 'I', 'love', 'NLP', '[SEP]']
vocab = {token: idx for idx, token in enumerate(tokens)}
token_ids = [vocab[t] for t in tokens]

print("Tokenization:")
print("="*60)
print(f"Input: '{sentence}'")
print(f"Tokens: {tokens}")
print(f"Token IDs: {token_ids}")
print()
print("Special tokens:")
print("  [CLS]: Classification token (for sentence-level tasks)")
print("  [SEP]: Separator token (marks end of sequence)")
print()
print(f"Sequence length: {len(tokens)}")
print("="*60)

## Step 2: Token Embeddings

Each token gets converted to a dense vector

In [None]:
# Create token embeddings (simplified)
# Real DistilBERT: 30,522 vocab × 768 dimensions
vocab_size = len(tokens)
embedding_matrix = np.array([
    [0.1, 0.2, 0.3, 0.4],  # [CLS]
    [0.5, 0.2, 0.1, 0.3],  # I
    [0.8, 0.6, 0.3, 0.5],  # love
    [0.1, 0.9, 0.7, 0.6],  # NLP
    [0.2, 0.3, 0.4, 0.5],  # [SEP]
])

# Get embeddings for our tokens
token_embeddings = embedding_matrix[token_ids]

print("Token Embeddings:")
print("="*60)
for token, emb in zip(tokens, token_embeddings):
    print(f"{token:8s}: {emb}")
print()
print(f"Shape: ({len(tokens)}, {hidden_dim})")
print("      (sequence_length, hidden_dimension)")
print("="*60)

## Step 3: Positional Encodings

**Critical for Transformers:** Since we process all tokens in parallel, we need to tell the model about token positions!

In [None]:
# Simplified positional encodings
# Real DistilBERT uses learned positional embeddings
def get_positional_encoding(seq_len, hidden_dim):
    """Simplified positional encoding"""
    position = np.arange(seq_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, hidden_dim, 2) * -(np.log(10000.0) / hidden_dim))
    
    pos_encoding = np.zeros((seq_len, hidden_dim))
    pos_encoding[:, 0::2] = np.sin(position * div_term)
    pos_encoding[:, 1::2] = np.cos(position * div_term)
    
    return pos_encoding

positional_encodings = get_positional_encoding(len(tokens), hidden_dim)

print("Positional Encodings:")
print("="*60)
for i, (token, pos_enc) in enumerate(zip(tokens, positional_encodings)):
    print(f"Position {i} ({token:8s}): {pos_enc}")
print()
print("Why needed? Without position info, 'I love NLP' = 'NLP love I'")
print("Positional encoding tells the model token ORDER")
print("="*60)

## Step 4: Input Embeddings = Token + Position

Combine token embeddings with positional encodings

In [None]:
# Add token embeddings and positional encodings
input_embeddings = token_embeddings + positional_encodings

print("Input Embeddings (Token + Position):")
print("="*60)
for token, emb in zip(tokens, input_embeddings):
    print(f"{token:8s}: {emb}")
print()
print("These embeddings now contain:")
print("  1. What the token is (semantic meaning)")
print("  2. Where the token is (position in sequence)")
print("="*60)

## Step 5: Self-Attention Mechanism

**The key innovation of Transformers!**

### How it works:
1. Each token creates **Query (Q)**, **Key (K)**, **Value (V)** vectors
2. Compute attention scores: how much each token should attend to others
3. Use scores to create weighted combination of values

### Formula:
```
Attention(Q, K, V) = softmax(Q @ K^T / √d_k) @ V
```

In [None]:
# Initialize Q, K, V weight matrices
np.random.seed(42)
W_q = np.random.randn(hidden_dim, hidden_dim) * 0.1
W_k = np.random.randn(hidden_dim, hidden_dim) * 0.1
W_v = np.random.randn(hidden_dim, hidden_dim) * 0.1

# Compute Q, K, V for all tokens
Q = input_embeddings @ W_q  # (5, 4) @ (4, 4) = (5, 4)
K = input_embeddings @ W_k
V = input_embeddings @ W_v

print("Query, Key, Value Matrices:")
print("="*60)
print("Each token has:")
print("  Query (Q):  'What am I looking for?'")
print("  Key (K):    'What do I offer?'")
print("  Value (V):  'What information do I have?'")
print()
print(f"Q shape: {Q.shape}")
print(f"K shape: {K.shape}")
print(f"V shape: {V.shape}")
print()
print("Q matrix (each row = query vector for one token):")
for token, q in zip(tokens, Q):
    print(f"  {token:8s}: {q}")
print("="*60)

## Step 6: Compute Attention Scores

In [None]:
# Compute attention scores: Q @ K^T
attention_scores = Q @ K.T / np.sqrt(hidden_dim)

print("Attention Scores (before softmax):")
print("="*60)
print("How much each token (row) attends to each token (column)")
print()
print(f"         {' '.join([f'{t:8s}' for t in tokens])}")
for i, (token, scores) in enumerate(zip(tokens, attention_scores)):
    scores_str = ' '.join([f'{s:8.3f}' for s in scores])
    print(f"{token:8s} {scores_str}")
print()
print("Higher score = more attention to that token")
print("="*60)

## Step 7: Apply Softmax to Get Attention Weights

In [None]:
def softmax(x):
    exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
    return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

# Apply softmax to get attention weights (sum to 1 for each row)
attention_weights = softmax(attention_scores)

print("Attention Weights (after softmax):")
print("="*60)
print("Normalized probabilities - each row sums to 1.0")
print()
print(f"         {' '.join([f'{t:8s}' for t in tokens])}")
for i, (token, weights) in enumerate(zip(tokens, attention_weights)):
    weights_str = ' '.join([f'{w:8.3f}' for w in weights])
    print(f"{token:8s} {weights_str}  (sum={weights.sum():.3f})")
print()
print("Example: 'love' attends most to which tokens?")
love_idx = 2
max_attention = np.argmax(attention_weights[love_idx])
print(f"  → '{tokens[love_idx]}' attends most to '{tokens[max_attention]}'")
print("="*60)

## Step 8: Compute Attention Output

In [None]:
# Multiply attention weights by values
attention_output = attention_weights @ V

print("Attention Output:")
print("="*60)
print("Weighted combination of all token values")
print()
for token, output in zip(tokens, attention_output):
    print(f"{token:8s}: {output}")
print()
print("Key insight:")
print("  Each token's output is influenced by ALL other tokens")
print("  Unlike RNN where 'I' only affects future tokens")
print("  Here 'NLP' can influence the representation of 'I'!")
print("="*60)

## Visualization 1: Attention Heatmap

In [None]:
fig, ax = plt.subplots(figsize=(10, 8))

sns.heatmap(attention_weights, annot=True, fmt='.3f', cmap='YlOrRd',
            xticklabels=tokens, yticklabels=tokens,
            cbar_kws={'label': 'Attention Weight'},
            vmin=0, vmax=1, ax=ax, linewidths=1, linecolor='black')

ax.set_xlabel('Attends TO (Keys)', fontsize=12, fontweight='bold')
ax.set_ylabel('Attends FROM (Queries)', fontsize=12, fontweight='bold')
ax.set_title('Self-Attention Matrix\nHow much each token attends to others', 
             fontsize=14, fontweight='bold', pad=20)

# Add note
fig.text(0.5, 0.02, 'Darker = more attention  |  Each row sums to 1.0', 
         ha='center', fontsize=10, style='italic')

plt.tight_layout()
plt.savefig('distilbert_attention_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'distilbert_attention_heatmap.png'")

## Visualization 2: Information Flow Diagram

In [None]:
fig, ax = plt.subplots(figsize=(14, 10))
ax.set_xlim(0, 14)
ax.set_ylim(0, 10)
ax.axis('off')

# Title
ax.text(7, 9.5, 'DistilBERT: All Tokens Attend to Each Other (Parallel)', 
        ha='center', fontsize=16, fontweight='bold')

# Token positions
token_positions = [(2, 5), (4.5, 5), (7, 5), (9.5, 5), (12, 5)]

# Draw tokens
for (x, y), token in zip(token_positions, tokens):
    circle = Circle((x, y), 0.4, color='#3498DB', ec='black', linewidth=2.5, zorder=3)
    ax.add_patch(circle)
    ax.text(x, y, token, ha='center', va='center', 
           fontsize=11, fontweight='bold', color='white')

# Draw attention connections (simplified - show strongest connections)
# Find top 2 connections for each token
for i, (token, (x1, y1)) in enumerate(zip(tokens, token_positions)):
    # Get top 2 attention weights (excluding self)
    weights = attention_weights[i].copy()
    weights[i] = 0  # Exclude self-attention for visualization
    top_indices = np.argsort(weights)[-2:]
    
    for j in top_indices:
        x2, y2 = token_positions[j]
        weight = attention_weights[i][j]
        
        # Draw arrow with thickness based on weight
        arrow = FancyArrowPatch((x1, y1), (x2, y2),
                               arrowstyle='->,head_width=0.3,head_length=0.2',
                               color='orange', linewidth=weight*5,
                               alpha=0.6, zorder=1,
                               connectionstyle="arc3,rad=0.3")
        ax.add_patch(arrow)

# Add legend
ax.text(7, 3, 'Self-Attention: Every token can attend to every other token', 
        ha='center', fontsize=12, style='italic', bbox=dict(boxstyle='round',
        facecolor='#FFF9E6', edgecolor='orange', linewidth=2))

ax.text(1, 1.5, 'Arrow thickness = attention strength', fontsize=10, style='italic')
ax.text(1, 1, 'Showing top 2 connections per token', fontsize=10, style='italic')

# Comparison text
ax.text(7, 7.5, 'vs RNN/LSTM: Sequential (I → love → NLP)', 
        ha='center', fontsize=11, color='gray', style='italic')

plt.tight_layout()
plt.savefig('distilbert_attention_flow.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'distilbert_attention_flow.png'")

## Visualization 3: Complete Architecture

In [None]:
fig, ax = plt.subplots(figsize=(12, 14))
ax.set_xlim(0, 12)
ax.set_ylim(0, 14)
ax.axis('off')

# Title
ax.text(6, 13.5, 'DistilBERT Architecture', 
        ha='center', fontsize=18, fontweight='bold')

y_pos = 12.5
box_width = 10
box_height = 0.8
x_center = 6

layers = [
    ('Input Text', '#E8F4F8', '"I love NLP"'),
    ('Tokenization', '#D5E8F0', '[CLS] I love NLP [SEP]'),
    ('Token Embeddings', '#B8E6F0', 'Each token → 4-dim vector'),
    ('+ Positional Encoding', '#A0D8E8', 'Add position information'),
    ('⬇', 'white', ''),
    ('Multi-Head Self-Attention', '#7DD3E8', 'All tokens attend to all'),
    ('+ Residual & LayerNorm', '#6BC5DD', 'Add & normalize'),
    ('⬇', 'white', ''),
    ('Feed-Forward Network', '#5AB7D2', 'Dense → ReLU → Dense'),
    ('+ Residual & LayerNorm', '#48A9C7', 'Add & normalize'),
    ('⬇', 'white', ''),
    ('Repeat 6 times', '#FFF9E6', '(We simplified to 2)'),
    ('⬇', 'white', ''),
    ('Final Representations', '#70AD47', '5 tokens × 4 dims'),
    ('Classification Head', '#5A9636', '[CLS] token → prediction'),
]

for i, (label, color, note) in enumerate(layers):
    y = y_pos - i * 0.9
    
    if label == '⬇':
        ax.text(x_center, y, '⬇', ha='center', va='center', 
               fontsize=20, color='gray')
    else:
        rect = FancyBboxPatch((x_center - box_width/2, y - box_height/2),
                             box_width, box_height,
                             boxstyle="round,pad=0.1",
                             edgecolor='black', facecolor=color,
                             linewidth=2)
        ax.add_patch(rect)
        
        ax.text(x_center, y + 0.1, label, ha='center', va='center',
               fontsize=11, fontweight='bold')
        
        if note:
            ax.text(x_center, y - 0.25, note, ha='center', va='center',
                   fontsize=8, style='italic', color='#666')

plt.tight_layout()
plt.savefig('distilbert_architecture.png', dpi=300, bbox_inches='tight')
plt.show()

print("Visualization saved as 'distilbert_architecture.png'")

## Comparison: RNN vs LSTM vs BiLSTM vs DistilBERT

In [None]:
print("\n" + "="*90)
print("COMPARISON: RNN vs LSTM vs BiLSTM vs DistilBERT (Transformer)")
print("="*90)

comparison = [
    ("Architecture", "Recurrent", "Recurrent", "Recurrent", "Transformer"),
    ("Processing", "Sequential", "Sequential", "Sequential", "Parallel"),
    ("States", "1 (h_t)", "2 (h_t, C_t)", "4 (2 dirs)", "None - uses attention"),
    ("Context", "Past only", "Past only", "Past + Future", "All tokens at once"),
    ("For 'love'", "Knows 'I'", "Knows 'I'", "Knows 'I' & 'NLP'", "Knows ALL tokens"),
    ("Attention", "None", "None", "None", "Self-attention"),
    ("Position info", "Implicit", "Implicit", "Implicit", "Explicit encoding"),
    ("Parameters", "~32", "~128", "~256", "~66M (real model)"),
    ("Speed", "Fast", "Medium", "Slow", "Fast (parallel)"),
    ("Long sequences", "Poor", "Good", "Good", "Excellent"),
    ("Vanishing grad", "Problem", "Solved", "Solved", "Not an issue"),
    ("Pre-training", "No", "No", "No", "Yes (BERT/DistilBERT)"),
]

print(f"{'Aspect':<18} {'RNN':<15} {'LSTM':<15} {'BiLSTM':<15} {'DistilBERT':<25}")
print("-"*90)
for row in comparison:
    print(f"{row[0]:<18} {row[1]:<15} {row[2]:<15} {row[3]:<15} {row[4]:<25}")
print("="*90)

print("\nKey Differences:")
print("="*90)
print("\n1. PROCESSING ORDER:")
print("   RNN/LSTM/BiLSTM: Process tokens ONE AT A TIME (sequential)")
print("   DistilBERT:      Process ALL TOKENS AT ONCE (parallel)")

print("\n2. CONTEXT AWARENESS:")
print("   RNN:        'love' only knows about 'I' (past)")
print("   LSTM:       'love' only knows about 'I' (past, better memory)")
print("   BiLSTM:     'love' knows 'I' (past) AND 'NLP' (future)")
print("   DistilBERT: 'love' ATTENDS to ALL tokens simultaneously")

print("\n3. MECHANISM:")
print("   RNN/LSTM/BiLSTM: Hidden states carry information")
print("   DistilBERT:      Self-attention computes relationships")

print("\n4. POSITION INFORMATION:")
print("   RNN/LSTM/BiLSTM: Position is implicit (order of processing)")
print("   DistilBERT:      Position is EXPLICIT (positional encodings)")

print("\n5. PRE-TRAINING:")
print("   RNN/LSTM/BiLSTM: Train from scratch for each task")
print("   DistilBERT:      Pre-trained on massive text, fine-tune for task")
print("="*90)

## Why DistilBERT is Powerful

### 1. Parallel Processing
```python
RNN:        I → love → NLP  (3 sequential steps)
DistilBERT: All tokens processed simultaneously (1 parallel step)
```

### 2. Full Context Attention
```python
At 'love':
RNN:        Knows 'I' (via hidden state)
BiLSTM:     Knows 'I' and 'NLP' (via two passes)
DistilBERT: Directly attends to 'I', 'NLP', [CLS], [SEP]
```

### 3. Self-Attention Captures Relationships
```python
attention_weights['love']['NLP'] = 0.35  # love strongly attends to NLP
attention_weights['love']['I']   = 0.28  # love also attends to I
# The model LEARNS which tokens are important for each token
```

### 4. Pre-trained Knowledge
- DistilBERT is pre-trained on billions of words
- Already knows language patterns, grammar, semantics
- Fine-tune for specific tasks (sentiment, NER, QA)

### 5. DistilBERT vs Full BERT
- 40% fewer parameters than BERT
- 60% faster
- Retains 97% of BERT's performance
- Achieved through **knowledge distillation**

## Summary: Key Concepts

### What is DistilBERT?
A **Transformer-based** model (NOT an RNN!) that uses **self-attention**

### Core Components:
```python
1. Tokenization:        "I love NLP" → [CLS] I love NLP [SEP]
2. Token Embeddings:    Each token → 768-dim vector (we used 4)
3. Positional Encoding: Add position information
4. Self-Attention:      Each token attends to all others
5. Feed-Forward:        Dense layers
6. Repeat:              6 transformer layers
7. Output:              Contextualized representations
```

### Dimensions:
```python
Input:  (sequence_length, hidden_dim) = (5, 4)
Q, K, V: (5, 4) for each
Attention weights: (5, 5) - each token to each token
Output: (5, 4) - same shape as input
```

### Real DistilBERT:
- Hidden dim: **768** (not 4)
- Attention heads: **12** (not 2)
- Layers: **6** (not 2)
- Vocabulary: **30,522** tokens
- Parameters: **66 million**

### When to Use What?

**RNN/LSTM:**
- Simple tasks
- Limited compute
- Streaming/real-time

**BiLSTM:**
- Need future context
- Sequence labeling
- Medium-length sequences

**DistilBERT:**
- Best accuracy needed
- Transfer learning
- Any NLP task with enough data
- Complete sequences available