# Part 4: Model Inference Details üî¨

## üìö Learning Goals:
1. Understand `calc_surprisal_hf.py` **line-by-line**
2. See how **batching** works (processing multiple pieces at once)
3. Understand **padding** and **masking**
4. See how **target spans** extract the right surprisals
5. Understand **PPL** (Perplexity) calculation
6. Run actual code on our data!

---

## üéØ The Big Picture:

We've learned:
- Part 1: What files exist (INPUT ‚Üí OUTPUT)
- Part 2: Why 31 JSONs (context modification)
- Part 3: What is surprisal (mathematical definition)

**Now**: How does `calc_surprisal_hf.py` actually compute ALL those surprisals?

### The Script's Job:
```
INPUT:  ngram_2-contextfunc_delete.json
        (contains ~212,000 context-target pairs)
        
OUTPUT: scores.json
        (contains ~212,000 surprisal values)
        + eval.txt (perplexity)
```

**Challenge**: Can't process all 212K pieces one-by-one (too slow!)

**Solution**: **BATCHING** - process many pieces at once!

In [11]:
import torch
import json
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.nn import CrossEntropyLoss

print("‚úÖ Imports loaded")

‚úÖ Imports loaded


## üìÇ Step 1: Load Data

Let's load one of our JSON files and see what we're working with.

In [12]:
# Load the data
data_path = '../data/DC/ngram_2-contextfunc_delete.json'
article2piece = json.load(open(data_path))

print(f"Loaded: {data_path}")
print(f"Number of articles: {len(article2piece)}")
print(f"\nFirst article ID: {list(article2piece.keys())[0]}")
print(f"Number of pieces in first article: {len(article2piece['1'])}")

# Look at first 3 pieces
print("\n" + "="*80)
print("FIRST 3 PIECES FROM ARTICLE 1:")
print("="*80)
for i, (context, target) in enumerate(article2piece['1'][:3]):
    context_clean = context.replace('‚ñÅ', ' ').strip()
    target_clean = target.replace('‚ñÅ', ' ').strip()
    print(f"\nPiece {i}:")
    print(f"  Context: '{context_clean}'")
    print(f"  Target:  '{target_clean}'")

Loaded: ../data/DC/ngram_2-contextfunc_delete.json
Number of articles: 20

First article ID: 1
Number of pieces in first article: 2573

FIRST 3 PIECES FROM ARTICLE 1:

Piece 0:
  Context: ''
  Target:  'Are'

Piece 1:
  Context: 'Are'
  Target:  'tourists'

Piece 2:
  Context: 'tourists'
  Target:  'ent iced'


## ü§ñ Step 2: Load Model

Same GPT-2 model we used in Part 3.

In [13]:
# Load model
model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Check device
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
model = model.to(device)
model.eval()

# Create loss function (used to compute cross-entropy = surprisal)
loss_fct = CrossEntropyLoss(ignore_index=-100, reduction="none")

print(f"‚úÖ Loaded {model_name} on {device}")
print(f"üìä Parameters: {sum(p.numel() for p in model.parameters()):,}")

‚úÖ Loaded gpt2 on mps
üìä Parameters: 124,439,808


## üîç Step 3: Process ONE Piece (Understand the Logic)

Before batching, let's process a single piece to understand what happens.

### The Process:
1. **Tokenize context and target**
2. **Identify target span** (where target tokens are in the sequence)
3. **Run model forward pass**
4. **Extract surprisals** for target tokens only
5. **Sum them** to get word-level surprisal

In [14]:
# Take first piece
context, target = article2piece['1'][0]
context_clean = context.replace('‚ñÅ', ' ').strip()
target_clean = target.replace('‚ñÅ', ' ').strip()

print("Processing single piece:")
print(f"Context: '{context_clean}'")
print(f"Target:  '{target_clean}'")
print()

# Handle empty context (first word in sentence)
if not context_clean:
    context_for_model = "<|endoftext|>"  # GPT-2's BOS token
    text = context_for_model + target_clean
else:
    context_for_model = context_clean
    text = context_clean + target_clean

print(f"Full text: '{text}'")
print()

# Tokenize
encoded_context = tokenizer(context_for_model, return_tensors="pt")
encoded_target = tokenizer(target_clean, return_tensors="pt")
encoded_text = tokenizer(text, return_tensors="pt")

print("Tokenization:")
print(f"  Context tokens: {tokenizer.convert_ids_to_tokens(encoded_context['input_ids'][0])}")
print(f"  Target tokens:  {tokenizer.convert_ids_to_tokens(encoded_target['input_ids'][0])}")
print(f"  Full tokens:    {tokenizer.convert_ids_to_tokens(encoded_text['input_ids'][0])}")
print()

# Find target span
start = len(encoded_context['input_ids'][0])
target_len = len(encoded_target['input_ids'][0])
target_span = (start, start + target_len - 1)

print(f"Target span: {target_span}")
print(f"  ‚Üí Tokens at positions {target_span[0]} to {target_span[1]}")
print(f"  ‚Üí {tokenizer.convert_ids_to_tokens(encoded_text['input_ids'][0][target_span[0]:target_span[1]+1])}")

Processing single piece:
Context: ''
Target:  'Are'

Full text: '<|endoftext|>Are'

Tokenization:
  Context tokens: ['<|endoftext|>']
  Target tokens:  ['Are']
  Full tokens:    ['<|endoftext|>', 'Are']

Target span: (1, 1)
  ‚Üí Tokens at positions 1 to 1
  ‚Üí ['Are']


## üéØ Step 4: Compute Surprisal for This Piece

Now let's run the model and extract surprisal.

### Key Insight: Input/Output Alignment
```
INPUT:  [token‚ÇÄ, token‚ÇÅ, token‚ÇÇ, token‚ÇÉ]  ‚Üê What model sees
OUTPUT: [pred‚ÇÅ,  pred‚ÇÇ,  pred‚ÇÉ,  pred‚ÇÑ]  ‚Üê What model predicts
GOLD:   [token‚ÇÅ, token‚ÇÇ, token‚ÇÉ, token‚ÇÑ]  ‚Üê Actual next tokens
```

Model at position `i` predicts token at position `i+1`.

So we:
- **Input IDs**: All tokens except last
- **Gold IDs**: All tokens except first
- **Loss**: Compares predictions to gold

In [15]:
# Prepare input and gold
text_ids = encoded_text['input_ids'][0]
input_ids = text_ids[:-1]  # All except last
gold_ids = text_ids[1:]    # All except first

print("Input/Output alignment:")
print(f"  Text IDs:  {text_ids.tolist()}")
print(f"  Input IDs: {input_ids.tolist()}  ‚Üê Model input")
print(f"  Gold IDs:  {gold_ids.tolist()}   ‚Üê Actual next tokens")
print()

# Run model
with torch.no_grad():
    outputs = model(input_ids.unsqueeze(0).to(device))  # Add batch dimension
    logits = outputs.logits  # Shape: (1, seq_len, vocab_size)
    
print(f"Logits shape: {logits.shape}")
print(f"  ‚Üí (batch_size=1, sequence_length={logits.shape[1]}, vocabulary_size={logits.shape[2]})")
print()

# Compute loss (cross-entropy = surprisal in nats, we'll convert to bits)
loss_per_token = loss_fct(
    logits.transpose(1, 2),  # CrossEntropyLoss expects (batch, vocab, seq)
    gold_ids.unsqueeze(0).to(device)
)

print(f"Loss per token shape: {loss_per_token.shape}")
print(f"Loss values (in nats): {loss_per_token[0].cpu().tolist()}")
print()

# Extract target surprisals
# NOTE: target_span refers to text_ids positions, but loss is for input_ids
# So we need to adjust by -1
target_surprisals_nats = loss_per_token[0, target_span[0]-1:target_span[1]].cpu().tolist()
target_surprisal_total = sum(target_surprisals_nats)

print(f"Target surprisals (nats): {target_surprisals_nats}")
print(f"Total surprisal (nats):   {target_surprisal_total:.3f}")
print(f"Total surprisal (bits):   {target_surprisal_total / np.log(2):.3f}")
print()
print(f"üí° This single number ({target_surprisal_total:.3f} nats) goes into scores.json!")

Input/Output alignment:
  Text IDs:  [50256, 8491]
  Input IDs: [50256]  ‚Üê Model input
  Gold IDs:  [8491]   ‚Üê Actual next tokens

Logits shape: torch.Size([1, 1, 50257])
  ‚Üí (batch_size=1, sequence_length=1, vocabulary_size=50257)

Loss per token shape: torch.Size([1, 1])
Loss values (in nats): [7.436365127563477]

Target surprisals (nats): [7.436365127563477]
Total surprisal (nats):   7.436
Total surprisal (bits):   10.728

üí° This single number (7.436 nats) goes into scores.json!


## ü§î Wait... Why Use Loss Function Instead of Manual Calculation?

**Great question!** In Part 3, we manually:
1. Got logits
2. Applied softmax ‚Üí probabilities
3. Extracted probability of actual token
4. Computed `-log2(prob)` ‚Üí surprisal

**Here we use `CrossEntropyLoss`**. Why?

### They're the SAME thing! Just more efficient.

**CrossEntropyLoss** combines steps 2-4 into ONE operation:
```python
# What we did in Part 3 (manual):
probs = torch.softmax(logits, dim=-1)
prob_of_actual = probs[0, i, actual_token_id]
surprisal = -torch.log(prob_of_actual)  # Natural log (nats)

# What CrossEntropyLoss does (automatic):
loss = CrossEntropyLoss(logits, actual_token_id)
# Returns: -log(softmax(logits)[actual_token_id])
# This IS surprisal! (in nats)
```

### Key Differences:

| Aspect | Part 3 (Manual) | Part 4 (Loss Function) |
|--------|----------------|----------------------|
| Softmax | Explicit | Built-in (more stable) |
| Log base | log‚ÇÇ (bits) | ln (nats) |
| Speed | Slower | Faster (optimized) |
| Numerical stability | Can have issues | Very stable |
| Use case | Learning | Production |

### Why Nats vs Bits?
- **Nats**: Natural log (ln) - PyTorch default
- **Bits**: Log base 2 (log‚ÇÇ) - information theory standard
- **Conversion**: `bits = nats / ln(2) ‚âà nats √ó 1.443`

### Numerical Stability
CrossEntropyLoss is more stable because it uses the "log-sum-exp trick":
```python
# Naive (can overflow):
softmax = exp(logit_i) / sum(exp(logit_j))

# Stable (what PyTorch does):
softmax = exp(logit_i - max(logits)) / sum(exp(logit_j - max(logits)))
```

**Bottom line**: Same math, just optimized for speed and stability!

In [16]:
# Let's prove they're the same!
print("Comparing Manual vs CrossEntropyLoss:")
print("="*80)

# Take first target token
first_target_pos = target_span[0] - 1  # Adjust for input/output offset
first_target_id = gold_ids[first_target_pos].item()

print(f"Target token ID: {first_target_id}")
print(f"Target token: '{tokenizer.decode([first_target_id])}'")
print()

# Method 1: Manual (like Part 3)
print("METHOD 1: Manual Calculation (Part 3 style)")
logit_for_position = logits[0, first_target_pos]  # Shape: (vocab_size,)
probs = torch.softmax(logit_for_position, dim=-1)
prob_of_actual = probs[first_target_id]
surprisal_manual_nats = -torch.log(prob_of_actual)
surprisal_manual_bits = surprisal_manual_nats / np.log(2)

print(f"  1. Logits shape: {logit_for_position.shape}")
print(f"  2. After softmax ‚Üí probabilities")
print(f"  3. Probability of actual token: {prob_of_actual:.6f}")
print(f"  4. Surprisal (nats): -log({prob_of_actual:.6f}) = {surprisal_manual_nats:.4f}")
print(f"  5. Surprisal (bits): {surprisal_manual_bits:.4f}")
print()

# Method 2: CrossEntropyLoss
print("METHOD 2: CrossEntropyLoss (Part 4 style)")
surprisal_loss = loss_per_token[0, first_target_pos]
surprisal_loss_bits = surprisal_loss / np.log(2)

print(f"  1. CrossEntropyLoss directly computes: {surprisal_loss:.4f} nats")
print(f"  2. Convert to bits: {surprisal_loss_bits:.4f}")
print()

# Compare
print("COMPARISON:")
print(f"  Manual (nats):  {surprisal_manual_nats:.6f}")
print(f"  Loss (nats):    {surprisal_loss:.6f}")
print(f"  Difference:     {abs(surprisal_manual_nats - surprisal_loss):.10f}")
print()
print(f"  Manual (bits):  {surprisal_manual_bits:.6f}")
print(f"  Loss (bits):    {surprisal_loss_bits:.6f}")
print()
print("‚úÖ They're the SAME! (tiny difference due to floating point precision)")
print("   CrossEntropyLoss just does it faster and more stably!")

Comparing Manual vs CrossEntropyLoss:
Target token ID: 8491
Target token: 'Are'

METHOD 1: Manual Calculation (Part 3 style)
  1. Logits shape: torch.Size([50257])
  2. After softmax ‚Üí probabilities
  3. Probability of actual token: 0.000589
  4. Surprisal (nats): -log(0.000589) = 7.4364
  5. Surprisal (bits): 10.7284

METHOD 2: CrossEntropyLoss (Part 4 style)
  1. CrossEntropyLoss directly computes: 7.4364 nats
  2. Convert to bits: 10.7284

COMPARISON:
  Manual (nats):  7.436365
  Loss (nats):    7.436365
  Difference:     0.0000000000

  Manual (bits):  10.728407
  Loss (bits):    10.728407

‚úÖ They're the SAME! (tiny difference due to floating point precision)
   CrossEntropyLoss just does it faster and more stably!


## üì¶ Step 5: Understanding Batching

Processing one piece at a time is **SLOW**. We have 212,000 pieces!

### The Problem:
- Pieces have different lengths
- GPUs are fast when processing same-length sequences together

### The Solution: Padding
```
Piece 1: [45, 23, 12, 89]           ‚Üí length 4
Piece 2: [12, 34]                   ‚Üí length 2
Piece 3: [78, 90, 23, 45, 67, 11]  ‚Üí length 6

Pad to max length (6):
Piece 1: [45, 23, 12, 89, -100, -100]  ‚Üê padded
Piece 2: [12, 34, -100, -100, -100, -100]
Piece 3: [78, 90, 23, 45, 67, 11]
```

Then use **attention mask** to ignore padding positions!

In [17]:
# Demonstrate batching with 3 pieces
pieces = article2piece['1'][:3]
text_ids_list = []
target_spans = []

for context, target in pieces:
    context_clean = context.replace('‚ñÅ', ' ').strip()
    target_clean = target.replace('‚ñÅ', ' ').strip()
    
    if context_clean:
        text = context_clean + target_clean
    else:
        text = "<|endoftext|>" + target_clean
    
    encoded_context = tokenizer(context_clean if context_clean else "<|endoftext|>", return_tensors="pt")
    encoded_target = tokenizer(target_clean, return_tensors="pt")
    encoded_text = tokenizer(text, return_tensors="pt")
    
    start = len(encoded_context['input_ids'][0])
    target_len = len(encoded_target['input_ids'][0])
    target_span = (start, start + target_len - 1)
    
    text_ids_list.append(encoded_text['input_ids'][0])
    target_spans.append(target_span)

print("3 pieces (different lengths):")
for i, ids in enumerate(text_ids_list):
    print(f"  Piece {i}: length {len(ids)} ‚Üí {ids.tolist()}")
print()

# Prepare for batching
input_ids_list = [ids[:-1] for ids in text_ids_list]
gold_ids_list = [ids[1:] for ids in text_ids_list]

# Pad sequences
pad_id = -100
padded_input = torch.nn.utils.rnn.pad_sequence(
    input_ids_list, batch_first=True, padding_value=pad_id
)
padded_gold = torch.nn.utils.rnn.pad_sequence(
    gold_ids_list, batch_first=True, padding_value=pad_id
)

print("After padding:")
print(f"Padded input shape: {padded_input.shape}")
print(f"  Piece 0: {padded_input[0].tolist()}")
print(f"  Piece 1: {padded_input[1].tolist()}")
print(f"  Piece 2: {padded_input[2].tolist()}")
print()

# Create attention mask (1 for real tokens, 0 for padding)
attention_mask = (padded_input > -1).int()
print("Attention mask:")
print(f"  Piece 0: {attention_mask[0].tolist()}")
print(f"  Piece 1: {attention_mask[1].tolist()}")
print(f"  Piece 2: {attention_mask[2].tolist()}")
print()
print("üí° Model will IGNORE positions with mask=0!")

3 pieces (different lengths):
  Piece 0: length 2 ‚Üí [50256, 8491]
  Piece 1: length 4 ‚Üí [32, 1186, 454, 1023]
  Piece 2: length 6 ‚Üí [83, 454, 1023, 298, 220, 3711]

After padding:
Padded input shape: torch.Size([3, 5])
  Piece 0: [50256, -100, -100, -100, -100]
  Piece 1: [32, 1186, 454, -100, -100]
  Piece 2: [83, 454, 1023, 298, 220]

Attention mask:
  Piece 0: [1, 0, 0, 0, 0]
  Piece 1: [1, 1, 1, 0, 0]
  Piece 2: [1, 1, 1, 1, 1]

üí° Model will IGNORE positions with mask=0!


## üöÄ Step 6: Batched Forward Pass

Now let's run the model on all 3 pieces at once!

In [8]:
# Replace padding with 0 (model can't handle -100 as input)
padded_input_clean = torch.where(padded_input == pad_id, 0, padded_input)
padded_gold_clean = torch.where(padded_gold == pad_id, 0, padded_gold)

print("Running batched forward pass...")
with torch.no_grad():
    outputs = model(
        input_ids=padded_input_clean.to(device),
        attention_mask=attention_mask.to(device)
    )
    logits = outputs.logits

print(f"Logits shape: {logits.shape}")
print(f"  ‚Üí (batch_size={logits.shape[0]}, max_length={logits.shape[1]}, vocab_size={logits.shape[2]})")
print()

# Compute loss for all pieces at once
batched_losses = loss_fct(
    logits.transpose(1, 2),
    padded_gold_clean.to(device)
)

print(f"Batched losses shape: {batched_losses.shape}")
print(f"  ‚Üí (batch_size={batched_losses.shape[0]}, max_length={batched_losses.shape[1]})")
print()

# Extract surprisals for each piece's target
print("Extracting surprisals for target spans:")
for i, span in enumerate(target_spans):
    target_losses = batched_losses[i, span[0]-1:span[1]].cpu().tolist()
    total_surprisal = sum(target_losses)
    print(f"  Piece {i}: span={span} ‚Üí surprisal={total_surprisal:.3f} nats")

print("\n‚úÖ This is how batching speeds up computation!")
print("   Process 50-100 pieces at once instead of one-by-one!")

Running batched forward pass...
Logits shape: torch.Size([3, 5, 50257])
  ‚Üí (batch_size=3, max_length=5, vocab_size=50257)

Batched losses shape: torch.Size([3, 5])
  ‚Üí (batch_size=3, max_length=5)

Extracting surprisals for target spans:
  Piece 0: span=(1, 1) ‚Üí surprisal=7.436 nats
  Piece 1: span=(1, 3) ‚Üí surprisal=29.196 nats
  Piece 2: span=(3, 5) ‚Üí surprisal=26.908 nats

‚úÖ This is how batching speeds up computation!
   Process 50-100 pieces at once instead of one-by-one!


## üìä Step 7: Perplexity (PPL) Calculation

The script also outputs `eval.txt` with a **Perplexity** value.

### What is Perplexity?
- Measures how "surprised" the model is on average
- Lower PPL = better predictions

### Formula:
$$
\text{PPL} = \exp\left(\frac{1}{N} \sum_{i=1}^{N} \text{loss}_i\right)
$$

Where:
- $N$ = total number of tokens
- $\text{loss}_i$ = cross-entropy loss for token $i$ (in nats)

### Intuition:
- PPL = 1: Perfect prediction (no surprise)
- PPL = 100: On average, model is choosing from ~100 equally likely words
- PPL = 1000: Much more uncertain

**Note**: PPL is computed over ALL tokens (not just targets), but for simplicity we'll use target tokens here.

In [9]:
# Collect all losses from our 3 pieces
all_losses = []
for i, span in enumerate(target_spans):
    target_losses = batched_losses[i, span[0]-1:span[1]].cpu().tolist()
    all_losses.extend(target_losses)

# Compute PPL
mean_loss = np.mean(all_losses)
perplexity = np.exp(mean_loss)

print("Perplexity calculation:")
print(f"  Total tokens: {len(all_losses)}")
print(f"  Mean loss (nats): {mean_loss:.3f}")
print(f"  Perplexity: {perplexity:.2f}")
print()
print(f"üí° PPL={perplexity:.2f} means model is on average choosing from ~{perplexity:.0f} equally likely words")
print("   (Lower is better - more confident predictions)")

Perplexity calculation:
  Total tokens: 7
  Mean loss (nats): 9.077
  Perplexity: 8753.64

üí° PPL=8753.64 means model is on average choosing from ~8754 equally likely words
   (Lower is better - more confident predictions)


## üîÑ Step 8: The Full Loop

In the actual script, this process repeats for:
- **All articles** in the JSON file
- **All pieces** in each article
- Using **batches of 50** (default) for efficiency

### Pseudocode:
```python
for article in article2piece:
    pieces = article2piece[article]
    
    # Process in batches of 50
    for batch_start in range(0, len(pieces), 50):
        batch = pieces[batch_start:batch_start+50]
        
        # 1. Tokenize all pieces in batch
        # 2. Pad to same length
        # 3. Run model forward pass
        # 4. Extract surprisals for each piece
        # 5. Save to article2scores[article]
    
# Save all results
json.dump(article2scores, open('scores.json', 'w'))
json.dump({'PPL': perplexity}, open('eval.txt', 'w'))
```

## üéì Advanced: Sorting by Length

One optimization in the script: **sort pieces by length before batching**.

### Why?
- Pieces of similar length ‚Üí less padding needed
- Less padding ‚Üí faster computation, less memory

### Example:
**Without sorting:**
```
Batch 1: [length 50, length 10, length 45] ‚Üí max=50, lots of padding
Batch 2: [length 12, length 48, length 15] ‚Üí max=48, lots of padding
```

**With sorting:**
```
Batch 1: [length 10, length 12, length 15] ‚Üí max=15, minimal padding
Batch 2: [length 45, length 48, length 50] ‚Üí max=50, minimal padding
```

The script tracks original indices to restore order at the end!

In [10]:
# Demonstrate sorting optimization
print("Original order (by article position):")
for i, ids in enumerate(text_ids_list):
    print(f"  Piece {i}: length {len(ids)}")
print()

# Sort by length
ids_span_idx = [(ids, span, i) for i, (ids, span) in enumerate(zip(text_ids_list, target_spans))]
ids_span_idx_sorted = sorted(ids_span_idx, key=lambda x: len(x[0]))

print("Sorted order (by length):")
for ids, span, orig_idx in ids_span_idx_sorted:
    print(f"  Piece {orig_idx}: length {len(ids)}")
print()

print("üí° After processing, results are restored to original order using indices!")

Original order (by article position):
  Piece 0: length 2
  Piece 1: length 4
  Piece 2: length 6

Sorted order (by length):
  Piece 0: length 2
  Piece 1: length 4
  Piece 2: length 6

üí° After processing, results are restored to original order using indices!


## üìù Summary: What Did We Learn?

### 1. **The Script's Flow**:
```
Load JSON ‚Üí For each article ‚Üí For each batch:
  1. Tokenize pieces
  2. Identify target spans
  3. Pad to same length
  4. Create attention mask
  5. Run model forward pass
  6. Compute cross-entropy loss
  7. Extract target surprisals
  8. Sum per piece
‚Üí Save to scores.json + eval.txt
```

### 2. **Key Concepts**:
- **Input/Output alignment**: Input = text[:-1], Gold = text[1:]
- **Target spans**: Where to extract surprisals from loss tensor
- **Batching**: Process multiple pieces together for speed
- **Padding**: Align sequences to same length (using -100)
- **Attention mask**: Tell model to ignore padding positions
- **Perplexity**: Average uncertainty = exp(mean_loss)

### 3. **Why This Matters**:
- Processes 212,000 pieces in reasonable time (~1-2 hours on GPU)
- Outputs scores.json: article ‚Üí list of surprisals
- These surprisals feed into statistical analysis (Part 5!)

### 4. **The Numbers**:
| Metric | Value |
|--------|-------|
| Total pieces | ~212,000 |
| Batch size | 50 |
| Number of batches | ~4,240 |
| Loss function | CrossEntropyLoss (nats) |
| Output surprisal | Summed per word |

---

## ‚úÖ Ready for Part 5?

**Next**: Statistical analysis with `dundee.py` - how do we use these surprisals to compute PPP?

### Check Your Understanding:
1. Why do we use batching? **‚Üí Speed up computation (process multiple pieces at once)**
2. What is padding? **‚Üí Adding -100 to align sequences to same length**
3. Why input_ids[:-1]? **‚Üí Model at position i predicts token i+1**
4. What is a target span? **‚Üí Positions where target tokens are located**
5. What is perplexity? **‚Üí exp(mean_loss) - measures average uncertainty**
6. Why sort by length? **‚Üí Minimize padding, faster computation**

Tell me when you're ready for Part 5! üöÄ