# Proper HMM Training from Scratch

**Goal:** Train a PROPER HMM that gives good letter predictions

Key insight: Test words are NOT in corpus, so we need to learn PATTERNS, not memorize words

In [13]:
import sys
sys.path.append('../src')

import numpy as np
import pickle
from collections import Counter, defaultdict
from tqdm import tqdm

# Load corpus
with open('../Data/corpus.txt', 'r', encoding='utf-8') as f:
    corpus_words = [line.strip().lower() for line in f if line.strip()]

# Clean words
corpus_words = [''.join(c for c in word if c.isalpha()) for word in corpus_words]
corpus_words = [w for w in corpus_words if len(w) > 0]

print(f"Loaded {len(corpus_words)} words from corpus")
print(f"Sample: {corpus_words[:5]}")

Loaded 50000 words from corpus
Sample: ['suburbanize', 'asmack', 'hypotypic', 'promoderationist', 'consonantly']


## Build Better HMM Model

Focus on what actually helps:
1. Global letter frequency (for first guesses)
2. Bigrams (letter pairs)
3. Position-based frequency
4. Pattern-based predictions

In [8]:
class ImprovedHMM:
    def __init__(self):
        self.alphabet = 'abcdefghijklmnopqrstuvwxyz'
        
        # Global letter frequency
        self.global_freq = Counter()
        
        # Bigrams: char1 -> char2 -> count
        self.bigrams = defaultdict(Counter)
        
        # Position frequencies (first 20 positions)
        self.position_freq = defaultdict(Counter)
        
        # Letter patterns by word length
        self.length_patterns = defaultdict(Counter)
    
    def train(self, words):
        print("Training Improved HMM...")
        
        for word in tqdm(words, desc="Processing words"):
            # Global frequency
            for char in word:
                self.global_freq[char] += 1
            
            # Bigrams
            for i in range(len(word) - 1):
                self.bigrams[word[i]][word[i+1]] += 1
            
            # Position frequency
            for i, char in enumerate(word):
                if i < 20:
                    self.position_freq[i][char] += 1
            
            # Length patterns
            length = len(word)
            for char in set(word):
                self.length_patterns[length][char] += 1
        
        # Normalize global frequency
        total = sum(self.global_freq.values())
        self.global_freq = {c: count/total for c, count in self.global_freq.items()}
        
        print(f"✓ Training complete")
        print(f"  Top letters: {sorted(self.global_freq.items(), key=lambda x: -x[1])[:10]}")
    
    def predict_letter_probabilities(self, masked_word, guessed_letters, word_length):
        """Predict letter probabilities for current state."""
        probs = {c: 0.0 for c in self.alphabet}
        
        # Strategy 1: Use global frequency (baseline)
        for char in self.alphabet:
            if char not in guessed_letters:
                probs[char] += self.global_freq.get(char, 0.0) * 1.0
        
        # Strategy 2: Use bigrams from revealed letters
        for i, char in enumerate(masked_word):
            if char is not None:
                # Look ahead
                if i + 1 < len(masked_word) and masked_word[i+1] is None:
                    if char in self.bigrams:
                        total = sum(self.bigrams[char].values())
                        if total > 0:
                            for next_char, count in self.bigrams[char].items():
                                if next_char not in guessed_letters:
                                    probs[next_char] += (count / total) * 2.0
                
                # Look behind
                if i > 0 and masked_word[i-1] is None:
                    for prev_char in self.alphabet:
                        if prev_char not in guessed_letters and prev_char in self.bigrams:
                            if char in self.bigrams[prev_char]:
                                count = self.bigrams[prev_char][char]
                                total = sum(self.bigrams[prev_char].values())
                                if total > 0:
                                    probs[prev_char] += (count / total) * 2.0
        
        # Strategy 3: Position-based frequency for blank positions
        for i, char in enumerate(masked_word):
            if char is None and i < 20:
                if i in self.position_freq:
                    total = sum(self.position_freq[i].values())
                    if total > 0:
                        for c, count in self.position_freq[i].items():
                            if c not in guessed_letters:
                                probs[c] += (count / total) * 1.5
        
        # Strategy 4: Length-based patterns
        if word_length in self.length_patterns:
            total = sum(self.length_patterns[word_length].values())
            if total > 0:
                for c, count in self.length_patterns[word_length].items():
                    if c not in guessed_letters:
                        probs[c] += (count / total) * 0.5
        
        # Normalize
        total = sum(probs.values())
        if total > 0:
            probs = {c: p/total for c, p in probs.items()}
        
        return probs
    
    def save(self, filepath):
        with open(filepath, 'wb') as f:
            pickle.dump(self, f)
        print(f"✓ Model saved to {filepath}")
    
    @staticmethod
    def load(filepath):
        with open(filepath, 'rb') as f:
            return pickle.load(f)

print("✓ ImprovedHMM class defined")

✓ ImprovedHMM class defined


In [9]:
# Train the improved HMM
hmm = ImprovedHMM()
hmm.train(corpus_words)

Training Improved HMM...


Processing words: 100%|██████████| 50000/50000 [00:00<00:00, 179770.78it/s]

✓ Training complete
  Top letters: [('e', 0.10366177251017158), ('a', 0.0886802624817838), ('i', 0.088591813870427), ('o', 0.0754529832453059), ('r', 0.07079890155248372), ('n', 0.0701565961604879), ('t', 0.06779164876635246), ('s', 0.061164320672546395), ('l', 0.057714824829631126), ('c', 0.04573635574873856)]





In [10]:
# Test on a sample
print("Testing HMM predictions:\n")

# Test 1: Empty word
probs = hmm.predict_letter_probabilities([None]*5, set(), 5)
top = sorted(probs.items(), key=lambda x: -x[1])[:5]
print(f"Empty 5-letter word:")
print(f"  Top 5: {top}\n")

# Test 2: With 'e' revealed
masked = [None, 'e', None, None, None]
probs = hmm.predict_letter_probabilities(masked, {'e'}, 5)
top = sorted(probs.items(), key=lambda x: -x[1])[:5]
print(f"Word: _e___")
print(f"  Top 5: {top}\n")

# Test 3: With 'ing' ending
masked = [None, None, None, 'i', 'n', 'g']
probs = hmm.predict_letter_probabilities(masked, {'i', 'n', 'g'}, 6)
top = sorted(probs.items(), key=lambda x: -x[1])[:5]
print(f"Word: ___ing")
print(f"  Top 5: {top}")

Testing HMM predictions:

Empty 5-letter word:
  Top 5: [('e', 0.0944451971757346), ('a', 0.0901352794680006), ('o', 0.07865477524185965), ('r', 0.07671637486089959), ('i', 0.07334579624967802)]

Word: _e___
  Top 5: [('r', 0.07881948412134224), ('v', 0.06761419297406523), ('t', 0.06269464086398246), ('n', 0.05917035749228527), ('l', 0.055121413538864183)]

Word: ___ing
  Top 5: [('t', 0.07404734518306579), ('r', 0.07216262892429859), ('l', 0.06139897668415388), ('d', 0.060572843365315675), ('a', 0.05679379676741129)]


In [11]:
# Quick evaluation on test set
from hangman_env import HangmanEnv
from utils import calculate_final_score

# Load test words
with open('../Data/test.txt', 'r') as f:
    test_words = [''.join(c for c in line.strip().lower() if c.isalpha()) for line in f if line.strip()][:500]

def test_hmm(hmm, test_words):
    results = []
    for word in tqdm(test_words, desc="Testing HMM"):
        env = HangmanEnv(word, max_lives=6)
        env.reset()
        
        while not env.done:
            masked = env.get_masked_word_list()
            probs = hmm.predict_letter_probabilities(masked, env.guessed_letters, len(word))
            
            # Pick best available letter
            available = {k: v for k, v in probs.items() if k not in env.guessed_letters}
            if available:
                action = max(available, key=available.get)
            else:
                break
            
            env.step(action)
        
        stats = env.get_stats()
        results.append({'won': env.won, 'wrong': stats['wrong_count'], 'repeated': stats['repeated_count']})
    
    wins = sum(1 for r in results if r['won'])
    rate = wins / len(results)
    wrong = sum(r['wrong'] for r in results)
    repeated = sum(r['repeated'] for r in results)
    score = calculate_final_score(rate, wrong, repeated, len(results))
    
    return rate, score, wrong, repeated

print("\n" + "="*60)
print("TESTING IMPROVED HMM (500 words)")
print("="*60)

rate, score, wrong, repeated = test_hmm(hmm, test_words)

print(f"\nWin Rate: {rate:.4f} ({rate*100:.2f}%)")
print(f"Wrong Guesses: {wrong} (avg: {wrong/len(test_words):.2f})")
print(f"Repeated: {repeated}")
print(f"Score: {score:.2f}")
print("="*60)


TESTING IMPROVED HMM (500 words)


Testing HMM: 100%|██████████| 500/500 [00:00<00:00, 1937.59it/s]


Win Rate: 0.2260 (22.60%)
Wrong Guesses: 2737 (avg: 5.47)
Repeated: 0
Score: -13572.00





In [12]:
# Save the improved HMM
hmm.save('../models/improved_hmm.pkl')
print("\n✅ Improved HMM saved and ready for RL training!")

✓ Model saved to ../models/improved_hmm.pkl

✅ Improved HMM saved and ready for RL training!
