In [1]:
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv, VecNormalize
from stable_baselines3.common.monitor import Monitor
import os
import random
from ppo_agents import WordleFeatureExtractor, WordleFeatureExtractor_Markov
from environments import WordleEnv, WordleEnvMarkov  
from trainer import WordleTrainingCallback
from heuristics import HeuristicWordleAgent

def make_env(word_list_path, rank=0):
    """
    Create a WordleEnv environment for Stable Baselines3.
    
    Args:
        word_list_path: Path to the word list file
        rank: Environment rank (for vectorized environments)
        
    Returns:
        A function that creates an instance of the environment
    """
    def _init():
        env = WordleEnvMarkov(word_list_path, max_attempts=6, word_length=5)
        env = Monitor(env)
        return env
    return _init


def train_ppo_agent(word_list_path, total_timesteps=100000, log_dir='./logs'):
    """
    Train a PPO agent on the Wordle environment.
    
    Args:
        word_list_path: Path to the word list file
        total_timesteps: Number of training steps
        log_dir: Directory to save logs
        
    Returns:
        The trained PPO model
    """
    # Create environments
    env = DummyVecEnv([make_env(word_list_path)])
    env = VecNormalize(env, norm_obs=False, norm_reward=True)
    
    eval_env = DummyVecEnv([make_env(word_list_path)])
    eval_env = VecNormalize(eval_env, norm_obs=False, norm_reward=True, training=False)
    
    # Policy kwargs for the feature extractor
    policy_kwargs = {
        'features_extractor_class': WordleFeatureExtractor_Markov,
        'features_extractor_kwargs': {'features_dim': 256}
    }
    
    # Create the PPO agent
    model = PPO(
        "MultiInputPolicy",
        env,
        policy_kwargs=policy_kwargs,
        verbose=1,
        learning_rate=3e-4,
        n_steps=2048,
        batch_size=64,
        n_epochs=10,
        gamma=0.99,
        gae_lambda=0.95,
        clip_range=0.2,
        ent_coef=0.01,
        tensorboard_log=log_dir
    )
    
    # Create the callback
    callback = WordleTrainingCallback(eval_env, check_freq=5000, log_dir=log_dir)
    
    # Train the agent
    model.learn(total_timesteps=total_timesteps, callback=callback)
    
    # Plot training metrics
    callback.plot_metrics()
    
    # Save the model
    model.save(os.path.join(log_dir, "ppo_wordle"))
    
    return model


def evaluate_ppo_agent(model, word_list_path, num_episodes=50, render=True):
    """
    Evaluate a trained PPO agent on the Wordle environment.
    
    Args:
        model: The trained PPO model
        word_list_path: Path to the word list file
        num_episodes: Number of episodes to evaluate
        render: Whether to render the environment
        
    Returns:
        Evaluation results
    """
    # Create the evaluation environment
    eval_env = WordleEnvMarkov(word_list_path, max_attempts=6, word_length=5, render_mode="human" if render else None)
    
    # Load word list
    try:
        with open(word_list_path, 'r') as f:
            words = [w.strip().lower() for w in f.readlines() if len(w.strip()) == 5 and w.strip().isalpha()]
    except FileNotFoundError:
        # Use sample words if file doesn't exist
        print(f"Warning: {word_list_path} not found. Using a small sample of words.")
        words = [
            'apple', 'baker', 'child', 'dance', 'early', 'first', 'grand', 'house', 'input',
            'jolly', 'knife', 'light', 'mouse', 'night', 'ocean', 'piano', 'queen', 'river',
            'sound', 'table', 'under', 'value', 'water', 'xenon', 'youth', 'zebra'
        ]
    
    # Evaluate the agent
    wins = 0
    rewards = []
    attempts = []
    
    for episode in range(num_episodes):
        obs, _ = eval_env.reset()
        done = False
        episode_reward = 0
        used_words = []
        
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            
            # Get the word
            word = words[action]
            if word in used_words:
                # If we've already used this word, pick a random unused word
                unused_words = [w for w in words if w not in used_words]
                if unused_words:
                    word = random.choice(unused_words)
                    action = words.index(word)
                
            used_words.append(word)
            
            obs, reward, terminated, truncated, info = eval_env.step(action)
            episode_reward += reward
            done = terminated or truncated
            
            if render:
                eval_env.render()
        
        rewards.append(episode_reward)
        won = eval_env.won
        wins += 1 if won else 0
        if won:
            attempts.append(eval_env.current_attempt)
        
        if render:
            print(f"Episode {episode+1}/{num_episodes} | " +
                  f"Result: {'Won' if won else 'Lost'} | " +
                  f"Word: {eval_env.target_word} | " +
                  f"Attempts: {eval_env.current_attempt}/{eval_env.max_attempts}" +
                  f" | Reward: {episode_reward:.2f}")
    
    win_rate = wins / num_episodes
    avg_reward = sum(rewards) / num_episodes
    avg_attempts = sum(attempts) / len(attempts) if attempts else 0
    
    print(f"\nEvaluation Results:")
    print(f"Win Rate: {win_rate:.2f}")
    print(f"Average Reward: {avg_reward:.2f}")
    print(f"Average Attempts (when won): {avg_attempts:.2f}")
    
    return {
        'win_rate': win_rate,
        'avg_reward': avg_reward,
        'avg_attempts': avg_attempts
    }


def compare_with_heuristic(model, word_list_path, num_episodes=50):
    """
    Compare the PPO agent with a heuristic agent.
    
    Args:
        model: The trained PPO model
        word_list_path: Path to the word list file
        num_episodes: Number of episodes for comparison
        
    Returns:
        Comparison results
    """
    
    # Load word list
    try:
        with open(word_list_path, 'r') as f:
            words = [w.strip().lower() for w in f.readlines() if len(w.strip()) == 5 and w.strip().isalpha()]
    except FileNotFoundError:
        # Use sample words if file doesn't exist
        print(f"Warning: {word_list_path} not found. Using a small sample of words.")
        words = [
            'apple', 'baker', 'child', 'dance', 'early', 'first', 'grand', 'house', 'input',
            'jolly', 'knife', 'light', 'mouse', 'night', 'ocean', 'piano', 'queen', 'river',
            'sound', 'table', 'under', 'value', 'water', 'xenon', 'youth', 'zebra'
        ]
    
    # Create the heuristic agent
    heuristic_agent = HeuristicWordleAgent(words)
    
    # Create environment
    env = WordleEnvMarkov(word_list_path, max_attempts=6, word_length=5)
    
    # Generate random target words
    target_words = random.sample(words, min(num_episodes, len(words)))
    
    # Evaluate both agents
    results = {
        'ppo': {'wins': 0, 'attempts': []},
        'heuristic': {'wins': 0, 'attempts': []}
    }
    
    for i, target_word in enumerate(target_words):
        print(f"\nGame {i+1}/{num_episodes} - Target word: {target_word}")
        
        # Evaluate PPO agent
        obs, _ = env.reset(options={'target_word': target_word})
        done = False
        used_words = []
        
        print("PPO agent playing...")
        while not done:
            action, _ = model.predict(obs, deterministic=True)
            word = words[action]
            
            # Avoid repeated words
            if word in used_words:
                unused_words = [w for w in words if w not in used_words]
                if unused_words:
                    word = random.choice(unused_words)
                    action = words.index(word)
            
            used_words.append(word)
            obs, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
        
        ppo_won = env.won
        ppo_attempts = env.current_attempt
        results['ppo']['wins'] += 1 if ppo_won else 0
        if ppo_won:
            results['ppo']['attempts'].append(ppo_attempts)
        
        print(f"PPO result: {'Won' if ppo_won else 'Lost'} in {ppo_attempts} attempts")
        
        # Evaluate heuristic agent
        heuristic_agent.reset()
        env.reset(options={'target_word': target_word})
        
        print("Heuristic agent playing...")
        for attempt in range(env.max_attempts):
            guess = heuristic_agent.get_action()
            action = words.index(guess) if guess in words else 0
            
            obs, _, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            
            # Generate feedback
            feedback = []
            for j in range(env.word_length):
                if env.board[attempt, j, env.CORRECT] == 1:
                    feedback.append(2)  # Correct
                elif env.board[attempt, j, env.PRESENT] == 1:
                    feedback.append(1)  # Present
                else:
                    feedback.append(0)  # Absent
            
            heuristic_agent.update(guess, feedback)
            
            if done:
                break
        
        heuristic_won = env.won
        heuristic_attempts = env.current_attempt
        results['heuristic']['wins'] += 1 if heuristic_won else 0
        if heuristic_won:
            results['heuristic']['attempts'].append(heuristic_attempts)
        
        print(f"Heuristic result: {'Won' if heuristic_won else 'Lost'} in {heuristic_attempts} attempts")
    
    # Calculate statistics
    ppo_win_rate = results['ppo']['wins'] / num_episodes
    heuristic_win_rate = results['heuristic']['wins'] / num_episodes
    
    ppo_avg_attempts = sum(results['ppo']['attempts']) / len(results['ppo']['attempts']) if results['ppo']['attempts'] else 0
    heuristic_avg_attempts = sum(results['heuristic']['attempts']) / len(results['heuristic']['attempts']) if results['heuristic']['attempts'] else 0
    
    print("\nComparison Results:")
    print(f"PPO Win Rate: {ppo_win_rate:.2f}")
    print(f"Heuristic Win Rate: {heuristic_win_rate:.2f}")
    print(f"PPO Average Attempts (when won): {ppo_avg_attempts:.2f}")
    print(f"Heuristic Average Attempts (when won): {heuristic_avg_attempts:.2f}")
    
    return results

2025-03-21 00:19:37.809838: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2025-03-21 00:19:38.127282: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-03-21 00:19:38.147443: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2025-03-21 00:19:38.147467: I tensorflow/compiler/xla/stream_executor/cuda/cudart_stub.cc:29] Ignore 

In [None]:
word_list_path = "target_words.txt"
log_dir = "./logs/ppo_wordle"
os.makedirs(log_dir, exist_ok=True)
print("Training PPO agent...")
model = train_ppo_agent(word_list_path, total_timesteps=100000, log_dir=log_dir)
model.save("ppo_wordle")
print("\nEvaluating PPO agent...")

Training PPO agent...
Using cpu device
Logging to ./logs/ppo_wordle/PPO_20
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 5.97     |
|    ep_rew_mean     | 67.7     |
| time/              |          |
|    fps             | 118      |
|    iterations      | 1        |
|    time_elapsed    | 17       |
|    total_timesteps | 2048     |
---------------------------------


KeyboardInterrupt: 

In [None]:
model = PPO.load("ppo_wordle2025-03-19_21-20-40.zip")
evaluate_ppo_agent(model, word_list_path, num_episodes=20, render=True)

⬛🟨⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 1/6
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 2/6
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟨⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 3/6
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 4/6
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟨🟨⬛⬛
⬜⬜⬜⬜⬜
Attempt 5/6
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟨⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟨🟨⬛⬛
⬛🟨⬛⬛🟨
Attempt 6/6
Game over! The word was: rally
Episode 1/20 | Result: Lost | Word: rally | Attempts: 6/6 | Reward: 0.00
⬛⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 1/6
⬛⬛⬛⬛⬛
⬛⬛⬛🟩🟩
⬜⬜⬜⬜⬜
Attempt 2/6
⬛⬛⬛⬛⬛
⬛⬛⬛🟩🟩
⬛⬛⬛⬛🟩
⬜⬜⬜⬜⬜
Attempt 3/6
⬛⬛⬛⬛⬛
⬛⬛⬛🟩🟩
⬛⬛⬛⬛🟩
⬛⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 4/6
⬛⬛⬛⬛⬛
⬛⬛⬛🟩🟩
⬛⬛⬛⬛🟩
⬛⬛⬛⬛⬛
⬛⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 5/6
⬛⬛⬛⬛⬛
⬛⬛⬛🟩🟩
⬛⬛⬛⬛🟩
⬛⬛⬛⬛⬛
⬛⬛⬛⬛⬛
⬛⬛🟩⬛🟩
Attempt 6/6
Game over! The word was: sunny
Episode 2/20 | Result: Lost | Word: sunny | Attempts: 6/6 | Reward: 1.40
⬛⬛🟨🟨⬛
⬜⬜⬜⬜⬜
Attempt 1/6
⬛⬛🟨🟨⬛
⬛🟩⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 2/6
⬛⬛🟨🟨⬛
⬛🟩⬛⬛⬛
🟨⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 3/6
⬛⬛🟨🟨⬛
⬛🟩⬛⬛⬛
🟨⬛⬛⬛⬛
⬛⬛⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 4/6
⬛⬛🟨🟨⬛
⬛🟩⬛⬛⬛
🟨⬛⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟩⬛⬛⬛
⬜⬜⬜⬜⬜
Attempt 5/6
⬛⬛🟨🟨⬛
⬛🟩⬛⬛⬛
🟨⬛⬛⬛⬛
⬛⬛⬛⬛⬛
⬛🟩⬛⬛⬛
⬛⬛⬛🟩⬛
Attempt 6/6
Game over! The word was: mimic
Episode 3/20 | Result: Lost | Word: mimic | Attempts: 6/6 | Reward: 0.90
⬛⬛⬛⬛

{'win_rate': 0.0, 'avg_reward': np.float32(1.1850001), 'avg_attempts': 0}

In [None]:

# Compare with heuristic agent
print("\nComparing with heuristic agent...")
compare_with_heuristic(model, word_list_path, num_episodes=20)


Comparing with heuristic agent...

Game 1/20 - Target word: visit
PPO agent playing...
PPO result: Lost in 6 attempts
Heuristic agent playing...
Heuristic result: Lost in 6 attempts

Game 2/20 - Target word: filer
PPO agent playing...
PPO result: Lost in 6 attempts
Heuristic agent playing...
Heuristic result: Lost in 6 attempts

Game 3/20 - Target word: axiom
PPO agent playing...
PPO result: Lost in 6 attempts
Heuristic agent playing...
Heuristic result: Won in 2 attempts

Game 4/20 - Target word: sworn
PPO agent playing...
PPO result: Lost in 6 attempts
Heuristic agent playing...
Heuristic result: Lost in 6 attempts

Game 5/20 - Target word: carry
PPO agent playing...
PPO result: Lost in 6 attempts
Heuristic agent playing...
Heuristic result: Lost in 6 attempts

Game 6/20 - Target word: hunch
PPO agent playing...
PPO result: Lost in 6 attempts
Heuristic agent playing...
Heuristic result: Lost in 6 attempts

Game 7/20 - Target word: flail
PPO agent playing...
PPO result: Lost in 6 att

{'ppo': {'wins': 0, 'attempts': []},
 'heuristic': {'wins': 3, 'attempts': [2, 3, 3]}}