# Solutions: Task 8.6 - GPT Text Generation

This notebook contains solutions to the exercises from notebook 06.

---

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time

try:
    from transformers import GPT2LMHeadModel, GPT2Tokenizer
    HAS_TRANSFORMERS = True
except ImportError:
    HAS_TRANSFORMERS = False
    print("Please install transformers: pip install transformers")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

## Exercise 1: Implement Nucleus Sampling (Top-p)

**Task:** Implement top-p (nucleus) sampling from scratch.

In [None]:
def top_p_sampling(
    model, 
    tokenizer, 
    prompt, 
    max_length=50,
    p=0.9,
    temperature=1.0,
    device='cpu'
):
    """
    Generate text using nucleus (top-p) sampling.
    
    Top-p sampling includes the smallest set of tokens whose
    cumulative probability exceeds p.
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        prompt: Starting text
        max_length: Maximum tokens to generate
        p: Cumulative probability threshold (0.9 = top 90%)
        temperature: Sampling temperature
        device: Device to use
    """
    model.eval()
    model.to(device)
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    generated = input_ids[0].tolist()
    
    eos_token_id = tokenizer.eos_token_id
    
    for _ in range(max_length):
        with torch.no_grad():
            outputs = model(torch.tensor([generated]).to(device))
            logits = outputs.logits[0, -1, :]  # Last token's logits
        
        # Apply temperature
        scaled_logits = logits / temperature
        probs = F.softmax(scaled_logits, dim=-1)
        
        # Sort by probability (descending)
        sorted_probs, sorted_indices = torch.sort(probs, descending=True)
        
        # Compute cumulative probability
        cumsum = torch.cumsum(sorted_probs, dim=-1)
        
        # Find cutoff index (first index where cumsum >= p)
        cutoff_mask = cumsum >= p
        if cutoff_mask.any():
            cutoff_idx = cutoff_mask.nonzero()[0].item() + 1
        else:
            cutoff_idx = len(sorted_probs)
        
        # Keep only tokens in the nucleus
        nucleus_probs = sorted_probs[:cutoff_idx]
        nucleus_indices = sorted_indices[:cutoff_idx]
        
        # Re-normalize
        nucleus_probs = nucleus_probs / nucleus_probs.sum()
        
        # Sample from nucleus
        selected_idx = torch.multinomial(nucleus_probs, num_samples=1).item()
        next_token = nucleus_indices[selected_idx].item()
        
        # Check for EOS
        if next_token == eos_token_id:
            break
        
        generated.append(next_token)
    
    return tokenizer.decode(generated)

# Test
if HAS_TRANSFORMERS:
    print("Loading GPT-2...")
    model = GPT2LMHeadModel.from_pretrained('gpt2')
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    
    prompt = "The future of artificial intelligence"
    
    print(f"\nPrompt: {prompt}")
    print("\nGenerated (top-p=0.9):")
    text = top_p_sampling(model, tokenizer, prompt, max_length=50, p=0.9, device=device)
    print(text)

## Exercise 2: Implement Beam Search

**Task:** Implement beam search decoding from scratch.

In [None]:
def beam_search(
    model,
    tokenizer,
    prompt,
    max_length=50,
    beam_width=5,
    length_penalty=1.0,
    device='cpu'
):
    """
    Generate text using beam search.
    
    Maintains multiple hypotheses (beams) and returns the best one.
    
    Args:
        model: Language model
        tokenizer: Tokenizer
        prompt: Starting text
        max_length: Maximum tokens to generate
        beam_width: Number of beams to maintain
        length_penalty: Penalty for longer sequences (>1 encourages longer)
        device: Device to use
    """
    model.eval()
    model.to(device)
    
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    eos_token_id = tokenizer.eos_token_id
    
    # Each beam: (sequence as list, cumulative log probability)
    beams = [(input_ids[0].tolist(), 0.0)]
    completed = []
    
    for step in range(max_length):
        all_candidates = []
        
        for seq, score in beams:
            # Get predictions
            with torch.no_grad():
                outputs = model(torch.tensor([seq]).to(device))
                logits = outputs.logits[0, -1, :]
            
            log_probs = F.log_softmax(logits, dim=-1)
            
            # Get top beam_width tokens
            top_log_probs, top_indices = log_probs.topk(beam_width)
            
            for log_prob, token_id in zip(top_log_probs, top_indices):
                new_seq = seq + [token_id.item()]
                new_score = score + log_prob.item()
                
                if token_id.item() == eos_token_id:
                    # Apply length penalty for completed sequences
                    normalized_score = new_score / (len(new_seq) ** length_penalty)
                    completed.append((new_seq, normalized_score))
                else:
                    all_candidates.append((new_seq, new_score))
        
        # Keep top beam_width candidates
        all_candidates.sort(key=lambda x: x[1], reverse=True)
        beams = all_candidates[:beam_width]
        
        # Early stopping if all beams completed
        if not beams:
            break
    
    # Add remaining beams to completed
    for seq, score in beams:
        normalized_score = score / (len(seq) ** length_penalty)
        completed.append((seq, normalized_score))
    
    # Return best sequence
    if completed:
        best_seq, best_score = max(completed, key=lambda x: x[1])
        return tokenizer.decode(best_seq), best_score
    else:
        return tokenizer.decode(beams[0][0]), beams[0][1]

# Test
if HAS_TRANSFORMERS:
    prompt = "The best way to learn programming is"
    
    print(f"Prompt: {prompt}")
    print("\nBeam search (width=5):")
    text, score = beam_search(model, tokenizer, prompt, max_length=30, beam_width=5, device=device)
    print(f"Score: {score:.4f}")
    print(text)

## Exercise 3: Compare Decoding Strategies

**Task:** Compare different decoding strategies on the same prompt.

In [None]:
def compare_decoding_strategies(model, tokenizer, prompt, device='cpu'):
    """
    Compare different text generation strategies.
    """
    print(f"Prompt: {prompt}")
    print("=" * 70)
    
    strategies = [
        ("Greedy", lambda: model.generate(
            tokenizer.encode(prompt, return_tensors='pt').to(device),
            max_new_tokens=40,
            do_sample=False
        )),
        ("Top-k (k=50)", lambda: model.generate(
            tokenizer.encode(prompt, return_tensors='pt').to(device),
            max_new_tokens=40,
            do_sample=True,
            top_k=50,
            temperature=1.0
        )),
        ("Top-p (p=0.9)", lambda: model.generate(
            tokenizer.encode(prompt, return_tensors='pt').to(device),
            max_new_tokens=40,
            do_sample=True,
            top_p=0.9,
            temperature=1.0
        )),
        ("Beam (width=5)", lambda: model.generate(
            tokenizer.encode(prompt, return_tensors='pt').to(device),
            max_new_tokens=40,
            num_beams=5,
            early_stopping=True
        )),
        ("Temperature=0.7", lambda: model.generate(
            tokenizer.encode(prompt, return_tensors='pt').to(device),
            max_new_tokens=40,
            do_sample=True,
            temperature=0.7
        )),
    ]
    
    results = []
    
    for name, generate_fn in strategies:
        torch.manual_seed(42)  # For reproducibility
        start = time.time()
        
        with torch.no_grad():
            output_ids = generate_fn()
        
        elapsed = time.time() - start
        text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        
        results.append({
            'name': name,
            'text': text,
            'time': elapsed,
            'tokens': len(output_ids[0])
        })
        
        print(f"\n{name} ({elapsed:.2f}s):")
        print(f"  {text}")
    
    # Summary
    print("\n" + "=" * 70)
    print("\nSummary:")
    print(f"{'Strategy':<20} {'Time':<10} {'Tokens':<10}")
    print("-" * 40)
    for r in results:
        print(f"{r['name']:<20} {r['time']:.3f}s    {r['tokens']}")
    
    return results

if HAS_TRANSFORMERS:
    results = compare_decoding_strategies(
        model, tokenizer,
        "In the year 2050, humans will",
        device=device
    )

## Challenge: Implement Contrastive Search

**Task:** Implement contrastive search which balances likelihood with diversity.

In [None]:
def contrastive_search(
    model,
    tokenizer,
    prompt,
    max_length=50,
    k=5,
    alpha=0.6,
    device='cpu'
):
    """
    Generate text using contrastive search.
    
    At each step, selects the token that maximizes:
    (1 - alpha) * log_prob(token) - alpha * max_sim(token, prev_tokens)
    
    This balances likelihood with diversity.
    
    Args:
        model: Language model (with hidden states)
        tokenizer: Tokenizer
        prompt: Starting text
        max_length: Maximum tokens to generate
        k: Number of top candidates to consider
        alpha: Balance between likelihood (0) and diversity (1)
        device: Device to use
    """
    model.eval()
    model.to(device)
    
    # Encode prompt
    input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    generated = input_ids[0].tolist()
    
    eos_token_id = tokenizer.eos_token_id
    
    # Store hidden states for similarity computation
    prev_hidden_states = []
    
    for step in range(max_length):
        with torch.no_grad():
            outputs = model(
                torch.tensor([generated]).to(device),
                output_hidden_states=True
            )
            logits = outputs.logits[0, -1, :]
            hidden = outputs.hidden_states[-1][0, -1, :]  # Last layer, last token
        
        # Get top-k candidates
        log_probs = F.log_softmax(logits, dim=-1)
        top_log_probs, top_indices = log_probs.topk(k)
        
        if len(prev_hidden_states) == 0:
            # First token: just use highest probability
            best_idx = 0
        else:
            # Compute scores for each candidate
            scores = []
            prev_hiddens = torch.stack(prev_hidden_states)  # (num_prev, hidden_dim)
            
            for i, token_id in enumerate(top_indices):
                # Likelihood component
                likelihood = (1 - alpha) * top_log_probs[i].item()
                
                # Diversity component: negative max similarity to previous tokens
                # (Simplified: using embedding similarity instead of full hidden states)
                token_hidden = hidden  # Would ideally get hidden for this token
                similarities = F.cosine_similarity(
                    token_hidden.unsqueeze(0),
                    prev_hiddens,
                    dim=-1
                )
                max_sim = similarities.max().item()
                diversity = -alpha * max_sim
                
                scores.append(likelihood + diversity)
            
            best_idx = max(range(len(scores)), key=lambda i: scores[i])
        
        next_token = top_indices[best_idx].item()
        
        if next_token == eos_token_id:
            break
        
        generated.append(next_token)
        prev_hidden_states.append(hidden.clone())
        
        # Limit memory
        if len(prev_hidden_states) > 50:
            prev_hidden_states = prev_hidden_states[-50:]
    
    return tokenizer.decode(generated)

# Test
if HAS_TRANSFORMERS:
    prompt = "The scientist discovered that"
    
    print(f"Prompt: {prompt}")
    print("\nContrastive search (alpha=0.6):")
    text = contrastive_search(model, tokenizer, prompt, max_length=40, alpha=0.6, device=device)
    print(text)
    
    print("\nContrastive search (alpha=0.3, more focused):")
    text = contrastive_search(model, tokenizer, prompt, max_length=40, alpha=0.3, device=device)
    print(text)

---

## Key Takeaways

1. **Top-p sampling** dynamically adjusts the candidate set based on probability mass
2. **Beam search** finds high-probability sequences but can be repetitive
3. **Temperature** controls randomness - lower is more deterministic
4. **Contrastive search** explicitly encourages diversity to reduce repetition
5. Best strategy depends on use case: creative writing vs factual generation

---