# ELECTRA: Efficiently Learning an Encoder that Classifies Token Replacements Accurately

**Rank**: #2 - Revolutionary Impact

## Background & Motivation

BERT's Masked Language Modeling (MLM) has a fundamental inefficiency: only 15% of tokens are masked, so the model only learns from a small fraction of the input at each step. This means BERT needs enormous amounts of compute to reach good performance.

**The Problem with MLM:**
- Only 15% of tokens contribute to the loss
- 85% of computation is "wasted" on unmasked tokens
- Need massive datasets and compute for good results
- Small models significantly underperform large ones

**ELECTRA's Innovation:**
- Learn from **ALL** tokens, not just 15%
- Replace MLM with "Replaced Token Detection" (RTD)
- Use a generator-discriminator setup (like GANs)
- 4x more sample efficient than BERT

## What You'll Learn:
1. **Generator-Discriminator Architecture**: How ELECTRA uses two models
2. **Replaced Token Detection**: The core task that replaces MLM
3. **Sample Efficiency**: Why ELECTRA learns faster
4. **Implementation**: Building ELECTRA from scratch
5. **Results**: Why small ELECTRA models rival large BERT

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import sys
from collections import defaultdict
import random

np.random.seed(42)
random.seed(42)

# Set style for better visualizations
try:
    plt.style.use('seaborn-v0_8-darkgrid')
except OSError:
    try:
        plt.style.use('seaborn-darkgrid') 
    except OSError:
        plt.style.use('default')
        
print("ELECTRA: Efficiently Learning an Encoder that Classifies Token Replacements Accurately")
print("Paper: Clark et al., 2020 - Google Research & Stanford")
print("Impact: 4x more efficient than BERT, small models match large BERT performance")

## Part 1: Understanding the Core Problem with MLM

Let's visualize why BERT's MLM is inefficient and how ELECTRA solves it.

In [None]:
def visualize_mlm_inefficiency():
    """
    Show the fundamental inefficiency of Masked Language Modeling
    """
    
    # Example sentence
    sentence = "The quick brown fox jumps over the lazy dog".split()
    n_tokens = len(sentence)
    
    # MLM: Only 15% of tokens are masked
    mask_prob = 0.15
    num_masked = int(n_tokens * mask_prob)
    
    # Randomly select tokens to mask
    masked_positions = np.random.choice(n_tokens, num_masked, replace=False)
    
    # Create visualization
    fig, axes = plt.subplots(3, 1, figsize=(14, 10))
    
    # 1. Original sentence
    colors_original = ['lightblue'] * n_tokens
    bars1 = axes[0].bar(range(n_tokens), [1] * n_tokens, color=colors_original)
    axes[0].set_title('Original Sentence', fontsize=14, fontweight='bold')
    axes[0].set_ylim(0, 1.5)
    axes[0].set_xticks(range(n_tokens))
    axes[0].set_xticklabels(sentence, rotation=45)
    axes[0].set_ylabel('Token Status')
    
    # Add text labels
    for i, word in enumerate(sentence):
        axes[0].text(i, 0.5, word, ha='center', va='center', fontweight='bold')
    
    # 2. BERT MLM - only some tokens contribute to learning
    colors_mlm = ['red' if i in masked_positions else 'lightgray' for i in range(n_tokens)]
    bars2 = axes[1].bar(range(n_tokens), [1] * n_tokens, color=colors_mlm)
    axes[1].set_title(f'BERT MLM: Only {num_masked}/{n_tokens} tokens ({mask_prob*100:.0f}%) contribute to learning', 
                     fontsize=14, fontweight='bold')
    axes[1].set_ylim(0, 1.5)
    axes[1].set_xticks(range(n_tokens))
    axes[1].set_xticklabels(['[MASK]' if i in masked_positions else word 
                           for i, word in enumerate(sentence)], rotation=45)
    axes[1].set_ylabel('Learning Signal')
    
    # Add legend
    axes[1].bar([], [], color='red', label='Contributes to Loss')
    axes[1].bar([], [], color='lightgray', label='No Learning Signal')
    axes[1].legend()
    
    # 3. ELECTRA RTD - all tokens contribute to learning
    colors_electra = ['green'] * n_tokens  # All tokens contribute
    bars3 = axes[2].bar(range(n_tokens), [1] * n_tokens, color=colors_electra)
    axes[2].set_title(f'ELECTRA RTD: All {n_tokens}/{n_tokens} tokens (100%) contribute to learning!', 
                     fontsize=14, fontweight='bold')
    axes[2].set_ylim(0, 1.5)
    axes[2].set_xticks(range(n_tokens))
    
    # Generate some replaced tokens for visualization
    replacements = sentence.copy()
    replaced_positions = np.random.choice(n_tokens, num_masked, replace=False)
    replacement_words = ['cat', 'slow', 'red'][:len(replaced_positions)]
    
    for i, pos in enumerate(replaced_positions):
        if i < len(replacement_words):
            replacements[pos] = replacement_words[i]
    
    axes[2].set_xticklabels(replacements, rotation=45)
    axes[2].set_ylabel('Learning Signal')
    
    plt.tight_layout()
    plt.show()
    
    # Calculate efficiency
    mlm_efficiency = mask_prob
    electra_efficiency = 1.0
    
    print(f"\nSAMPLE EFFICIENCY COMPARISON:")
    print(f"BERT MLM: {mlm_efficiency:.1%} of tokens provide learning signal")
    print(f"ELECTRA RTD: {electra_efficiency:.1%} of tokens provide learning signal")
    print(f"\nELECTRA is {electra_efficiency/mlm_efficiency:.1f}x more sample efficient!")
    
    return masked_positions, replaced_positions

masked_pos, replaced_pos = visualize_mlm_inefficiency()

## Part 2: ELECTRA's Generator-Discriminator Architecture

ELECTRA uses two models working together, similar to GANs but adapted for language.

## Understanding Generator vs Discriminator in ELECTRA

ELECTRA uses **two separate neural networks** that work together, similar to Generative Adversarial Networks (GANs) but adapted for language understanding.

### 🎭 **The Generator (The \"Faker\")**

**What it does:** The generator is like BERT - it tries to predict masked tokens.

**Role:** Creates **plausible but wrong** token replacements to challenge the discriminator.

**Architecture:**
- **Smaller model** (like BERT-Small: 14M parameters)
- **Task**: Masked Language Modeling (MLM) - same as BERT
- **Input**: Sentence with some tokens masked as `[MASK]`
- **Output**: Predictions for what the masked tokens should be

**Example:**
```
Original:  \"The quick brown fox jumps\"
Masked:    \"The [MASK] brown fox jumps\"
Generator: \"The slow brown fox jumps\"    # Predicts \"slow\" for [MASK]
```

### 🕵️ **The Discriminator (The \"Detective\")**

**What it does:** The discriminator learns to detect which tokens are \"real\" vs \"fake\".

**Role:** Learns to spot the generator's replacements, building better language understanding.

**Architecture:**
- **Larger model** (like BERT-Base: 110M parameters)
- **Task**: Replaced Token Detection (RTD) - binary classification for each token
- **Input**: Sentence with some tokens replaced by generator
- **Output**: For each token, probability it's been replaced (0 = original, 1 = replaced)

**Example:**
```
Original:      \"The quick brown fox jumps\"
After Generator: \"The slow brown fox jumps\"
Discriminator:  [0.02, 0.95, 0.01, 0.03, 0.01]  # High probability that \"slow\" is replaced
```

### 🔄 **How They Work Together**

**Step 1: Generator Creates Fakes**
- Take original sentence: `\"The quick brown fox\"`
- Mask some tokens: `\"The [MASK] brown fox\"`
- Generator predicts: `\"The slow brown fox\"`

**Step 2: Discriminator Detects Fakes**
- Gets sentence with replacements: `\"The slow brown fox\"`
- For each position, predicts: `[ORIGINAL, REPLACED, ORIGINAL, ORIGINAL]`
- Learns to distinguish real vs fake tokens

**Step 3: Both Models Improve**
- **Generator** gets better at creating plausible replacements
- **Discriminator** gets better at detecting subtle differences
- This creates an **adversarial training** dynamic

### 🎯 **Key Differences from BERT**

| Aspect | BERT (MLM) | ELECTRA Generator | ELECTRA Discriminator |
|--------|------------|-------------------|----------------------|
| **Task** | Predict masked tokens | Predict masked tokens | Detect replaced tokens |
| **Learning Signal** | 15% of tokens | 15% of tokens | **100% of tokens** |
| **Architecture** | Single model | Smaller model | Larger model |
| **Final Use** | Use for downstream | Discarded after training | **Used for downstream** |

### 💡 **Why This Design is Brilliant**

1. **More Learning Signal**: Discriminator learns from ALL tokens, not just 15%
2. **Adversarial Training**: Generator creates harder examples over time
3. **Efficiency**: Small generator can train large discriminator effectively
4. **Better Representations**: Discriminator learns fine-grained understanding

**The key insight:** Instead of wasting 85% of tokens, ELECTRA learns from every single token position!"

In [None]:
class SimpleELECTRA:\n    \"\"\"\n    Simplified ELECTRA implementation to demonstrate core concepts\n    \n    Two main components:\n    1. Generator: Small model that predicts masked tokens (like BERT)\n    2. Discriminator: Larger model that detects which tokens were replaced\n    \"\"\"\n    \n    def __init__(self, vocab_size=8192, hidden_size=192, generator_size=64):\n        self.vocab_size = vocab_size\n        self.hidden_size = hidden_size\n        self.generator_size = generator_size\n        \n        print(\"\\n=== INITIALIZING ELECTRA MODELS ===\")\n        \n        # ===== GENERATOR (\"The Faker\") =====\n        # Smaller model that predicts masked tokens (like BERT MLM)\n        print(\"\\n🎭 Generator (Faker):\")\n        print(f\"   - Size: {generator_size} hidden dimensions (smaller model)\")\n        print(f\"   - Task: Predict masked tokens (MLM)\")\n        print(f\"   - Role: Create plausible but wrong replacements\")\n        \n        self.generator_embeddings = np.random.randn(vocab_size, generator_size) * 0.02\n        self.generator_mlm_head = np.random.randn(generator_size, vocab_size) * 0.02\n        gen_params = self.generator_embeddings.size + self.generator_mlm_head.size\n        \n        # ===== DISCRIMINATOR (\"The Detective\") =====\n        # Larger model that detects which tokens are replaced\n        print(\"\\n🕵️ Discriminator (Detective):\")\n        print(f\"   - Size: {hidden_size} hidden dimensions (larger model)\")\n        print(f\"   - Task: Detect replaced tokens (RTD)\")\n        print(f\"   - Role: Learn to spot generator's fakes\")\n        \n        self.discriminator_embeddings = np.random.randn(vocab_size, hidden_size) * 0.02\n        self.discriminator_rtd_head = np.random.randn(hidden_size, 1) * 0.02\n        disc_params = self.discriminator_embeddings.size + self.discriminator_rtd_head.size\n        \n        print(f\"\\n📊 Parameter Comparison:\")\n        print(f\"   Generator parameters: {gen_params:,}\")\n        print(f\"   Discriminator parameters: {disc_params:,}\")\n        print(f\"   Ratio: {disc_params/gen_params:.1f}x larger discriminator\")\n    \n    def softmax(self, x):\n        \"\"\"Compute softmax for probability distributions\"\"\"\n        exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))\n        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)\n    \n    def sigmoid(self, x):\n        \"\"\"Compute sigmoid for binary classification\"\"\"\n        return 1 / (1 + np.exp(-np.clip(x, -500, 500)))\n    \n    def generator_step(self, input_ids, masked_positions):\n        \"\"\"\n        🎭 GENERATOR STEP: Predict masked tokens (like BERT MLM)\n        \n        Input: [\"The\", \"[MASK]\", \"brown\", \"fox\"]\n        Output: [\"The\", \"slow\", \"brown\", \"fox\"]  # Predicts \"slow\" for [MASK]\n        \n        This is essentially the same task as BERT's MLM!\n        \"\"\"\n        print(\"\\n🎭 GENERATOR STEP (Faker creates replacements):\")\n        \n        # Simple embedding lookup + linear layer\n        embeddings = self.generator_embeddings[input_ids]  # [seq_len, generator_size]\n        print(f\"   - Input embeddings shape: {embeddings.shape}\")\n        \n        # MLM predictions - NO .T transpose (this was the bug!)\n        logits = embeddings @ self.generator_mlm_head  # [seq_len, vocab_size]\n        probs = self.softmax(logits)\n        print(f\"   - Output probabilities shape: {probs.shape}\")\n        print(f\"   - Generating predictions for {len(masked_positions)} masked positions\")\n        \n        # Sample predictions for masked positions\n        generated_tokens = input_ids.copy()\n        \n        for pos in masked_positions:\n            # Sample from the probability distribution\n            generated_token = np.random.choice(self.vocab_size, p=probs[pos])\n            generated_tokens[pos] = generated_token\n            print(f\"   - Position {pos}: Generated token ID {generated_token}\")\n        \n        return generated_tokens, probs, logits\n    \n    def discriminator_step(self, corrupted_tokens, original_tokens):\n        \"\"\"\n        🕵️ DISCRIMINATOR STEP: Detect which tokens are replaced\n        \n        Input: [\"The\", \"slow\", \"brown\", \"fox\"]  # Some tokens replaced by generator\n        Output: [0.02, 0.95, 0.01, 0.03]    # Probability each token is replaced\n        \n        This is a binary classification task for EVERY token position!\n        \"\"\"\n        print(\"\\n🕵️ DISCRIMINATOR STEP (Detective finds fakes):\")\n        \n        # Embedding lookup\n        embeddings = self.discriminator_embeddings[corrupted_tokens]  # [seq_len, hidden_size]\n        print(f\"   - Input embeddings shape: {embeddings.shape}\")\n        \n        # Binary classification for each position\n        logits = embeddings @ self.discriminator_rtd_head  # [seq_len, 1]\n        logits = logits.squeeze(-1)  # [seq_len]\n        print(f\"   - Classification logits shape: {logits.shape}\")\n        \n        # Sigmoid to get probabilities (0 = original, 1 = replaced)\n        probs = self.sigmoid(logits)\n        \n        # True labels: 1 if token was replaced, 0 if original\n        labels = (corrupted_tokens != original_tokens).astype(float)\n        num_replaced = int(labels.sum())\n        print(f\"   - Detecting {num_replaced} replaced tokens out of {len(labels)} total\")\n        \n        return probs, labels, logits\n    \n    def train_step(self, input_ids, mask_prob=0.15):\n        \"\"\"\n        🔄 COMPLETE ELECTRA TRAINING STEP\n        \n        This shows how Generator and Discriminator work together:\n        1. Mask some tokens in original sentence\n        2. Generator predicts what masked tokens should be\n        3. Replace masked tokens with generator's predictions\n        4. Discriminator tries to detect which tokens were replaced\n        5. Both models learn from their respective losses\n        \"\"\"\n        print(\"\\n🔄 ELECTRA TRAINING STEP:\")\n        original_tokens = input_ids.copy()\n        \n        # Step 1: Create masked input for generator\n        masked_input = input_ids.copy()\n        mask_token_id = self.vocab_size - 1  # [MASK] token\n        \n        # Randomly select positions to mask\n        num_mask = max(1, int(len(input_ids) * mask_prob))\n        masked_positions = np.random.choice(len(input_ids), num_mask, replace=False)\n        print(f\"   Step 1: Masking {num_mask} tokens at positions {masked_positions}\")\n        \n        # Mask tokens\n        for pos in masked_positions:\n            masked_input[pos] = mask_token_id\n        \n        # Step 2: Generator predicts masked tokens\n        print(f\"   Step 2: Generator predicts masked tokens...\")\n        generated_tokens, gen_probs, gen_logits = self.generator_step(masked_input, masked_positions)\n        \n        # Step 3: Create corrupted sequence for discriminator\n        corrupted_tokens = original_tokens.copy()\n        for pos in masked_positions:\n            corrupted_tokens[pos] = generated_tokens[pos]\n        print(f\"   Step 3: Created corrupted sequence with generator predictions\")\n        \n        # Step 4: Discriminator detects replaced tokens\n        print(f\"   Step 4: Discriminator detects replaced tokens...\")\n        disc_probs, disc_labels, disc_logits = self.discriminator_step(corrupted_tokens, original_tokens)\n        \n        # Calculate losses\n        # Generator loss: MLM cross-entropy on masked positions\n        gen_loss = 0\n        for pos in masked_positions:\n            target = original_tokens[pos]\n            gen_loss += -np.log(gen_probs[pos, target] + 1e-10)\n        gen_loss /= len(masked_positions)\n        \n        # Discriminator loss: Binary cross-entropy on all positions\n        disc_loss = 0\n        for i in range(len(disc_labels)):\n            p = disc_probs[i]\n            label = disc_labels[i]\n            disc_loss += -(label * np.log(p + 1e-10) + (1-label) * np.log(1-p + 1e-10))\n        disc_loss /= len(disc_labels)\n        \n        print(f\"\\n📊 Training Results:\")\n        print(f\"   Generator loss (MLM): {gen_loss:.4f}\")\n        print(f\"   Discriminator loss (RTD): {disc_loss:.4f}\")\n        \n        return {\n            'generator_loss': gen_loss,\n            'discriminator_loss': disc_loss,\n            'original_tokens': original_tokens,\n            'corrupted_tokens': corrupted_tokens,\n            'discriminator_predictions': disc_probs,\n            'discriminator_labels': disc_labels,\n            'masked_positions': masked_positions\n        }\n\n# Demonstrate ELECTRA with detailed explanations\nprint(\"=\" * 70)\nprint(\"ELECTRA DEMONSTRATION: Generator vs Discriminator\")\nprint(\"=\" * 70)\n\nelectra = SimpleELECTRA()\n\n# Example input\nvocab = ['the', 'quick', 'brown', 'fox', 'jumps', 'over', 'lazy', 'dog', '[MASK]', '[PAD]']\nsentence_ids = np.array([0, 1, 2, 3, 4, 5, 6, 7])  # \"the quick brown fox jumps over lazy dog\"\n\nprint(f\"\\n📝 Input sentence:\")\nprint(f\"   Token IDs: {sentence_ids}\")\nprint(f\"   Words: {[vocab[i] for i in sentence_ids]}\")\n\n# Run training step\nresult = electra.train_step(sentence_ids)\n\n# Show detailed results\nprint(f\"\\n🔍 DETAILED RESULTS:\")\nprint(f\"\\n📍 Masked positions: {result['masked_positions']}\")\nprint(f\"\\n📖 Token-by-token breakdown:\")\nfor i in range(len(sentence_ids)):\n    orig_word = vocab[result['original_tokens'][i] % len(vocab)]\n    corr_word = vocab[result['corrupted_tokens'][i] % len(vocab)]\n    pred = result['discriminator_predictions'][i]\n    label = result['discriminator_labels'][i]\n    status = \"REPLACED\" if label == 1 else \"ORIGINAL\"\n    \n    marker = \"🎯\" if label == 1 else \"✅\"\n    print(f\"   {marker} Position {i}: '{orig_word}' -> '{corr_word}' | Prob: {pred:.3f} | {status}\")\n\nprint(f\"\\n🎯 Key Insight: Discriminator learns from ALL {len(sentence_ids)} tokens, not just {len(result['masked_positions'])} masked ones!\")\nprint(f\"   This is why ELECTRA is {len(sentence_ids)/len(result['masked_positions']):.1f}x more sample efficient than BERT!\")"

## Part 3: Why ELECTRA is More Sample Efficient

Let's analyze mathematically why ELECTRA learns faster than BERT.

In [None]:
def analyze_sample_efficiency():
    """
    Analyze why ELECTRA is more sample efficient than BERT
    """
    
    # Parameters
    sequence_length = 128
    mask_probability = 0.15
    num_examples = 1000
    
    # Calculate learning signals per example
    bert_signals_per_example = sequence_length * mask_probability
    electra_signals_per_example = sequence_length  # All tokens
    
    # Total learning signals
    bert_total_signals = bert_signals_per_example * num_examples
    electra_total_signals = electra_signals_per_example * num_examples
    
    print("SAMPLE EFFICIENCY ANALYSIS:")
    print(f"\nSequence length: {sequence_length} tokens")
    print(f"Number of examples: {num_examples:,}")
    print(f"\nBERT MLM:")
    print(f"  Mask probability: {mask_probability:.1%}")
    print(f"  Learning signals per example: {bert_signals_per_example:.1f}")
    print(f"  Total learning signals: {bert_total_signals:,.0f}")
    
    print(f"\nELECTRA RTD:")
    print(f"  All tokens contribute: 100%")
    print(f"  Learning signals per example: {electra_signals_per_example:.1f}")
    print(f"  Total learning signals: {electra_total_signals:,.0f}")
    
    efficiency_ratio = electra_total_signals / bert_total_signals
    print(f"\nEfficiency ratio: {efficiency_ratio:.1f}x")
    
    # Simulate learning curves
    steps = np.linspace(0, num_examples, 100)
    
    # BERT learning curve (slower due to fewer signals)
    bert_performance = 1 - np.exp(-steps * bert_signals_per_example / 10000)
    
    # ELECTRA learning curve (faster due to more signals)
    electra_performance = 1 - np.exp(-steps * electra_signals_per_example / 10000)
    
    # Visualization
    fig, axes = plt.subplots(1, 2, figsize=(15, 6))
    
    # Sample efficiency comparison
    methods = ['BERT MLM', 'ELECTRA RTD']
    signals = [bert_signals_per_example, electra_signals_per_example]
    colors = ['lightcoral', 'lightblue']
    
    bars = axes[0].bar(methods, signals, color=colors, alpha=0.8)
    axes[0].set_ylabel('Learning Signals per Example')
    axes[0].set_title('Sample Efficiency Comparison')
    axes[0].grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, value in zip(bars, signals):
        height = bar.get_height()
        axes[0].text(bar.get_x() + bar.get_width()/2., height + 2,
                    f'{value:.1f}', ha='center', va='bottom', fontweight='bold')
    
    # Learning curves
    axes[1].plot(steps, bert_performance, 'r-', linewidth=3, label='BERT MLM', alpha=0.8)
    axes[1].plot(steps, electra_performance, 'b-', linewidth=3, label='ELECTRA RTD', alpha=0.8)
    axes[1].set_xlabel('Training Examples')
    axes[1].set_ylabel('Performance')
    axes[1].set_title('Simulated Learning Curves')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)
    
    # Add annotations
    axes[1].annotate('ELECTRA reaches high\nperformance faster', 
                    xy=(300, 0.8), xytext=(500, 0.6),
                    arrowprops=dict(arrowstyle='->', color='blue', lw=2),
                    fontsize=12, color='blue', fontweight='bold')
    
    plt.tight_layout()
    plt.show()
    
    return efficiency_ratio

efficiency_gain = analyze_sample_efficiency()

# Additional analysis
print("\n" + "="*60)
print("WHY THIS MATTERS:")
print(f"\n1. Training Speed: ELECTRA needs {1/efficiency_gain:.1f}x less data")
print(f"2. Compute Cost: {efficiency_gain:.1f}x reduction in training time")
print(f"3. Model Size: Small ELECTRA can match large BERT")
print(f"4. Accessibility: Enables BERT-quality on modest hardware")
print(f"\n5. Mathematical Intuition:")
print(f"   - BERT: Loss only on {15}% of tokens")
print(f"   - ELECTRA: Loss on {100}% of tokens")
print(f"   - Result: {100/15:.1f}x more learning signal per example")

## Part 4: The Replaced Token Detection Task

Let's dive deeper into how RTD works and why it's effective.

In [None]:
def demonstrate_rtd_task():
    """
    Demonstrate the Replaced Token Detection task in detail
    """
    
    # Example sentences with different types of replacements
    examples = [
        {
            'original': 'The quick brown fox jumps over the lazy dog'.split(),
            'corrupted': 'The fast brown fox jumps over the lazy dog'.split(),
            'explanation': 'Semantic replacement: quick -> fast'
        },
        {
            'original': 'Machine learning algorithms require large datasets'.split(),
            'corrupted': 'Machine learning algorithms require purple datasets'.split(),
            'explanation': 'Nonsensical replacement: large -> purple'
        },
        {
            'original': 'The cat sat on the comfortable mat'.split(),
            'corrupted': 'The cat sat on the comfortable cat'.split(),
            'explanation': 'Repetition replacement: mat -> cat'
        }
    ]
    
    print("REPLACED TOKEN DETECTION (RTD) EXAMPLES:")
    print("="*70)
    
    fig, axes = plt.subplots(len(examples), 1, figsize=(14, 4*len(examples)))
    if len(examples) == 1:
        axes = [axes]
    
    for ex_idx, example in enumerate(examples):
        original = example['original']
        corrupted = example['corrupted']
        explanation = example['explanation']
        
        print(f"\nExample {ex_idx + 1}: {explanation}")
        print(f"Original:  {' '.join(original)}")
        print(f"Corrupted: {' '.join(corrupted)}")
        
        # Find differences
        labels = []
        for i, (orig, corr) in enumerate(zip(original, corrupted)):
            if orig != corr:
                labels.append(1)  # Replaced
                print(f"Position {i}: '{orig}' -> '{corr}' [REPLACED]")
            else:
                labels.append(0)  # Original
        
        # Visualize
        colors = ['red' if label == 1 else 'lightblue' for label in labels]
        bars = axes[ex_idx].bar(range(len(corrupted)), [1] * len(corrupted), color=colors)
        axes[ex_idx].set_title(f'Example {ex_idx + 1}: {explanation}', fontweight='bold')
        axes[ex_idx].set_ylim(0, 1.5)
        axes[ex_idx].set_xticks(range(len(corrupted)))
        axes[ex_idx].set_xticklabels(corrupted, rotation=45)
        axes[ex_idx].set_ylabel('Token Status')
        
        # Add token labels on bars
        for i, (word, label) in enumerate(zip(corrupted, labels)):
            axes[ex_idx].text(i, 0.5, word, ha='center', va='center', 
                             fontweight='bold', color='white' if label == 1 else 'black')
        
        # Add legend for first plot
        if ex_idx == 0:
            axes[ex_idx].bar([], [], color='red', label='Replaced Token')
            axes[ex_idx].bar([], [], color='lightblue', label='Original Token')
            axes[ex_idx].legend()
    
    plt.tight_layout()
    plt.show()
    
    # RTD Task Analysis
    print("\n" + "="*70)
    print("RTD TASK CHARACTERISTICS:")
    print("\n1. Binary Classification: Each token is either ORIGINAL or REPLACED")
    print("2. Contextual Understanding: Model must use context to detect anomalies")
    print("3. All Positions Matter: Every token contributes to the loss")
    print("4. Generator Quality: Better generator makes task harder (good!)")
    
    print("\nWHY RTD WORKS BETTER THAN MLM:")
    print("+ Dense learning signal (100% vs 15%)")
    print("+ Contextual reasoning required")
    print("+ Adversarial training effect")
    print("+ Encourages better representation learning")
    
demonstrate_rtd_task()

## Summary: ELECTRA's Revolutionary Impact

### **Why ELECTRA Ranks #2**

1. **Efficiency Revolution**: 4x more sample efficient than BERT
2. **Accessibility**: Small models achieve large model performance
3. **Paradigm Shift**: From generative to discriminative pre-training
4. **Practical Impact**: Enabled deployment in resource-constrained environments

### **Core Innovation Comparison**

| Aspect | BERT MLM | ELECTRA RTD |
|--------|----------|-------------|
| **Learning Signal** | 15% of tokens | 100% of tokens |
| **Task** | Generate masked tokens | Detect replaced tokens |
| **Architecture** | Single model | Generator + Discriminator |
| **Sample Efficiency** | 1x | 4x |
| **Small Model Performance** | Poor | Excellent |

### **Key Insights**

1. **Dense Learning Signal**: Every token contributes to learning
2. **Adversarial Training**: Generator-discriminator setup creates challenging examples
3. **Contextual Understanding**: Model must understand context to detect replacements
4. **Computational Efficiency**: More learning per compute unit

**ELECTRA proved that smarter training objectives can dramatically improve efficiency while maintaining quality.**