In [86]:
"""
IMPROVED RL AGENT
=================
Incorporates best practices from the high-performing solution:
1. Balanced reward function
2. Simpler feature representation
3. More stable training
"""

import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import deque, Counter
import matplotlib.pyplot as plt
import pickle
import re


"""
IMPROVED HMM WITH BIGRAM MODELING AND DYNAMIC WEIGHTING
======================================================
Incorporates the bigram transition approach and adds dynamic weighting between
the sequence model and the pattern-matching model for more robust predictions.
"""

import numpy as np
import pickle
from collections import Counter
import re

# ==============================================================================
# FINAL HMM CLASS DEFINITION (COPY THIS ENTIRE CLASS INTO YOUR RL NOTEBOOK)
# ==============================================================================

class ImprovedHangmanHMM:
    """
    Enhanced HMM that uses an interpolated trigram/bigram/unigram model,
    dynamically blended with a pattern-matching model for predictions.
    """
    
    def __init__(self, smoothing=1.0):
        self.smoothing = smoothing
        self.alphabet = 'abcdefghijklmnopqrstuvwxyz'
        self.letter_to_idx = {c: i for i, c in enumerate(self.alphabet)}
        self.idx_to_letter = {i: c for i, c in enumerate(self.alphabet)}
        
        ### CORRECTED ATTRIBUTES ###
        self.unigram_probs = np.ones(26) / 26
        self.bigram_probs = np.ones((26, 26)) / 26
        self.trigram_probs = np.ones((26, 26, 26)) / 26
        
        self.words_by_length = {}
        self.trained = False
    
    # The train method is not strictly needed in the RL notebook but is included for completeness.
    def train(self, words):
        print("\nTraining HMM with Trigram, Bigram, and Unigram models...")
        unigram_counts = np.full(26, self.smoothing)
        bigram_counts = np.full((26, 26), self.smoothing)
        trigram_counts = np.full((26, 26, 26), self.smoothing)
        
        for word in words:
            word = word.lower()
            if not word or not all(c in self.alphabet for c in word): continue
            length = len(word)
            if length not in self.words_by_length: self.words_by_length[length] = []
            self.words_by_length[length].append(word)
            for char in word: unigram_counts[self.letter_to_idx[char]] += 1.0
            for i in range(len(word) - 1):
                prev_idx = self.letter_to_idx[word[i]]
                curr_idx = self.letter_to_idx[word[i+1]]
                bigram_counts[prev_idx, curr_idx] += 1.0
            for i in range(len(word) - 2):
                p_prev_idx = self.letter_to_idx[word[i]]
                prev_idx = self.letter_to_idx[word[i+1]]
                curr_idx = self.letter_to_idx[word[i+2]]
                trigram_counts[p_prev_idx, prev_idx, curr_idx] += 1.0
        
        self.unigram_probs = unigram_counts / unigram_counts.sum()
        self.bigram_probs = bigram_counts / (bigram_counts.sum(axis=1, keepdims=True) + 1e-9)
        self.trigram_probs = trigram_counts / (trigram_counts.sum(axis=2, keepdims=True) + 1e-9)
        self.trained = True

    def _get_pattern_probs(self, masked_word, guessed_letters):
        length = len(masked_word)
        pattern_probs = np.zeros(26)
        num_matching_words = 0
        if length in self.words_by_length:
            pattern = masked_word.replace('_', '.')
            try: pattern_regex = re.compile(f'^{pattern}$')
            except re.error: return pattern_probs, 0
            guessed_wrong = {l for l in guessed_letters if l not in masked_word}
            matching_words = [w for w in self.words_by_length[length] if pattern_regex.match(w) and not any(c in guessed_wrong for c in w)]
            num_matching_words = len(matching_words)
            if matching_words:
                letter_counts = Counter(c for w in matching_words for c in set(w) if c not in guessed_letters)
                if letter_counts:
                    total_counts = sum(letter_counts.values())
                    for letter, count in letter_counts.items():
                        pattern_probs[self.letter_to_idx[letter]] = count / total_counts
        return pattern_probs, num_matching_words

    def predict_letter_probabilities(self, masked_word, guessed_letters):
        if not self.trained: return np.ones(26) / 26
        
        lambda1, lambda2, lambda3 = 0.1, 0.3, 0.6
        presence_probs = np.zeros(26)
        if masked_word.count('_') == 0: return np.zeros(26)

        for i, char in enumerate(masked_word):
            if char == '_':
                prev_char = masked_word[i-1] if i > 0 and masked_word[i-1] != '_' else None
                p_prev_char = masked_word[i-2] if i > 1 and masked_word[i-2] != '_' else None
                prob_dist = self.unigram_probs
                if prev_char:
                    prev_idx = self.letter_to_idx[prev_char]
                    bigram_dist = self.bigram_probs[prev_idx, :]
                    prob_dist = (1 - lambda1) * bigram_dist + lambda1 * self.unigram_probs
                if p_prev_char and prev_char:
                    p_prev_idx = self.letter_to_idx[p_prev_char]
                    prev_idx = self.letter_to_idx[prev_char]
                    trigram_dist = self.trigram_probs[p_prev_idx, prev_idx, :]
                    bigram_dist = self.bigram_probs[prev_idx, :]
                    prob_dist = (lambda3 * trigram_dist) + (lambda2 * bigram_dist) + (lambda1 * self.unigram_probs)
                presence_probs += prob_dist

        if presence_probs.sum() > 0: presence_probs /= presence_probs.sum()
        pattern_probs, num_matching_words = self._get_pattern_probs(masked_word, guessed_letters)

        if pattern_probs.sum() > 0:
            pattern_weight = max(0.5, min(0.95, 0.95 - (num_matching_words / 500.0)))
            combined = pattern_weight * pattern_probs + (1.0 - pattern_weight) * presence_probs
        else:
            combined = presence_probs
        
        for letter in guessed_letters:
            if letter in self.letter_to_idx:
                combined[self.letter_to_idx[letter]] = 0
        
        if combined.sum() > 0:
            combined /= combined.sum()
        
        return combined
    
    def save(self, filename='improved_hmm_model.pkl'):
        with open(filename, 'wb') as f:
            pickle.dump({
                ### CORRECTED KEYS TO SAVE ###
                'unigram_probs': self.unigram_probs,
                'bigram_probs': self.bigram_probs,
                'trigram_probs': self.trigram_probs,
                'words_by_length': self.words_by_length,
                'alphabet': self.alphabet,
                'letter_to_idx': self.letter_to_idx,
                'smoothing': self.smoothing
            }, f)
        print(f"Model saved to {filename}")
    
    def load(self, filename='improved_hmm_model.pkl'):
        with open(filename, 'rb') as f:
            data = pickle.load(f)
            ### CORRECTED KEYS TO LOAD ###
            self.unigram_probs = data['unigram_probs']
            self.bigram_probs = data['bigram_probs']
            self.trigram_probs = data['trigram_probs']
            self.words_by_length = data['words_by_length']
            self.alphabet = data['alphabet']
            self.letter_to_idx = data['letter_to_idx']
            self.smoothing = data['smoothing']
            self.trained = True
        print(f"Model loaded from {filename}")

In [87]:
# ==============================================================================
# IMPROVED ENVIRONMENT WITH BETTER REWARDS
# ==============================================================================

class ImprovedHangmanEnvironment:
    """Hangman environment with balanced reward function."""
    
    def __init__(self, word_list, max_wrong=6):
        self.word_list = word_list
        self.max_wrong = max_wrong
        self.reset()
    
    def reset(self, word=None):
        """Reset for new game. Now handles dictionary or list of words."""
        if word is None:
            ### MODIFICATION START ###
            # Handle the case where self.word_list is a dictionary of words grouped by length.
            # This fixes the crash during the initial self.reset() call in __init__.
            if isinstance(self.word_list, dict):
                # 1. Pick a random length (a key from the dictionary)
                random_length = random.choice(list(self.word_list.keys()))
                # 2. Pick a random word from that length's list
                self.target_word = random.choice(self.word_list[random_length]).lower()
            else:
                # Original behavior for backward compatibility: assume a flat list of words
                self.target_word = random.choice(self.word_list).lower()
            ### MODIFICATION END ###
        else:
            self.target_word = word.lower()
        
        self.masked_word = ['_'] * len(self.target_word)
        self.guessed_letters = set()
        self.wrong_guesses = 0
        self.repeated_guesses = 0
        self.done = False
        
        return self.get_state()
    
    def get_state(self):
        """Return current state."""
        return {
            'masked_word': ''.join(self.masked_word),
            'guessed_letters': self.guessed_letters.copy(),
            'wrong_guesses': self.wrong_guesses,
            'lives_left': self.max_wrong - self.wrong_guesses,
            'done': self.done,
            'target_word': self.target_word
        }
    
    def step(self, letter):
        """Take action with BALANCED reward function."""
        letter = letter.lower()
        
        if letter in self.guessed_letters:
            self.repeated_guesses += 1
            reward = -0.2
            info = {'repeated': True, 'correct': False}
            return self.get_state(), reward, self.done, info
        
        self.guessed_letters.add(letter)
        
        if letter in self.target_word:
            positions_revealed = 0
            for i, char in enumerate(self.target_word):
                if char == letter:
                    self.masked_word[i] = letter
                    positions_revealed += 1
            
            reward = 0.5 * positions_revealed
            
            if '_' not in self.masked_word:
                self.done = True
                reward += 10.0
            
            info = {'repeated': False, 'correct': True, 'positions': positions_revealed}
        else:
            self.wrong_guesses += 1
            reward = -1.0
            
            if self.wrong_guesses >= self.max_wrong:
                self.done = True
                reward -= 2.0
            
            info = {'repeated': False, 'correct': False}
        
        return self.get_state(), reward, self.done, info

In [88]:
# ==============================================================================
# SIMPLIFIED NEURAL NETWORK
# ==============================================================================

class SimplifiedDQN(nn.Module):
    """Simpler, more efficient network architecture."""
    
    def __init__(self, state_size, action_size):
        super(SimplifiedDQN, self).__init__()
        
        self.fc1 = nn.Linear(state_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, action_size)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [89]:
# ==============================================================================
# IMPROVED RL AGENT
# ==============================================================================

class ImprovedRLAgent:
    """RL agent with a dense, feature-engineered state representation."""
    
    def __init__(self, hmm_model):
        self.hmm = hmm_model
        self.alphabet = 'abcdefghijklmnopqrstuvwxyz'
        self.letter_to_idx = {c: i for i, c in enumerate(self.alphabet)}
        self.idx_to_letter = {i: c for i, c in enumerate(self.alphabet)}
        
        # State representation remains the same dense, 55-dimensional vector
        self.state_size = 55
        
        ### MODIFICATION START ###
        # 1. STRATEGIC ACTION SPACE
        self.action_size = 5 # Actions correspond to picking the 1st, 2nd, ..., 5th best HMM guess
        ### MODIFICATION END ###

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.qnetwork = SimplifiedDQN(self.state_size, self.action_size).to(self.device)
        self.target_qnetwork = SimplifiedDQN(self.state_size, self.action_size).to(self.device)
        self.target_qnetwork.load_state_dict(self.qnetwork.state_dict())
        
        # Hyperparameters are robust and remain the same from the previous fix
        self.optimizer = optim.Adam(self.qnetwork.parameters(), lr=0.001)
        self.memory = deque(maxlen=20000)
        self.batch_size = 128
        self.gamma = 0.99
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.9995
        
        self.update_target_every = 10
        self.losses = []
        self.episode_rewards = []
        self.episode_steps = []
        self.episode_wins = []
    
    ### MODIFICATION START ###
    def encode_state(self, state):
        """Encodes state using dense, meaningful features."""
        masked_word = state['masked_word']
        guessed_letters = state['guessed_letters']
        lives_left = state['lives_left']
        word_len = len(masked_word)
        
        # Feature 1: Normalized lives left
        lives_norm = np.array([lives_left / 6.0])
        
        # Feature 2: Normalized count of remaining blanks
        blanks = masked_word.count('_')
        blanks_norm = np.array([blanks / word_len if word_len > 0 else 0])
        
        # Feature 3: Normalized count of unique letters revealed
        unique_revealed = len(set(c for c in masked_word if c != '_'))
        unique_revealed_norm = np.array([unique_revealed / word_len if word_len > 0 else 0])
        
        # Feature 4: HMM probabilities (already a great dense feature)
        hmm_probs = self.hmm.predict_letter_probabilities(masked_word, guessed_letters)
        
        # Feature 5: Guessed letters binary vector
        guessed_vec = np.zeros(26)
        for letter in guessed_letters:
            guessed_vec[self.letter_to_idx[letter]] = 1.0
        
        # Concatenate all features into a dense vector
        encoded = np.concatenate([
            lives_norm,
            blanks_norm,
            unique_revealed_norm,
            hmm_probs,
            guessed_vec
        ]).astype(np.float32)
        
        return encoded
    
    def get_ranked_letter_choices(self, state):
        """Uses HMM to get a ranked list of the best letters to guess."""
        hmm_probs = self.hmm.predict_letter_probabilities(state['masked_word'], state['guessed_letters'])
        # Sort letter indices by probability in descending order
        sorted_indices = np.argsort(hmm_probs)[::-1]
        
        # Filter out letters that have already been guessed
        ranked_choices = [idx for idx in sorted_indices if self.idx_to_letter[idx] not in state['guessed_letters']]
        
        # Return the top N choices (or fewer if not enough are available)
        return ranked_choices[:self.action_size]

    # 3. UPDATED ACTION SELECTION
    def get_action(self, state, training=True):
        """Selects a strategic action (0-4) and translates it to a letter."""
        ranked_choices = self.get_ranked_letter_choices(state)
        
        # If there are no valid choices, default to a safe action
        if not ranked_choices:
            return 0, 0 # action_idx, letter_idx (e.g., 'a')

        # Epsilon-greedy exploration
        if training and random.random() < self.epsilon:
            action_idx = random.randrange(len(ranked_choices))
        else:
            # Exploitation: use the Q-network
            encoded_state = self.encode_state(state)
            state_tensor = torch.FloatTensor(encoded_state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                q_values = self.qnetwork(state_tensor).cpu().numpy()[0]
            
            # Mask Q-values for actions that are not possible (e.g., if there are only 3 valid choices)
            mask = np.full(self.action_size, -np.inf)
            mask[:len(ranked_choices)] = 0
            q_values += mask

            action_idx = np.argmax(q_values)
        
        # Translate the strategic action index (e.g., 2) to the actual letter index (e.g., 14 for 'o')
        letter_idx = ranked_choices[action_idx]
        
        return action_idx, letter_idx
    
    def remember(self, state, action, reward, next_state, done):
        """Store experience."""
        encoded_state = self.encode_state(state)
        encoded_next_state = self.encode_state(next_state)
        self.memory.append((encoded_state, action, reward, encoded_next_state, done))
    
    def replay(self):
        """Train on a batch of experiences using the Double DQN algorithm."""
        if len(self.memory) < self.batch_size:
            return
        
        batch = random.sample(self.memory, self.batch_size)
        states, actions, rewards, next_states, dones = zip(*batch)
        
        states = torch.FloatTensor(np.array(states)).to(self.device)
        actions = torch.LongTensor(np.array(actions)).to(self.device)
        rewards = torch.FloatTensor(np.array(rewards)).to(self.device)
        next_states = torch.FloatTensor(np.array(next_states)).to(self.device)
        dones = torch.FloatTensor(np.array(dones)).to(self.device)
        
        # Get Q-values for current states from the main network
        current_q = self.qnetwork(states).gather(1, actions.unsqueeze(1))
        
        ### MODIFICATION START: DOUBLE DQN LOGIC ###
        with torch.no_grad():
            # 1. Select the best action for the next_state using the MAIN network.
            best_next_actions = self.qnetwork(next_states).argmax(1).unsqueeze(1)
            
            # 2. Evaluate that action's Q-value using the TARGET network.
            # This decouples selection from evaluation.
            target_q_values = self.target_qnetwork(next_states).gather(1, best_next_actions).squeeze()

            # Calculate the final target Q-value for the Bellman equation
            target_q = rewards + (1 - dones) * self.gamma * target_q_values
        ### MODIFICATION END ###
        
        # Calculate loss and perform backpropagation
        loss = F.mse_loss(current_q.squeeze(), target_q)
        
        self.optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(self.qnetwork.parameters(), 1.0)
        self.optimizer.step()
        
        self.losses.append(loss.item())
    
    def update_epsilon(self):
        """Decay epsilon."""
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)
    
    def update_target_network(self):
        """Update target network."""
        self.target_qnetwork.load_state_dict(self.qnetwork.state_dict())
    
    def save(self, filename='improved_rl_agent.pth'):
        """Save agent."""
        torch.save({
            'qnetwork': self.qnetwork.state_dict(),
            'target_qnetwork': self.target_qnetwork.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
            'stats': {
                'rewards': self.episode_rewards,
                'steps': self.episode_steps,
                'losses': self.losses
            }
        }, filename)
        print(f"Agent saved to {filename}")
    
    def load(self, filename='improved_rl_agent.pth'):
        """Load agent."""
        checkpoint = torch.load(filename, map_location=self.device)
        self.qnetwork.load_state_dict(checkpoint['qnetwork'])
        self.target_qnetwork.load_state_dict(checkpoint['target_qnetwork'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.epsilon = checkpoint['epsilon']
        if 'stats' in checkpoint:
            self.episode_rewards = checkpoint['stats']['rewards']
            self.episode_steps = checkpoint['stats']['steps']
            self.losses = checkpoint['stats']['losses']
        print(f"Agent loaded from {filename}")

In [90]:
def train_improved_agent(agent, env, num_episodes=8000):
    """Train with the strategic action space and success rate logging."""
    print("\n" + "="*70)
    print("TRAINING RL AGENT WITH STRATEGIC ACTION SPACE")
    print("="*70)
    
    # Clear stats from previous runs, if any
    agent.episode_rewards, agent.episode_steps, agent.episode_wins = [], [], []
    
    words_by_length = {l: ws for l, ws in env.word_list.items()}
    
    for episode in range(num_episodes):
        # Curriculum learning setup... (This is unchanged)
        if episode < 1500:   max_len = 6
        elif episode < 3000: max_len = 8
        elif episode < 5000: max_len = 11
        else:                max_len = 25
        
        available_lengths = [l for l in words_by_length.keys() if l <= max_len]
        if not available_lengths: available_lengths = list(words_by_length.keys())
        
        chosen_length = random.choice(available_lengths)
        word = random.choice(words_by_length[chosen_length])
        state = env.reset(word=word)

        total_reward = 0
        steps = 0
        
        while not state['done']:
            # This inner game loop is unchanged and correct from the last fix.
            strategic_action, letter_action = agent.get_action(state, training=True)
            letter_to_guess = agent.idx_to_letter[letter_action]
            next_state, reward, done, info = env.step(letter_to_guess)
            
            is_wrong_guess = not info['correct'] and not info['repeated']
            if is_wrong_guess:
                hmm_probs = agent.hmm.predict_letter_probabilities(state['masked_word'], state['guessed_letters'])
                prob_of_wrong_guess = hmm_probs[letter_action]
                penalty = (1.0 - prob_of_wrong_guess) * 0.25 
                reward -= penalty
            
            agent.remember(state, strategic_action, reward, next_state, done)
            agent.replay()
            
            total_reward += reward
            steps += 1
            state = next_state
        
        ### MODIFICATION START ###
        # After the episode ends, check if it was a win or loss.
        is_win = '_' not in state['masked_word']
        agent.episode_wins.append(1 if is_win else 0)
        ### MODIFICATION END ###

        # Standard updates
        if episode % agent.update_target_every == 0:
            agent.update_target_network()
        agent.update_epsilon()
        
        agent.episode_rewards.append(total_reward)
        agent.episode_steps.append(steps)
        
        # Updated logging block
        if (episode + 1) % 200 == 0:
            avg_reward = np.mean(agent.episode_rewards[-200:])
            avg_steps = np.mean(agent.episode_steps[-200:])
            
            ### MODIFICATION START ###
            # Calculate success rate over the last 200 episodes
            success_rate = np.mean(agent.episode_wins[-200:])
            
            # Add success rate to the print statement (formatted as percentage)
            print(f"Episode {episode+1}/{num_episodes} | "
                  f"Success Rate: {success_rate:.1%} | "
                  f"Avg Reward: {avg_reward:.2f} | "
                  f"Avg Steps: {avg_steps:.2f} | "
                  f"Epsilon: {agent.epsilon:.3f} | "
                  f"Max Len: {max_len}")
            ### MODIFICATION END ###
            
    print("\nTraining complete!")

In [91]:
# ==============================================================================
# EVALUATION (Corrected)
# ==============================================================================

def evaluate_improved_agent(agent, test_words, max_wrong=6):
    """Evaluate on test set."""
    print("\n" + "="*70)
    print("EVALUATING IMPROVED AGENT")
    print("="*70)
    
    wins = 0
    total_wrong = 0
    total_repeated = 0
    
    # Use a single environment instance for efficiency
    env = ImprovedHangmanEnvironment([test_words[0]], max_wrong=max_wrong)
    
    for i, word in enumerate(test_words):
        state = env.reset(word=word)
        
        while not state['done']:
            ### MODIFICATION START ###
            
            # The agent's get_action method returns a tuple: (strategic_action, letter_action)
            # We must unpack it into two separate variables.
            strategic_action, letter_action = agent.get_action(state, training=False)
            
            # Now, use the correct variable (letter_action) to get the letter character.
            letter = agent.idx_to_letter[letter_action]
            
            ### MODIFICATION END ###

            state, reward, done, info = env.step(letter)
        
        if '_' not in state['masked_word']:
            wins += 1
        total_wrong += env.wrong_guesses
        total_repeated += env.repeated_guesses
        
        if (i + 1) % 500 == 0:
            print(f"Evaluated {i+1}/{len(test_words)} words...")
    
    num_games = len(test_words)
    success_rate = wins / num_games
    avg_wrong = total_wrong / num_games
    avg_repeated = total_repeated / num_games
    
    # The scoring formula in the problem description is based on rates AND totals.
    # It seems to be (Success Rate * 2000) not (Success Rate * num_games). Let's stick to the prompt.
    # Also, the prompt uses a sample of 1000 games for the score, so we should scale.
    scaling_factor = 1000 / num_games
    final_score = (success_rate * 2000) - ((total_wrong * scaling_factor) * 5) - ((total_repeated * scaling_factor) * 2)

    print(f"\n{'='*70}")
    print("EVALUATION RESULTS")
    print(f"{'='*70}")
    print(f"Games Played: {num_games}")
    print(f"Success Rate: {success_rate:.2%}")
    print(f"Average Wrong Guesses per Game: {avg_wrong:.2f}")
    print(f"Average Repeated Guesses per Game: {avg_repeated:.2f}")
    print(f"FINAL SCORE (scaled to 1000 games as per prompt): {final_score:.2f}")
    print(f"{'='*70}")
    
    return {
        'success_rate': success_rate,
        'avg_wrong': avg_wrong,
        'avg_repeated': avg_repeated,
        'final_score': final_score
    }

In [None]:
# ==============================================================================
# MAIN
# ==============================================================================

if __name__ == "__main__":
    print("="*70)
    print("IMPROVED HANGMAN SOLUTION")
    print("="*70)
    
    # Load HMM
    print("\nLoading improved HMM...")
    hmm = ImprovedHangmanHMM()
    hmm.load('improved_hmm_model.pkl')
    
    # Load corpus
    print("\nLoading corpus...")
    with open('corpus.txt', 'r') as f:
        corpus_words = [line.strip().lower() for line in f if line.strip()]
    print(f"Loaded {len(corpus_words)} words")
    
    words_by_len_dict = {}
    for word in corpus_words:
        length = len(word)
        if length not in words_by_len_dict:
            words_by_len_dict[length] = []
        words_by_len_dict[length].append(word)
    print(f"Loaded and grouped {len(corpus_words)} words")
    
    # Pass the grouped dictionary to the environment
    env = ImprovedHangmanEnvironment(words_by_len_dict, max_wrong=6)
    agent = ImprovedRLAgent(hmm)
    
    train_improved_agent(agent, env, num_episodes=8000)
    
    agent.save('improved_rl_agent.pth')
    
    # Evaluate
    print("\nLoading test set...")
    with open('test.txt', 'r') as f:
        test_words = [line.strip().lower() for line in f if line.strip()]
    
    results = evaluate_improved_agent(agent, test_words)
    
    print("\nâœ… COMPLETE!")

IMPROVED HANGMAN SOLUTION

Loading improved HMM...
Model loaded from improved_hmm_model.pkl

Loading corpus...
Loaded 50000 words
Loaded and grouped 50000 words

TRAINING RL AGENT WITH STRATEGIC ACTION SPACE
Episode 200/8000 | Success Rate: 33.5% | Avg Reward: -3.01 | Avg Steps: 7.19 | Epsilon: 0.905 | Max Len: 6
Episode 400/8000 | Success Rate: 35.5% | Avg Reward: -2.61 | Avg Steps: 7.04 | Epsilon: 0.819 | Max Len: 6
Episode 600/8000 | Success Rate: 38.5% | Avg Reward: -2.21 | Avg Steps: 6.99 | Epsilon: 0.741 | Max Len: 6
