In [8]:
import random
import matplotlib.pyplot as plt
import numpy as np
import random
from tqdm import tqdm
from collections import Counter


from environments import WordleEnv, WordleEnvMarkov

# -----------------------------
# Q-Learning Agent Definition
# -----------------------------
def train_q_learning(env, num_episodes=1000, alpha=0.5, gamma=0.9, epsilon=0.2):
    """
    A simple tabular Q-learning algorithm that trains on the Wordle environment.
    
    The state is defined as a tuple: (attempt_number, last_feedback)
    where last_feedback is a tuple of length word_length.
    
    This version prevents the agent from guessing the same word twice in a single episode.
    """
    Q = {}  # Q-table

    def get_state(observation):
        # Extract Information state from observation
        return ()
    
    def choose_action(state, guessed_actions):
        allowed_actions = [a for a in range(env.action_space.n) if a not in guessed_actions]
        if not allowed_actions:
            return env.action_space.sample()
        if random.random() < epsilon or state not in Q:
            return random.choice(allowed_actions)
        else:
            q_values = Q[state]
            allowed_q = {a: q_values[a] for a in allowed_actions}
            return max(allowed_q, key=allowed_q.get)
    
    def update_Q(state, action, reward, next_state, done):
        if state not in Q:
            Q[state] = {a: 0 for a in range(env.action_space.n)}
        if next_state not in Q:
            Q[next_state] = {a: 0 for a in range(env.action_space.n)}
        best_next = max(Q[next_state].values())
        Q[state][action] += alpha * (reward + gamma * best_next * (1 - int(done)) - Q[state][action])
    
    for episode in range(num_episodes):
        observation, _ = env.reset()
        state = get_state(observation)
        guessed_actions = set() 
        done = False
        
        while not done:
            action = choose_action(state, guessed_actions)
            guessed_actions.add(action)
            next_observation, reward, done, _, _ = env.step(action)
            print("Episode: ", episode, "Action: ", action, "Reward: ", reward, "Done: ", done)
            next_state = get_state(next_observation)
            update_Q(state, action, reward, next_state, done)
            state = next_state
        # Optional: print progress every 100 episodes
        if (episode + 1) % 100 == 0:
            print(f"Episode {episode + 1}/{num_episodes} completed")
    return Q


# -----------------------------
# Testing the Trained Agent
# -----------------------------
def test_agent(env, Q):
    observation, _ = env.reset()
    state = (observation["attempt"], tuple(observation["feedback"].tolist()))
    guessed_actions = set()
    done = False
    print("\nTesting trained agent:")
    while not done:
        allowed_actions = [a for a in range(env.action_space.n) if a not in guessed_actions]
        if not allowed_actions:
            # Fallback if all actions have been guessed.
            action = env.action_space.sample()
        elif state in Q:
            # Choose the best allowed action based on Q-values.
            allowed_q = {a: Q[state][a] for a in allowed_actions}
            action = max(allowed_q, key=allowed_q.get)
        else:
            action = random.choice(allowed_actions)
        guessed_actions.add(action)
        observation, reward, done, _, _ = env.step(action)
        print(f"Guess: {env.word_list[action]}, Feedback: {observation['feedback']}, Reward: {reward}")
        state = (observation["attempt"], tuple(observation["feedback"].tolist()))
    env.render()

def simulate_game_with_target(env, Q, target_word):
    """
    Simulate a single game with the target word fixed to target_word.
    Uses the learned Q-values to choose actions. Returns the number of moves
    taken to solve the word if successful, or 7 if the agent fails within 6 moves.
    """
    observation, _ = env.reset()
    # Override the randomly chosen target with the given target_word.
    env.target = target_word
    state = (observation["attempt"], tuple(observation["feedback"].tolist()))
    guessed_actions = set()
    done = False

    while not done:
        allowed_actions = [a for a in range(env.action_space.n) if a not in guessed_actions]
        if not allowed_actions:
            action = env.action_space.sample()
        elif state in Q:
            allowed_q = {a: Q[state][a] for a in allowed_actions}
            action = max(allowed_q, key=allowed_q.get)
        else:
            action = random.choice(allowed_actions)
        guessed_actions.add(action)
        observation, reward, done, _, _ = env.step(action)
        state = (observation["attempt"], tuple(observation["feedback"].tolist()))
    # If reward is positive, the agent solved the word; otherwise record 7 moves (failure).
    return env.attempt if reward > 0 else 7


def plot_histogram(move_counts):
    """
    Plots a histogram with 7 bins: moves 1 through 6 and 7 for failures.
    The x-axis shows the number of moves and the y-axis the count of words.
    """
    # Create bins labeled 1 through 7.
    bins = np.arange(1, 9) - 0.5  # bin edges for 7 bins
    plt.hist(move_counts, bins=bins, edgecolor="black")
    plt.xlabel("Number of Moves (7 = failure)")
    plt.ylabel("Number of Words")
    plt.title("Histogram of Moves Required to Solve Wordle")
    plt.xticks(range(1, 8))
    plt.show()


In [2]:
env = WordleEnv(word_list_file="target_words.txt")
# Train the Q-learning agent.
Q = train_q_learning(env, num_episodes=1000)
# Test the trained agent.
# test_agent(env, Q)
    # Evaluate the agent on every possible target word.
move_counts = []
for target_word in tqdm(env.word_list):
    # For each target, reset the environment and simulate a game.
    moves = simulate_game_with_target(env, Q, target_word)
    # print(f"Target: {target_word}, Moves: {moves}")
    move_counts.append(moves)

Episode:  0 Action:  245 Reward:  0 Done:  False
Episode:  0 Action:  0 Reward:  2.5 Done:  False
Episode:  0 Action:  410 Reward:  0.5 Done:  False
Episode:  0 Action:  1 Reward:  0.5 Done:  False
Episode:  0 Action:  2 Reward:  1 Done:  False
Episode:  0 Action:  3 Reward:  -10 Done:  True
Episode:  1 Action:  0 Reward:  0.5 Done:  False
Episode:  1 Action:  1 Reward:  0 Done:  False
Episode:  1 Action:  2 Reward:  1 Done:  False
Episode:  1 Action:  3 Reward:  0.5 Done:  False
Episode:  1 Action:  4 Reward:  1.5 Done:  False
Episode:  1 Action:  5 Reward:  -10 Done:  True
Episode:  2 Action:  0 Reward:  0 Done:  False
Episode:  2 Action:  19 Reward:  0.5 Done:  False
Episode:  2 Action:  1 Reward:  1.5 Done:  False
Episode:  2 Action:  2 Reward:  1 Done:  False
Episode:  2 Action:  3 Reward:  1 Done:  False
Episode:  2 Action:  4 Reward:  -10 Done:  True
Episode:  3 Action:  0 Reward:  0.5 Done:  False
Episode:  3 Action:  1 Reward:  0.5 Done:  False
Episode:  3 Action:  2 Reward:  

100%|██████████| 2309/2309 [00:04<00:00, 501.28it/s]


In [3]:
print(Counter(move_counts).keys())
print(Counter(move_counts).values())

dict_keys([7, 2, 1, 4, 3])
dict_values([2302, 3, 1, 2, 1])


In [19]:
import random
import torch

def train_q_learning(env, num_episodes=1000, alpha=0.5, gamma=0.9, epsilon=0.2, word_length=5, num_letters=26):
    """
    Tabular Q-learning algorithm for Wordle, using an information state.

    Args:
        env: The Wordle environment
        num_episodes: Number of episodes to train
        alpha: Learning rate
        gamma: Discount factor
        epsilon: Exploration rate
        word_length: Length of the target word
        num_letters: Number of possible letters (26 for English alphabet)

    Returns:
        Q: The learned Q-table
    """
    Q = {}  # Q-table

    def get_information_state(observation):
        """
        Extract a compact information state from the observation, mimicking WordleFeatureExtractor_Markov.
        This version directly mirrors the logic of the `forward` method, adapted for a single environment.
        """

        # Initialize state representation
        state = torch.zeros((word_length, num_letters))
        greens = {}  # {letter_idx: [positions]}
        yellows = {}  # {letter_idx: [positions]}
        blacks = {}  # {letter_idx: [positions]}  for truly absent letters
        missing_letters = {pos: [] for pos in range(word_length)}

        attempt_idx = observation['attempt'].item()

        if attempt_idx == 0:
            return "initial" # Using same string as original q learning impl

        for guess_idx in range(attempt_idx):
            last_feedback = observation['board'][guess_idx]
            last_guess = observation['guesses'][guess_idx]

            if (last_guess < 0).any():
                continue  # Skip invalid guesses

            for idx, (feed, letter) in enumerate(zip(last_feedback, last_guess)):
                letter_item = letter.item()
                if feed == 2:  # Green
                    if letter_item not in greens:
                        greens[letter_item] = []
                    greens[letter_item].append(idx)

                elif feed == 1:  # Yellow
                    if letter_item not in yellows:
                        yellows[letter_item] = []
                    yellows[letter_item].append(idx)
                elif feed == 0:  # Black (Gray)
                    if letter_item not in blacks:
                        blacks[letter_item] = []
                    blacks[letter_item].append(idx)

        # Process green positions
        for letter_idx, positions in greens.items():
            for pos in positions:
                state[pos, letter_idx] = 1
                for other_letter in range(num_letters):
                    if other_letter != letter_idx:
                        state[pos, other_letter] = -1

         # Process yellows *after* greens
        for letter_idx, positions in yellows.items():

            # If the letter has been confirmed as green, adjust yellow processing.
            confirmed_greens_count = len(greens.get(letter_idx, []))
            
            # Exclude yellows from the positions and greens from being candidates
            candidate_positions = [p for p in range(word_length) if p not in positions and p not in greens.get(letter_idx,[])]
            for pos in positions:
                state[pos, letter_idx] = -1
                if letter_idx not in missing_letters[pos]:
                    missing_letters[pos].append(letter_idx)

            if candidate_positions:
                yellow_value = min(1.0, len(positions) / len(candidate_positions))

                for pos in candidate_positions:
                  state[pos, letter_idx] = yellow_value
                  if yellow_value == 1:  #yellow confirmed at position
                    for other_letter in range(num_letters):
                      if other_letter != letter_idx:
                        state[pos, other_letter] = -1



        # Process blacks *after* greens and yellows
        for letter_idx, positions in blacks.items():
            has_positive_info = (state[:, letter_idx] > 0).any()

            if has_positive_info:
                # If we have green or yellow info, just mark black positions as impossible
                for pos in positions:
                    state[pos, letter_idx] = -1
                    if letter_idx not in missing_letters[pos]:
                      missing_letters[pos].append(letter_idx)
            else:
                # No positive info, the letter is likely absent
                for pos in range(word_length):
                    state[pos, letter_idx] = -1
                    if letter_idx not in missing_letters[pos]:
                        missing_letters[pos].append(letter_idx)

        # Convert state to string representation for Q-table
        return state.flatten().numpy().tobytes()



    def choose_action(state, guessed_actions):
        """Choose an action based on the current state, using epsilon-greedy strategy."""
        allowed_actions = [a for a in range(env.action_space.n) if a not in guessed_actions]

        if not allowed_actions:
            return env.action_space.sample()  # All actions were tried

        if random.random() < epsilon or state not in Q:
            return random.choice(allowed_actions)
        else:
            q_values = Q[state]
            allowed_q = {a: q_values[a] for a in allowed_actions}
            return max(allowed_q, key=allowed_q.get)


    def update_Q(state, action, reward, next_state, done):
        """Update the Q-table using the Q-learning update rule."""
        if state not in Q:
            Q[state] = {a: 0 for a in range(env.action_space.n)}
        if next_state not in Q:
            Q[next_state] = {a: 0 for a in range(env.action_space.n)}

        best_next_action_value = max(Q[next_state].values()) if not done else 0
        Q[state][action] += alpha * (reward + gamma * best_next_action_value - Q[state][action])


    # Main training loop
    for episode in range(num_episodes):
        observation, _ = env.reset()
        state = get_information_state(observation)
        done = False
        guessed_actions = set()

        while not done:
            action = choose_action(state, guessed_actions)
            guessed_actions.add(action)
            next_observation, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            next_state = get_information_state(next_observation)
            update_Q(state, action, reward, next_state, done)
            state = next_state

        if episode % 100 == 0:
            print(f"Episode {episode}/{num_episodes} completed")

    return Q


def test_agent(env, Q, num_test_episodes=100, word_length=5, num_letters=26):
    """
    Test the trained agent.

    Args:
        env: The Wordle environment
        Q: The learned Q-table
        num_test_episodes: Number of test episodes to run

    Returns:
        results: List of (target_word, num_attempts, success) tuples
    """
    def get_information_state(observation):
        """
        Extract a compact information state from the observation, mimicking WordleFeatureExtractor_Markov.
        This version directly mirrors the logic of the `forward` method, adapted for a single environment.
        """

        # Initialize state representation
        state = torch.zeros((word_length, num_letters))
        greens = {}  # {letter_idx: [positions]}
        yellows = {}  # {letter_idx: [positions]}
        blacks = {}  # {letter_idx: [positions]}  for truly absent letters
        missing_letters = {pos: [] for pos in range(word_length)}

        attempt_idx = observation['attempt'].item()

        if attempt_idx == 0:
            return "initial" # Using same string as original q learning impl

        for guess_idx in range(attempt_idx):
            last_feedback = observation['board'][guess_idx]
            last_guess = observation['guesses'][guess_idx]

            if (last_guess < 0).any():
                continue  # Skip invalid guesses

            for idx, (feed, letter) in enumerate(zip(last_feedback, last_guess)):
                letter_item = letter.item()
                if feed == 2:  # Green
                    if letter_item not in greens:
                        greens[letter_item] = []
                    greens[letter_item].append(idx)

                elif feed == 1:  # Yellow
                    if letter_item not in yellows:
                        yellows[letter_item] = []
                    yellows[letter_item].append(idx)
                elif feed == 0:  # Black (Gray)
                    if letter_item not in blacks:
                        blacks[letter_item] = []
                    blacks[letter_item].append(idx)

        # Process green positions
        for letter_idx, positions in greens.items():
            for pos in positions:
                state[pos, letter_idx] = 1
                for other_letter in range(num_letters):
                    if other_letter != letter_idx:
                        state[pos, other_letter] = -1

         # Process yellows *after* greens
        for letter_idx, positions in yellows.items():

            # If the letter has been confirmed as green, adjust yellow processing.
            confirmed_greens_count = len(greens.get(letter_idx, []))
            
            # Exclude yellows from the positions and greens from being candidates
            candidate_positions = [p for p in range(word_length) if p not in positions and p not in greens.get(letter_idx,[])]
            for pos in positions:
                state[pos, letter_idx] = -1
                if letter_idx not in missing_letters[pos]:
                    missing_letters[pos].append(letter_idx)

            if candidate_positions:
                yellow_value = min(1.0, len(positions) / len(candidate_positions))

                for pos in candidate_positions:
                  state[pos, letter_idx] = yellow_value
                  if yellow_value == 1:  #yellow confirmed at position
                    for other_letter in range(num_letters):
                      if other_letter != letter_idx:
                        state[pos, other_letter] = -1



        # Process blacks *after* greens and yellows
        for letter_idx, positions in blacks.items():
            has_positive_info = (state[:, letter_idx] > 0).any()

            if has_positive_info:
                # If we have green or yellow info, just mark black positions as impossible
                for pos in positions:
                    state[pos, letter_idx] = -1
                    if letter_idx not in missing_letters[pos]:
                      missing_letters[pos].append(letter_idx)
            else:
                # No positive info, the letter is likely absent
                for pos in range(word_length):
                    state[pos, letter_idx] = -1
                    if letter_idx not in missing_letters[pos]:
                        missing_letters[pos].append(letter_idx)

        # Convert state to string representation for Q-table
        return state.flatten().numpy().tobytes()
    results = []
    for episode in range(num_test_episodes):
        observation, _ = env.reset()
        target_word = env.target_word
        state = get_information_state(observation)
        done = False
        guessed_actions = set()
        attempts = 0

        while not done:
            attempts += 1
            allowed_actions = [a for a in range(env.action_space.n) if a not in guessed_actions]
            if not allowed_actions: #if no actions are left for some reason, force it to quit out
                attempts = 7
                action = env.action_space.sample()
                done = True
            elif state in Q:
                # Find the best action among the allowed ones.
                q_values = Q[state]
                allowed_q = {a: q_values[a] for a in allowed_actions}
                action = max(allowed_q, key=allowed_q.get)

            else:
                action = random.choice(allowed_actions) #random allowed action

            guessed_actions.add(action)
            next_observation, reward, terminated, truncated, _ = env.step(action)
            done = terminated or truncated
            next_state = get_information_state(next_observation)
            state = next_state

        success = env.won
        results.append((target_word, attempts, success))
        if episode % 20 == 0:
                print(f"tested {episode} / {num_test_episodes}")

    return results

In [20]:
env = WordleEnvMarkov("target_words.txt")
# Train the Q-learning agent.
# Train the Q-learning agent
Q = train_q_learning(env, num_episodes=10000)

# Test the agent
results = test_agent(env, Q, num_test_episodes=100)

# Analyze results
total_episodes = len(results)
successful_episodes = sum(1 for _, _, success in results if success)
success_rate = successful_episodes / total_episodes
print(f"Success rate: {success_rate:.4f}")

# Calculate average moves for successful games
move_counts = [attempts for _, attempts, success in results if success]
if move_counts:
    average_moves = sum(move_counts) / len(move_counts)
    print(f"Average moves for successful games: {average_moves:.2f}")
else:
    print("No successful games to calculate average moves.")

#create same histogram
# performance = evaluate_performance(results)
# plot_histogram(performance['move_distribution']

Episode 0/10000 completed
Episode 100/10000 completed
Episode 200/10000 completed
Episode 300/10000 completed
Episode 400/10000 completed
Episode 500/10000 completed
Episode 600/10000 completed
Episode 700/10000 completed
Episode 800/10000 completed
Episode 900/10000 completed
Episode 1000/10000 completed
Episode 1100/10000 completed
Episode 1200/10000 completed
Episode 1300/10000 completed
Episode 1400/10000 completed
Episode 1500/10000 completed
Episode 1600/10000 completed
Episode 1700/10000 completed
Episode 1800/10000 completed
Episode 1900/10000 completed
Episode 2000/10000 completed
Episode 2100/10000 completed


KeyboardInterrupt: 