# 2) Trigram Language Model (Baseline #1)

This notebook implements a trigram language model for next-word prediction.

**Model Description:**
- Predicts next word based on previous 2 words
- Uses Maximum Likelihood Estimation (MLE): P(w_i | w_{i-2}, w_{i-1}) = Count(w_{i-2}, w_{i-1}, w_i) / Count(w_{i-2}, w_{i-1})
- Filters candidates by first letter constraint
- Falls back to bigram → unigram if trigram not seen

**Expected Performance:**
- 10K data: ~18-20% accuracy
- 100K data: ~33-37% accuracy
- 1M data: ~50-52% accuracy
- Full (3.8M) data: ~52-55% accuracy

## 2.1 Setup and Imports

In [1]:
import pandas as pd
from collections import defaultdict, Counter
from typing import List, Tuple, Dict
import time
from tqdm import tqdm

print("Imports successful!")

Imports successful!


## 2.2 Load Data

In [2]:
# Load training data
print("Loading training data...")
with open('train.src.tok', 'r', encoding='utf-8') as f:
    train_lines = [line.strip() for line in f.readlines()]

print(f"Total training sentences: {len(train_lines):,}")

# Load dev set
print("\nLoading development set...")
dev_df = pd.read_csv('dev_set.csv')
print(f"Development set size: {len(dev_df):,} predictions")
print(f"Columns: {list(dev_df.columns)}")

# Show sample
print("\nSample dev set entries:")
print(dev_df.head(3))

Loading training data...
Total training sentences: 3,803,957

Loading development set...
Development set size: 94,825 predictions
Columns: ['context', 'first letter', 'answer']

Sample dev set entries:
                                             context first letter   answer
0  south korea and the united states on monday wa...            d      day
1  after agreeing to drastically cut its car impo...            t      the
2  three soldiers were injured in a bombing ambus...            m  morning


## 2.3 Data Sampling

We'll use simple sequential sampling (first N sentences) as decided in EDA.

In [11]:
# Data sizes for experiments
DATA_SIZES = {
    'debug': 10_000,
    'dev': 100_000,
    'large': 1_000_000,
    'full': 3_803_957
}

def sample_data(train_lines: List[str], size_key: str = 'debug') -> List[str]:
    """
    Sample training data sequentially (simple, no shuffling).
    
    Args:
        train_lines: Full training corpus
        size_key: One of 'debug', 'dev', 'large', 'full'
    
    Returns:
        First N sentences from corpus
    """
    size = DATA_SIZES[size_key]
    if size >= len(train_lines):
        return train_lines
    return train_lines[:size]

# Start with debug size (10K) for fast testing
# Change to 'dev', 'large', or 'full' later
# CURRENT_SIZE = 'debug'
# CURRENT_SIZE = 'dev'
# CURRENT_SIZE = 'large'
CURRENT_SIZE = 'full'

train_data = sample_data(train_lines, CURRENT_SIZE)
print(f"Using {CURRENT_SIZE} dataset: {len(train_data):,} sentences")
print(f"\nFirst 3 training sentences:")
for i, sent in enumerate(train_data[:3]):
    print(f"{i+1}. {sent}")

Using full dataset: 3,803,957 sentences

First 3 training sentences:
1. australia ' s current account deficit shrunk by a record 1 . 11 billion dollars - lrb - 1 . 11 billion us - rrb - in the june quarter due to soaring commodity prices , figures released monday showed .
2. at least two people were killed in a suspected bomb attack on a passenger bus in the strife - torn southern philippines on monday , the military said .
3. australian shares closed down 1 . 1 percent monday following a weak lead from the united states and lower commodity prices , dealers said .


## 2.4 Trigram Model Implementation

### Key Design Decisions:

1. **Sentence boundaries**: Add `<s>` at start and `</s>` at end
   - This ensures we don't use context from previous sentence
   - Follows standard n-gram practice

2. **Backoff strategy**: Trigram → Bigram → Unigram
   - If trigram (w1, w2, ?) not seen, fall back to bigram (w2, ?)
   - If bigram not seen, fall back to unigram (?)
   - If unigram not seen, use most common word with that first letter

3. **First letter filtering**:
   - Build vocabulary index by first character
   - Only consider words starting with given first letter
   - Handles special characters (`,`, `'`, `1`, etc.)

4. **Data structures**:
   - `trigram_counts`: Dict[(w1, w2, w3)] → count
   - `bigram_counts`: Dict[(w1, w2)] → count
   - `unigram_counts`: Dict[w] → count
   - `vocab_by_first_char`: Dict[char] → List[words]

In [12]:
class TrigramModel:
    """
    Trigram language model with backoff strategy.
    
    Predicts P(w_i | w_{i-2}, w_{i-1}) using Maximum Likelihood Estimation.
    Falls back to bigram/unigram if trigram not seen.
    """
    
    def __init__(self):
        # N-gram counts
        self.trigram_counts = defaultdict(int)   # (w1, w2, w3) -> count
        self.bigram_counts = defaultdict(int)    # (w1, w2) -> count
        self.unigram_counts = defaultdict(int)   # w -> count
        
        # Context counts (for probability calculation)
        self.bigram_context_counts = defaultdict(int)  # (w1, w2) -> count
        self.unigram_context_counts = defaultdict(int)  # w1 -> count
        
        # Vocabulary indexed by first character
        self.vocab_by_first_char = defaultdict(set)  # char -> {words}
        
        # Statistics
        self.total_trigrams = 0
        self.total_bigrams = 0
        self.total_unigrams = 0
        
    def train(self, sentences: List[str]):
        """
        Train the trigram model on a list of sentences.
        
        Args:
            sentences: List of tokenized sentences (strings)
        """
        print(f"Training trigram model on {len(sentences):,} sentences...")
        start_time = time.time()
        
        for sentence in tqdm(sentences, desc="Processing sentences"):
            # Tokenize sentence
            tokens = sentence.split()
            
            # Add sentence boundaries
            # We use TWO <s> tokens at start for trigram context
            tokens = ['<s>', '<s>'] + tokens + ['</s>']
            
            # Extract n-grams and count
            for i in range(len(tokens)):
                # Unigram
                if i >= 2:  # Skip the <s> tokens
                    word = tokens[i]
                    self.unigram_counts[word] += 1
                    self.total_unigrams += 1
                    
                    # Add to vocabulary index
                    if word not in ['<s>', '</s>']:
                        first_char = word[0]
                        self.vocab_by_first_char[first_char].add(word)
                
                # Bigram
                if i >= 1:
                    bigram = (tokens[i-1], tokens[i])
                    self.bigram_counts[bigram] += 1
                    self.total_bigrams += 1
                    
                    # Count context (for probability: count(w1, w2) / count(w1))
                    if i >= 2:
                        self.unigram_context_counts[tokens[i-1]] += 1
                
                # Trigram
                if i >= 2:
                    trigram = (tokens[i-2], tokens[i-1], tokens[i])
                    self.trigram_counts[trigram] += 1
                    self.total_trigrams += 1
                    
                    # Count context (for probability: count(w1, w2, w3) / count(w1, w2))
                    context = (tokens[i-2], tokens[i-1])
                    self.bigram_context_counts[context] += 1
        
        elapsed = time.time() - start_time
        print(f"\nTraining complete in {elapsed:.2f} seconds")
        print(f"Total trigrams: {self.total_trigrams:,}")
        print(f"Total bigrams: {self.total_bigrams:,}")
        print(f"Total unigrams: {self.total_unigrams:,}")
        print(f"Unique trigrams: {len(self.trigram_counts):,}")
        print(f"Unique bigrams: {len(self.bigram_counts):,}")
        print(f"Unique unigrams: {len(self.unigram_counts):,}")
        print(f"Vocabulary size: {sum(len(words) for words in self.vocab_by_first_char.values()):,}")
    
    def get_trigram_prob(self, w1: str, w2: str, w3: str) -> float:
        """
        Calculate P(w3 | w1, w2) using MLE.
        
        Returns:
            Probability (0 if trigram never seen)
        """
        trigram = (w1, w2, w3)
        context = (w1, w2)
        
        trigram_count = self.trigram_counts.get(trigram, 0)
        context_count = self.bigram_context_counts.get(context, 0)
        
        if context_count == 0:
            return 0.0
        
        return trigram_count / context_count
    
    def get_bigram_prob(self, w1: str, w2: str) -> float:
        """
        Calculate P(w2 | w1) using MLE.
        
        Returns:
            Probability (0 if bigram never seen)
        """
        bigram = (w1, w2)
        
        bigram_count = self.bigram_counts.get(bigram, 0)
        context_count = self.unigram_context_counts.get(w1, 0)
        
        if context_count == 0:
            return 0.0
        
        return bigram_count / context_count
    
    def get_unigram_prob(self, w: str) -> float:
        """
        Calculate P(w) using MLE.
        
        Returns:
            Probability (0 if word never seen)
        """
        if self.total_unigrams == 0:
            return 0.0
        
        return self.unigram_counts.get(w, 0) / self.total_unigrams
    
    def predict(self, context: str, first_letter: str) -> str:
        """
        Predict next word given context and first letter constraint.
        
        Args:
            context: Previous words as string (e.g., "the cat sat on the")
            first_letter: Required first character of prediction
        
        Returns:
            Predicted word (most likely word starting with first_letter)
        """
        # Tokenize context and get last 2 words
        context_tokens = context.split()
        
        # Handle short contexts
        if len(context_tokens) == 0:
            w1, w2 = '<s>', '<s>'
        elif len(context_tokens) == 1:
            w1, w2 = '<s>', context_tokens[0]
        else:
            w1, w2 = context_tokens[-2], context_tokens[-1]
        
        # Get candidate words (all words starting with first_letter)
        candidates = self.vocab_by_first_char.get(first_letter, set())
        
        if not candidates:
            # No words in vocabulary start with this letter
            # This shouldn't happen with our data, but handle gracefully
            return first_letter  # Return just the letter
        
        # Score candidates using backoff strategy
        best_word = None
        best_score = -1
        
        for word in candidates:
            # Try trigram first
            score = self.get_trigram_prob(w1, w2, word)
            
            # If trigram not seen, back off to bigram
            if score == 0:
                score = self.get_bigram_prob(w2, word)
            
            # If bigram not seen, back off to unigram
            if score == 0:
                score = self.get_unigram_prob(word)
            
            # Update best
            if score > best_score:
                best_score = score
                best_word = word
        
        # If still no match, return most common word with this first letter
        if best_word is None:
            # Get most common word by unigram count
            candidates_list = list(candidates)
            best_word = max(candidates_list, 
                          key=lambda w: self.unigram_counts.get(w, 0))
        
        return best_word
    
    def evaluate(self, dev_df: pd.DataFrame, max_examples: int = None) -> Dict:
        """
        Evaluate model on development set.
        
        Args:
            dev_df: DataFrame with columns ['context', 'first letter', 'answer']
            max_examples: Optional limit on number of examples to evaluate
        
        Returns:
            Dictionary with accuracy and other metrics
        """
        print(f"\nEvaluating on development set...")
        
        if max_examples:
            dev_df = dev_df.head(max_examples)
        
        correct = 0
        total = len(dev_df)
        
        predictions = []
        
        for idx, row in tqdm(dev_df.iterrows(), total=total, desc="Predicting"):
            context = row['context']
            first_letter = row['first letter']
            answer = row['answer']
            
            # Predict
            prediction = self.predict(context, first_letter)
            predictions.append(prediction)
            
            # Check correctness
            if prediction == answer:
                correct += 1
        
        accuracy = correct / total
        
        print(f"\nResults:")
        print(f"  Total examples: {total:,}")
        print(f"  Correct: {correct:,}")
        print(f"  Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
        
        return {
            'accuracy': accuracy,
            'correct': correct,
            'total': total,
            'predictions': predictions
        }

print("TrigramModel class defined successfully!")

TrigramModel class defined successfully!


## 2.5 Train Model

Let's train on the debug dataset (10K sentences) first to test everything works.

In [13]:
# Initialize model
model = TrigramModel()

# Train
model.train(train_data)

Training trigram model on 3,803,957 sentences...


Processing sentences: 100%|██████████| 3803957/3803957 [05:56<00:00, 10658.41it/s]


Training complete in 356.90 seconds
Total trigrams: 132,086,229
Total bigrams: 135,890,186
Total unigrams: 132,086,229
Unique trigrams: 27,500,824
Unique bigrams: 6,376,191
Unique unigrams: 99,022
Vocabulary size: 99,021





## 2.6 Test Predictions

Let's test on a few manual examples before evaluating on dev set.

In [14]:
# Test examples
test_cases = [
    ("the cat sat on the", "m"),  # mat?
    ("president of the united", "s"),  # states?
    ("new york", "c"),  # city?
    ("in the", "m"),  # morning? middle?
    ("on", "m"),  # monday?
]

print("Testing predictions:\n")
for context, first_letter in test_cases:
    prediction = model.predict(context, first_letter)
    print(f"Context: '{context}'")
    print(f"First letter: '{first_letter}'")
    print(f"Prediction: {prediction}")
    print()

Testing predictions:

Context: 'the cat sat on the'
First letter: 'm'
Prediction: middle

Context: 'president of the united'
First letter: 's'
Prediction: states

Context: 'new york'
First letter: 'c'
Prediction: city

Context: 'in the'
First letter: 'm'
Prediction: middle

Context: 'on'
First letter: 'm'
Prediction: monday



## 2.7 Evaluate on Dev Set

Now let's evaluate on the full development set (all 94,825 examples).

**Note:** Evaluation takes ~30-60 seconds for 10K training data, longer for larger models.

In [15]:
# Evaluate on full dev set
# This will take ~30-60 seconds for 10K training data
# Change max_examples=1000 for quick testing
results = model.evaluate(dev_df, max_examples=None)

print(f"\nExpected accuracy for {CURRENT_SIZE} dataset: ~18-20%")
print(f"Actual accuracy: {results['accuracy']*100:.2f}%")


Evaluating on development set...


Predicting: 100%|██████████| 94825/94825 [13:56<00:00, 113.36it/s]


Results:
  Total examples: 94,825
  Correct: 55,116
  Accuracy: 0.5812 (58.12%)

Expected accuracy for full dataset: ~18-20%
Actual accuracy: 58.12%





## 2.8 Error Analysis

Let's look at some examples where the model got it right vs wrong.

In [8]:
# Add predictions to dev_df
# Note: This uses all predictions from the full dev set evaluation above
dev_sample = dev_df.copy()
dev_sample['prediction'] = results['predictions']
dev_sample['correct'] = dev_sample['prediction'] == dev_sample['answer']

# Show correct predictions
print("=" * 80)
print("CORRECT PREDICTIONS (Sample of 5)")
print("=" * 80)
correct_samples = dev_sample[dev_sample['correct']].head(5)
for idx, row in correct_samples.iterrows():
    print(f"\nContext: {row['context']}")
    print(f"First letter: '{row['first letter']}'")
    print(f"Prediction: {row['prediction']}")
    print(f"Answer: {row['answer']}")
    print(f"✓ CORRECT")

# Show incorrect predictions
print("\n" + "=" * 80)
print("INCORRECT PREDICTIONS (Sample of 5)")
print("=" * 80)
incorrect_samples = dev_sample[~dev_sample['correct']].head(5)
for idx, row in incorrect_samples.iterrows():
    print(f"\nContext: {row['context']}")
    print(f"First letter: '{row['first letter']}'")
    print(f"Prediction: {row['prediction']}")
    print(f"Answer: {row['answer']}")
    print(f"✗ INCORRECT")

CORRECT PREDICTIONS (Sample of 5)

Context: after agreeing to drastically cut its car import duties , taiwan on thursday won european union support for its bid to enter
First letter: 't'
Prediction: the
Answer: the
✓ CORRECT

Context: three soldiers were injured in a bombing ambush launched by suspect thai southern insurgents on wednesday
First letter: 'm'
Prediction: morning
Answer: morning
✓ CORRECT

Context: the movement for democracy in liberia - lrb - model - rrb - , the second largest rebel group in war - ravaged west african country has made a firm commitment to continue with the peace talks in ghana , according
First letter: 't'
Prediction: to
Answer: to
✓ CORRECT

Context: an indian court that heard a stunning confession from the lone surviving gunman in the mumbai attacks put a gag order on his latest testimony - - a message to his handlers in pakistan and
First letter: 'a'
Prediction: a
Answer: a
✓ CORRECT

Context: authorities in belarus detained two prominent russian lawma

## 2.9 Scaling Experiments

Now let's see how accuracy changes with more training data.

**Note:** This will take progressively longer:
- 10K: ~10 seconds
- 100K: ~1-2 minutes
- 1M: ~10-15 minutes
- Full (3.8M): ~40-60 minutes

In [21]:
# Run experiments on different data sizes
# Comment out sizes you don't want to run
sizes_to_test = [
    'debug',   # 10K
    'dev',     # 100K
    'large',   # 1M
    # 'full',    # 3.8M
]

scaling_results = []

for size_key in sizes_to_test:
    print("\n" + "=" * 80)
    print(f"TRAINING ON {size_key.upper()} DATASET ({DATA_SIZES[size_key]:,} sentences)")
    print("=" * 80)
    
    # Sample data
    data = sample_data(train_lines, size_key)
    
    # Train model
    model = TrigramModel()
    model.train(data)
    
    # Evaluate on full dev set
    # This takes ~30-60 seconds per model
    results = model.evaluate(dev_df, max_examples=None)
    
    # Store results
    scaling_results.append({
        'size': size_key,
        'num_sentences': len(data),
        'accuracy': results['accuracy']
    })

# Show summary
print("\n" + "=" * 80)
print("SCALING RESULTS SUMMARY")
print("=" * 80)
print(f"{'Dataset':<15} {'# Sentences':<15} {'Accuracy':<15}")
print("-" * 45)
for result in scaling_results:
    print(f"{result['size']:<15} {result['num_sentences']:<15,} {result['accuracy']*100:<14.2f}%")


TRAINING ON DEBUG DATASET (10,000 sentences)
Training trigram model on 10,000 sentences...


Processing sentences: 100%|██████████| 10000/10000 [00:00<00:00, 25352.65it/s]



Training complete in 0.40 seconds
Total trigrams: 337,976
Total bigrams: 347,976
Total unigrams: 337,976
Unique trigrams: 185,369
Unique bigrams: 102,716
Unique unigrams: 14,885
Vocabulary size: 14,884

Evaluating on development set...


Predicting: 100%|██████████| 94825/94825 [00:28<00:00, 3367.44it/s]



Results:
  Total examples: 94,825
  Correct: 41,172
  Accuracy: 0.4342 (43.42%)

TRAINING ON DEV DATASET (100,000 sentences)
Training trigram model on 100,000 sentences...


Processing sentences: 100%|██████████| 100000/100000 [00:04<00:00, 21178.54it/s]



Training complete in 4.72 seconds
Total trigrams: 3,435,020
Total bigrams: 3,535,020
Total unigrams: 3,435,020
Unique trigrams: 1,454,743
Unique bigrams: 576,344
Unique unigrams: 42,351
Vocabulary size: 42,350

Evaluating on development set...


Predicting: 100%|██████████| 94825/94825 [02:02<00:00, 774.04it/s]



Results:
  Total examples: 94,825
  Correct: 47,310
  Accuracy: 0.4989 (49.89%)

TRAINING ON LARGE DATASET (1,000,000 sentences)
Training trigram model on 1,000,000 sentences...


Processing sentences: 100%|██████████| 1000000/1000000 [00:56<00:00, 17693.93it/s]



Training complete in 56.52 seconds
Total trigrams: 34,184,775
Total bigrams: 35,184,775
Total unigrams: 34,184,775
Unique trigrams: 8,776,975
Unique bigrams: 2,433,677
Unique unigrams: 79,021
Vocabulary size: 79,020

Evaluating on development set...


Predicting: 100%|██████████| 94825/94825 [05:14<00:00, 301.50it/s]



Results:
  Total examples: 94,825
  Correct: 51,285
  Accuracy: 0.5408 (54.08%)

SCALING RESULTS SUMMARY
Dataset         # Sentences     Accuracy       
---------------------------------------------
debug           10,000          43.42         %
dev             100,000         49.89         %
large           1,000,000       54.08         %


## 2.10 Save Model (Optional)

Save the trained model for later use.

In [None]:
import pickle

# Save model
model_filename = f'trigram_model_{CURRENT_SIZE}.pkl'
with open(model_filename, 'wb') as f:
    pickle.dump(model, f)

print(f"Model saved to {model_filename}")

# To load later:
# with open(model_filename, 'rb') as f:
#     loaded_model = pickle.load(f)

OSError: [Errno 28] No space left on device

In [19]:
import sys

# Quick model size check
def get_size_mb(obj):
    """Get approximate size in MB"""
    size = sys.getsizeof(obj)
    if hasattr(obj, '__dict__'):
        for key, val in obj.__dict__.items():
            size += sys.getsizeof(val)
            if isinstance(val, dict):
                for k, v in val.items():
                    size += sys.getsizeof(k) + sys.getsizeof(v)
    return size / (1024 * 1024)

# Check model size
size_mb = get_size_mb(model)
print(f"Model size: {size_mb:.2f} MB")

if size_mb > 500:
    print("⚠️ WARNING: Model is too large to save!")
else:
    print("✓ Safe to save")

Model size: 5382.06 MB


## 2.11 Next Steps

**Current Status:**
- ✅ Trigram model implemented
- ✅ Tested on debug dataset (10K)
- ✅ Evaluated on dev set

**To improve performance:**

1. **More data**: Train on larger datasets (100K, 1M, Full)
   - Expected: ~33-37% on 100K, ~50-52% on 1M, ~52-55% on Full

2. **Better smoothing**: Add-k smoothing or Kneser-Ney
   - Current: Simple MLE with backoff
   - Improvement: +3-8% accuracy

3. **Higher-order n-grams**: 4-gram or 5-gram
   - More context → better predictions
   - Expected: +3-5% accuracy

4. **KenLM**: Use optimized library with Modified Kneser-Ney
   - Expected: 58-65% accuracy

**Next notebook:**
- `3_4gram_Model.ipynb` - Implement 4-gram baseline
- Or jump to `3_KenLM.ipynb` for best n-gram performance