# Week 7 Lab: Advanced Transformers - T5 and GPT-3

## Learning Objectives
- Understand and implement T5's text-to-text framework
- Explore GPT-3-style few-shot learning
- Analyze emergent abilities at different scales
- Compare different transformer architectures
- Implement efficient inference techniques

## Prerequisites
```bash
pip install transformers torch datasets evaluate accelerate
```

## Part 1: Setup and Imports

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import (
    T5ForConditionalGeneration, T5Tokenizer,
    GPT2LMHeadModel, GPT2Tokenizer,
    AutoModelForCausalLM, AutoTokenizer
)
from datasets import load_dataset
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')
if torch.cuda.is_available():
    print(f'GPU: {torch.cuda.get_device_name(0)}')
    print(f'Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB')

## Part 2: T5 - Text-to-Text Framework

T5 treats every NLP task as a text generation problem. Let's explore this unified approach.

In [None]:
# Load T5 model and tokenizer (using small version for demo)
print("Loading T5-small model...")
t5_model = T5ForConditionalGeneration.from_pretrained('t5-small').to(device)
t5_tokenizer = T5Tokenizer.from_pretrained('t5-small')

# Model info
total_params = sum(p.numel() for p in t5_model.parameters())
print(f"T5-small parameters: {total_params/1e6:.1f}M")
print(f"Vocabulary size: {t5_tokenizer.vocab_size}")

### 2.1 T5 for Multiple Tasks

Demonstrate T5's versatility across different NLP tasks using task prefixes.

In [None]:
def t5_generate(model, tokenizer, prompt, max_length=128):
    """Generate text using T5 model."""
    inputs = tokenizer(prompt, return_tensors='pt', max_length=512, truncation=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=4,
            early_stopping=True,
            temperature=0.7
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

# Test different tasks
tasks = [
    ("translate English to French: The weather is beautiful today.",
     "Translation"),
    ("summarize: The transformer architecture has revolutionized natural language processing. "
     "It uses self-attention mechanisms to process sequences in parallel, achieving state-of-the-art "
     "results on many NLP tasks. Unlike RNNs, transformers can capture long-range dependencies efficiently.",
     "Summarization"),
    ("question: What is the capital of France? context: Paris is the capital and largest city of France.",
     "Question Answering"),
    ("sentiment: This movie was absolutely fantastic! Great acting and storyline.",
     "Sentiment Analysis")
]

print("T5 Text-to-Text Examples:\n" + "="*50)
for prompt, task_name in tasks:
    result = t5_generate(t5_model, t5_tokenizer, prompt)
    print(f"\n{task_name}:")
    print(f"Input: {prompt[:60]}...")
    print(f"Output: {result}")

### 2.2 T5 Span Corruption Training

Implement T5's unique pre-training objective: span corruption.

In [None]:
def create_span_corruption_example(text, mask_ratio=0.15, mean_span_length=3):
    """Create T5-style span corruption example."""
    tokens = text.split()
    n_tokens = len(tokens)
    n_masks = int(n_tokens * mask_ratio / mean_span_length)
    
    # Generate random spans to mask
    masked_tokens = tokens.copy()
    masked_spans = []
    mask_id = 0
    
    for _ in range(n_masks):
        if len(masked_tokens) < mean_span_length:
            break
            
        start = np.random.randint(0, len(masked_tokens) - mean_span_length + 1)
        span_length = np.random.poisson(mean_span_length - 1) + 1
        span_length = min(span_length, len(masked_tokens) - start)
        
        # Extract span
        span = masked_tokens[start:start + span_length]
        masked_spans.append((f"<extra_id_{mask_id}>", ' '.join(span)))
        
        # Replace with mask token
        masked_tokens[start:start + span_length] = [f"<extra_id_{mask_id}>"]
        mask_id += 1
    
    # Create input and target
    input_text = ' '.join(masked_tokens)
    target_text = ' '.join([f"{mask} {span}" for mask, span in masked_spans])
    
    return input_text, target_text

# Example
text = "The quick brown fox jumps over the lazy dog in the sunny afternoon"
corrupted, target = create_span_corruption_example(text)

print("T5 Span Corruption Example:")
print(f"Original: {text}")
print(f"Corrupted Input: {corrupted}")
print(f"Target Output: {target}")

# Visualize the corruption process
fig, ax = plt.subplots(figsize=(12, 4))
words = text.split()
positions = np.arange(len(words))

# Highlight masked spans
colors = ['lightblue' if '<extra_id' not in w else 'salmon' 
          for w in corrupted.split()]

ax.bar(positions, [1]*len(words), color=colors, edgecolor='black')
ax.set_xticks(positions)
ax.set_xticklabels(words, rotation=45, ha='right')
ax.set_ylim([0, 1.5])
ax.set_ylabel('Token')
ax.set_title('T5 Span Corruption Visualization')
ax.axis('off')

for i, word in enumerate(words):
    ax.text(i, 0.5, word, ha='center', va='center', fontsize=10)

plt.tight_layout()
plt.show()

## Part 3: GPT-Style Few-Shot Learning

Explore in-context learning with GPT models.

In [None]:
# Load GPT-2 model (as GPT-3 proxy)
print("Loading GPT-2 model...")
gpt_model = GPT2LMHeadModel.from_pretrained('gpt2-medium').to(device)
gpt_tokenizer = GPT2Tokenizer.from_pretrained('gpt2-medium')
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token

total_params = sum(p.numel() for p in gpt_model.parameters())
print(f"GPT-2 Medium parameters: {total_params/1e6:.1f}M")

### 3.1 Zero-Shot, One-Shot, and Few-Shot Learning

In [None]:
def gpt_generate(model, tokenizer, prompt, max_new_tokens=50):
    """Generate text using GPT model."""
    inputs = tokenizer(prompt, return_tensors='pt', padding=True).to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.8,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=True,
            top_p=0.9
        )
    
    generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Return only the new tokens
    return generated[len(prompt):].strip()

# Sentiment classification examples
task = "Classify the sentiment as positive or negative."

# Zero-shot
zero_shot_prompt = f"""{task}
Text: This restaurant has amazing food and great service!
Sentiment:"""

# One-shot
one_shot_prompt = f"""{task}
Text: The movie was boring and too long.
Sentiment: negative

Text: This restaurant has amazing food and great service!
Sentiment:"""

# Few-shot
few_shot_prompt = f"""{task}
Text: The movie was boring and too long.
Sentiment: negative

Text: I love this product! Works perfectly.
Sentiment: positive

Text: Terrible experience, would not recommend.
Sentiment: negative

Text: This restaurant has amazing food and great service!
Sentiment:"""

print("In-Context Learning Comparison:\n" + "="*50)
print("\nZero-shot:")
print(f"Result: {gpt_generate(gpt_model, gpt_tokenizer, zero_shot_prompt, 5)}")

print("\nOne-shot:")
print(f"Result: {gpt_generate(gpt_model, gpt_tokenizer, one_shot_prompt, 5)}")

print("\nFew-shot:")
print(f"Result: {gpt_generate(gpt_model, gpt_tokenizer, few_shot_prompt, 5)}")

### 3.2 Analyzing Emergent Abilities

Test various tasks that emerge at different model scales.

In [None]:
# Test different emergent abilities
emergent_tasks = [
    ("Basic arithmetic", "What is 13 + 27? Answer:"),
    ("Logic puzzle", "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Answer:"),
    ("Code generation", "Write a Python function to calculate factorial:\ndef factorial(n):"),
    ("Chain of thought", "Question: A store has 23 apples. They sell 8 and buy 15 more. How many apples do they have?\nLet's think step by step:"),
]

print("Testing Emergent Abilities:\n" + "="*50)
for task_name, prompt in emergent_tasks:
    print(f"\n{task_name}:")
    print(f"Prompt: {prompt}")
    result = gpt_generate(gpt_model, gpt_tokenizer, prompt, max_new_tokens=50)
    print(f"Response: {result}")
    print("-" * 30)

## Part 4: Model Scaling Analysis

Analyze how model performance changes with scale.

In [None]:
# Simulate scaling laws (using hypothetical data)
# In practice, you would test with actual models of different sizes

# Model sizes (parameters)
model_sizes = np.logspace(7, 11, 20)  # 10M to 100B parameters

# Simulated performance metrics following scaling laws
# Loss = a * N^(-alpha) + L_inf
alpha = 0.076  # From Kaplan et al.
a = 10
L_inf = 1.69  # Irreducible loss

loss = a * model_sizes**(-alpha) + L_inf
perplexity = np.exp(loss)

# Emergent abilities (sigmoid emergence)
def emergence_curve(size, threshold, steepness=4):
    return 100 / (1 + np.exp(-steepness * (np.log10(size) - threshold)))

arithmetic = emergence_curve(model_sizes, 9.5, 3)
reasoning = emergence_curve(model_sizes, 10.5, 4)
coding = emergence_curve(model_sizes, 11, 5)

# Plotting
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Scaling law
axes[0].loglog(model_sizes, loss, 'b-', linewidth=2, label='Loss')
axes[0].axhline(y=L_inf, color='r', linestyle='--', label='Irreducible loss')
axes[0].set_xlabel('Model Parameters')
axes[0].set_ylabel('Loss')
axes[0].set_title('Scaling Laws for Language Models')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Emergent abilities
axes[1].semilogx(model_sizes, arithmetic, 'g-', linewidth=2, label='Arithmetic')
axes[1].semilogx(model_sizes, reasoning, 'b-', linewidth=2, label='Reasoning')
axes[1].semilogx(model_sizes, coding, 'r-', linewidth=2, label='Code Generation')
axes[1].set_xlabel('Model Parameters')
axes[1].set_ylabel('Task Performance (%)')
axes[1].set_title('Emergent Abilities at Scale')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Compute requirements
compute_flops = 6 * model_sizes * 1e12  # Approximate FLOPs for 1T tokens
axes[2].loglog(model_sizes, compute_flops, 'purple', linewidth=2)
axes[2].set_xlabel('Model Parameters')
axes[2].set_ylabel('Training FLOPs')
axes[2].set_title('Compute Requirements')
axes[2].grid(True, alpha=0.3)

# Add annotations for notable models
notable_models = [
    (1.75e8, 'GPT-2'),
    (1.5e9, 'GPT-2 XL'),
    (1.75e11, 'GPT-3')
]

for size, name in notable_models:
    if size <= model_sizes[-1]:
        idx = np.argmin(np.abs(model_sizes - size))
        axes[1].annotate(name, xy=(model_sizes[idx], emergence_curve(size, 10, 4)),
                        xytext=(10, 10), textcoords='offset points',
                        fontsize=8, ha='left')

plt.tight_layout()
plt.show()

## Part 5: Efficient Inference Techniques

Implement techniques for efficient inference with large models.

In [None]:
class EfficientTransformer(nn.Module):
    """Demonstration of efficiency techniques."""
    
    def __init__(self, d_model=512, n_heads=8, use_flash_attention=False):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.use_flash = use_flash_attention
        
        self.attention = nn.MultiheadAttention(d_model, n_heads)
        
    def forward(self, x, use_kv_cache=False, past_kv=None):
        """Forward pass with optional KV caching."""
        if use_kv_cache and past_kv is not None:
            # Simulate KV cache usage
            # In practice, this would concatenate past keys/values
            print("Using KV cache for faster inference")
            
        if self.use_flash:
            print("Using Flash Attention for memory efficiency")
            
        # Standard attention (simplified)
        attn_output, _ = self.attention(x, x, x)
        return attn_output

# Demonstrate quantization impact
def simulate_quantization(model_size_gb, bit_width):
    """Calculate model size after quantization."""
    original_bits = 32  # FP32
    compression_ratio = original_bits / bit_width
    quantized_size = model_size_gb / compression_ratio
    return quantized_size, compression_ratio

# Model sizes in GB (FP32)
models = {
    'GPT-3 175B': 700,  # ~700GB in FP32
    'T5-11B': 44,
    'GPT-2 1.5B': 6
}

quantization_levels = [32, 16, 8, 4]  # Bits

print("Quantization Impact on Model Size:\n" + "="*50)
results = []

for model_name, size_gb in models.items():
    print(f"\n{model_name} (Original: {size_gb} GB):")
    for bits in quantization_levels:
        q_size, ratio = simulate_quantization(size_gb, bits)
        print(f"  {bits}-bit: {q_size:.1f} GB (compression: {ratio:.1f}x)")
        results.append({
            'Model': model_name,
            'Bits': bits,
            'Size_GB': q_size,
            'Compression': ratio
        })

# Visualize quantization impact
df_quant = pd.DataFrame(results)
fig, ax = plt.subplots(figsize=(10, 6))

for model in models.keys():
    model_data = df_quant[df_quant['Model'] == model]
    ax.plot(model_data['Bits'], model_data['Size_GB'], 
            marker='o', linewidth=2, label=model)

ax.set_xlabel('Quantization Bits')
ax.set_ylabel('Model Size (GB)')
ax.set_title('Impact of Quantization on Model Size')
ax.legend()
ax.grid(True, alpha=0.3)
ax.set_xticks(quantization_levels)
plt.show()

## Part 6: Comparing Architectures

Compare encoder-only, decoder-only, and encoder-decoder architectures.

In [None]:
# Architecture comparison
architectures = {
    'Encoder-only (BERT)': {
        'strengths': ['Bidirectional context', 'Great for classification', 'Understanding tasks'],
        'weaknesses': ['Cannot generate text naturally', 'Requires task-specific heads'],
        'use_cases': ['Sentiment analysis', 'NER', 'Question answering']
    },
    'Decoder-only (GPT)': {
        'strengths': ['Natural text generation', 'Few-shot learning', 'Scales well'],
        'weaknesses': ['Unidirectional context', 'Less efficient for classification'],
        'use_cases': ['Text generation', 'Code completion', 'Creative writing']
    },
    'Encoder-Decoder (T5)': {
        'strengths': ['Flexible input/output', 'Good for seq2seq', 'Unified framework'],
        'weaknesses': ['More parameters needed', 'Complex architecture'],
        'use_cases': ['Translation', 'Summarization', 'Question generation']
    }
}

# Create comparison visualization
fig, axes = plt.subplots(1, 3, figsize=(15, 8))
colors = ['#FF6B6B', '#4ECDC4', '#95E77E']

for idx, (arch_name, arch_info) in enumerate(architectures.items()):
    ax = axes[idx]
    
    # Create simple architecture diagram
    ax.set_xlim(0, 10)
    ax.set_ylim(0, 10)
    ax.axis('off')
    
    # Title
    ax.text(5, 9, arch_name, ha='center', fontsize=12, fontweight='bold')
    
    # Strengths
    ax.text(5, 7.5, 'Strengths:', ha='center', fontsize=10, fontweight='bold')
    for i, strength in enumerate(arch_info['strengths']):
        ax.text(5, 7 - i*0.5, f'• {strength}', ha='center', fontsize=9, color='green')
    
    # Weaknesses
    ax.text(5, 5, 'Weaknesses:', ha='center', fontsize=10, fontweight='bold')
    for i, weakness in enumerate(arch_info['weaknesses']):
        ax.text(5, 4.5 - i*0.5, f'• {weakness}', ha='center', fontsize=9, color='red')
    
    # Use cases
    ax.text(5, 2.5, 'Best for:', ha='center', fontsize=10, fontweight='bold')
    for i, use_case in enumerate(arch_info['use_cases']):
        ax.text(5, 2 - i*0.5, f'• {use_case}', ha='center', fontsize=9, color='blue')
    
    # Add colored border
    rect = plt.Rectangle((0.5, 0.5), 9, 9, 
                         fill=False, edgecolor=colors[idx], linewidth=3)
    ax.add_patch(rect)

plt.suptitle('Transformer Architecture Comparison', fontsize=14, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()

## Part 7: Implementing a Simple MoE Layer

Build a basic Mixture of Experts layer to understand sparse models.

In [None]:
class SimpleMoE(nn.Module):
    """Simple Mixture of Experts layer."""
    
    def __init__(self, d_model=512, n_experts=4, expert_capacity=2):
        super().__init__()
        self.d_model = d_model
        self.n_experts = n_experts
        self.expert_capacity = expert_capacity
        
        # Router network
        self.router = nn.Linear(d_model, n_experts)
        
        # Expert networks (simple FFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_model * 4),
                nn.ReLU(),
                nn.Linear(d_model * 4, d_model)
            ) for _ in range(n_experts)
        ])
        
    def forward(self, x):
        batch_size, seq_len, d_model = x.shape
        
        # Flatten for routing
        x_flat = x.view(-1, d_model)
        
        # Compute router probabilities
        router_logits = self.router(x_flat)
        router_probs = torch.softmax(router_logits, dim=-1)
        
        # Select top-k experts per token
        _, selected_experts = torch.topk(router_probs, k=1, dim=-1)
        
        # Route tokens to experts
        output = torch.zeros_like(x_flat)
        for expert_id in range(self.n_experts):
            # Get tokens for this expert
            expert_mask = (selected_experts.squeeze() == expert_id)
            if expert_mask.any():
                expert_input = x_flat[expert_mask]
                expert_output = self.experts[expert_id](expert_input)
                output[expert_mask] = expert_output
        
        # Reshape back
        output = output.view(batch_size, seq_len, d_model)
        
        # Return output and routing info for visualization
        return output, router_probs.view(batch_size, seq_len, -1)

# Test MoE layer
moe = SimpleMoE(d_model=256, n_experts=4)
test_input = torch.randn(2, 10, 256)  # batch=2, seq=10, d_model=256

output, routing = moe(test_input)

print(f"MoE Layer Test:")
print(f"Input shape: {test_input.shape}")
print(f"Output shape: {output.shape}")
print(f"Routing shape: {routing.shape}")

# Visualize routing decisions
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Routing probabilities heatmap
routing_viz = routing[0].detach().cpu().numpy()  # First batch
im1 = axes[0].imshow(routing_viz.T, aspect='auto', cmap='YlOrRd')
axes[0].set_xlabel('Token Position')
axes[0].set_ylabel('Expert ID')
axes[0].set_title('Expert Routing Probabilities')
plt.colorbar(im1, ax=axes[0])

# Expert load distribution
expert_loads = routing.mean(dim=(0, 1)).detach().cpu().numpy()
axes[1].bar(range(len(expert_loads)), expert_loads, color=['#FF6B6B', '#4ECDC4', '#95E77E', '#FFE66D'])
axes[1].set_xlabel('Expert ID')
axes[1].set_ylabel('Average Load')
axes[1].set_title('Expert Load Balancing')
axes[1].set_xticks(range(len(expert_loads)))

plt.tight_layout()
plt.show()

print(f"\nExpert utilization:")
for i, load in enumerate(expert_loads):
    print(f"  Expert {i}: {load:.2%}")

## Part 8: Practical Exercise - Building a Multi-Task Model

Combine what we've learned to build a simple multi-task model.

In [None]:
# Exercise: Create a multi-task prompt handler
class MultiTaskPromptHandler:
    """Handle multiple NLP tasks with appropriate prompting."""
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        
    def process_task(self, task_type, input_text, few_shot_examples=None):
        """Process different task types with appropriate prompting."""
        
        # Task-specific prompt templates
        templates = {
            'classification': 'Classify the following text:\n{input}\nCategory:',
            'generation': 'Continue the following text:\n{input}\n',
            'qa': 'Answer the question based on the context.\nContext: {context}\nQuestion: {question}\nAnswer:',
            'summarization': 'Summarize the following text:\n{input}\nSummary:',
            'translation': 'Translate the following to {target_lang}:\n{input}\nTranslation:'
        }
        
        # Build prompt
        if task_type in templates:
            if task_type == 'qa':
                # Special handling for QA
                context, question = input_text.split('\n', 1)
                prompt = templates[task_type].format(context=context, question=question)
            else:
                prompt = templates[task_type].format(input=input_text)
                
            # Add few-shot examples if provided
            if few_shot_examples:
                examples_str = '\n\n'.join(few_shot_examples)
                prompt = f"{examples_str}\n\n{prompt}"
                
            return prompt
        else:
            return f"Unknown task type: {task_type}"

# Test the multi-task handler
handler = MultiTaskPromptHandler(gpt_model, gpt_tokenizer)

# Test different tasks
test_cases = [
    ('classification', 'The new smartphone has an incredible camera and battery life.'),
    ('summarization', 'Artificial intelligence has made remarkable progress in recent years. '
                      'Deep learning models now excel at tasks like image recognition, '
                      'natural language processing, and game playing. These advances '
                      'have led to practical applications in healthcare, finance, and transportation.'),
    ('generation', 'Once upon a time in a distant galaxy,')
]

print("Multi-Task Model Testing:\n" + "="*50)
for task_type, input_text in test_cases:
    print(f"\nTask: {task_type}")
    prompt = handler.process_task(task_type, input_text)
    print(f"Prompt: {prompt[:100]}...")
    
    # Generate response
    if 't5' in task_type:
        response = t5_generate(t5_model, t5_tokenizer, prompt, max_length=50)
    else:
        response = gpt_generate(gpt_model, gpt_tokenizer, prompt, max_new_tokens=30)
    
    print(f"Response: {response}")
    print("-" * 40)

## Exercises

1. **Scaling Analysis**: 
   - Load models of different sizes (GPT-2 small, medium, large) and compare their performance on the same tasks
   - Plot the relationship between model size and task performance

2. **Few-Shot Learning**:
   - Create a custom few-shot learning task with your own examples
   - Test how the number of examples (0, 1, 5, 10) affects performance

3. **T5 Fine-tuning**:
   - Fine-tune T5-small on a specific task using the datasets library
   - Compare zero-shot vs fine-tuned performance

4. **MoE Implementation**:
   - Extend the SimpleMoE class to support top-k routing (k>1)
   - Implement load balancing loss to ensure even expert utilization

5. **Efficiency Optimization**:
   - Implement KV caching for faster autoregressive generation
   - Measure the speed improvement with and without caching

## Summary

In this lab, we explored:
- T5's unified text-to-text framework and span corruption training
- GPT-style few-shot learning and emergent abilities
- Scaling laws and their implications for model performance
- Efficient inference techniques including quantization and MoE
- Practical implementation of advanced transformer concepts

Key takeaways:
1. Scale brings emergence - larger models show qualitatively different behaviors
2. Different architectures excel at different tasks
3. Efficiency techniques are crucial for deploying large models
4. Few-shot learning reduces the need for task-specific fine-tuning
5. The field is rapidly evolving with new techniques and models