In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm
import wandb
from transformers import GPT2Model, GPT2Config
import random
import numpy as np
from collections import Counter
# random.seed(42)

In [None]:
# Wide but shallow Transformer Policy Network for Hangman
class HangmanTransformerPolicy(nn.Module):
    def __init__(self, vocab_size=37, hidden_size=1024, num_layers=2, max_length=50):
        super(HangmanTransformerPolicy, self).__init__()
        self.max_length = max_length
        self.vocab_size = vocab_size

        # Define token values
        self.HIDDEN_TOKEN = 37
        self.PADDING_TOKEN = 38

        # Embedding for the word state (includes normal tokens 0-36, hidden token 37, and padding token 38)
        self.word_embedding = nn.Embedding(vocab_size + 2, hidden_size)

        # Transformer for processing the word state
        self.config = GPT2Config(
            vocab_size=vocab_size,
            n_positions=max_length,
            n_embd=hidden_size,
            n_layer=num_layers,  # Shallow (not deep)
            n_head=16,  # Wide attention heads
            n_inner=hidden_size*4  # Wide feed-forward layers
        )
        self.transformer = GPT2Model(self.config)

        # Linear projection for incorrectly guessed letters
        self.incorrect_projection = nn.Linear(37, hidden_size)

        # Embedding for lives (0-10)
        self.lives_embedding = nn.Embedding(11, hidden_size)

        # Final policy head (action logits)
        self.policy_head = nn.Linear(hidden_size * 3, vocab_size)

    def forward(self, x):
        batch_size = x.shape[0]

        # Split the input into its components
        word_state = x[:, :50].long()  # First 50 positions for the word
        incorrect_guesses = x[:, 50:87].float()  # Next 37 positions for incorrect guesses
        lives = x[:, 87].long()  # Last position for lives

        # Process word state
        word_embeds = self.word_embedding(word_state)

        # Create attention mask for padding (padding token is 38)
        attention_mask = (word_state != self.PADDING_TOKEN).float()

        # Pass through transformer
        transformer_outputs = self.transformer(
            inputs_embeds=word_embeds,
            attention_mask=attention_mask
        )
        hidden_states = transformer_outputs.last_hidden_state

        # Get global representation by mean pooling
        masked_hidden = hidden_states * attention_mask.unsqueeze(-1)
        sum_hidden = masked_hidden.sum(dim=1)
        word_features = sum_hidden / (attention_mask.sum(dim=1, keepdim=True) + 1e-10)

        # Process incorrectly guessed letters
        incorrect_features = self.incorrect_projection(incorrect_guesses)

        # Process lives
        lives_features = self.lives_embedding(lives)

        # Combine all features
        combined_features = torch.cat([word_features, incorrect_features, lives_features], dim=1)

        # Get action logits
        action_logits = self.policy_head(combined_features)

        return action_logits

In [None]:
######## Masking already guessed letters
# Wide but shallow Transformer Policy Network for Hangman
class HangmanTransformerPolicy(nn.Module):
    def __init__(self, vocab_size=37, hidden_size=1024, num_layers=2, max_length=50):
        super(HangmanTransformerPolicy, self).__init__()
        self.max_length = max_length
        self.vocab_size = vocab_size

        # Define token values
        self.HIDDEN_TOKEN = 37
        self.PADDING_TOKEN = 38

        # Embedding for the word state (includes normal tokens 0-36, hidden token 37, and padding token 38)
        self.word_embedding = nn.Embedding(vocab_size + 2, hidden_size)

        # Transformer for processing the word state
        self.config = GPT2Config(
            vocab_size=vocab_size,
            n_positions=max_length,
            n_embd=hidden_size,
            n_layer=num_layers,  # Shallow (not deep)
            n_head=16,  # Wide attention heads
            n_inner=hidden_size*4  # Wide feed-forward layers
        )
        self.transformer = GPT2Model(self.config)

        # Linear projection for incorrectly guessed letters
        self.incorrect_projection = nn.Linear(37, hidden_size)

        # Embedding for lives (0-10)
        self.lives_embedding = nn.Embedding(11, hidden_size)

        # Final policy head (action logits)
        self.policy_head = nn.Linear(hidden_size * 3, vocab_size)

    def forward(self, x):
        batch_size = x.shape[0]

        # Split the input into its components
        word_state = x[:, :50].long()  # First 50 positions for the word
        incorrect_guesses = x[:, 50:87].float()  # Next 37 positions for incorrect guesses
        lives = x[:, 87].long()  # Last position for lives

        # Process word state
        word_embeds = self.word_embedding(word_state)

        # Create attention mask for padding (padding token is 38)
        attention_mask = (word_state != self.PADDING_TOKEN).float()

        # Pass through transformer
        transformer_outputs = self.transformer(
            inputs_embeds=word_embeds,
            attention_mask=attention_mask
        )
        hidden_states = transformer_outputs.last_hidden_state

        # Get global representation by mean pooling
        masked_hidden = hidden_states * attention_mask.unsqueeze(-1)
        sum_hidden = masked_hidden.sum(dim=1)
        word_features = sum_hidden / (attention_mask.sum(dim=1, keepdim=True) + 1e-10)

        # Process incorrectly guessed letters
        incorrect_features = self.incorrect_projection(incorrect_guesses)

        # Process lives
        lives_features = self.lives_embedding(lives)

        # Combine all features
        combined_features = torch.cat([word_features, incorrect_features, lives_features], dim=1)

        # Get action logits
        action_logits = self.policy_head(combined_features)

        # Create mask for already guessed letters
        already_guessed_mask = torch.zeros_like(action_logits)

        # Mark letters that appear in the word state (excluding hidden and padding tokens)
        for b in range(batch_size):
            # Get unique letters in the current word state (excluding special tokens)
            visible_letters = word_state[b]
            mask_indices = torch.logical_and(
                visible_letters < self.HIDDEN_TOKEN,  # Only actual letters, not hidden token
                visible_letters != self.PADDING_TOKEN  # Not padding token
            )
            unique_letters = torch.unique(visible_letters[mask_indices])

            # Set these indices in the mask to negative infinity
            already_guessed_mask[b, unique_letters] = float('-inf')

        # Mark letters that were incorrectly guessed
        for b in range(batch_size):
            # Get indices of incorrect guesses (where value is 1)
            incorrect_letters = torch.nonzero(incorrect_guesses[b] > 0.5).squeeze(-1)

            # Set these indices in the mask to negative infinity
            already_guessed_mask[b, incorrect_letters] = float('-inf')

        # Apply the mask to prevent selecting already guessed letters
        masked_action_logits = action_logits + already_guessed_mask

        return masked_action_logits

In [None]:
# Get device
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# Load a word list and split into training and validation sets
def load_words(filename="words.txt", min_length=5, split_ratio=0.9):
    with open(filename) as f:
        words = [line.strip().lower() for line in f if len(line.strip()) >= min_length]
    print(f'len(words)={len(words)}')
    random.seed(42)
    random.shuffle(words)
    split_idx = int(len(words) * split_ratio)
    training_set = words[:split_idx]
    validation_set = words[split_idx:]

    return training_set, validation_set

In [None]:
# Hangman Game Logic
class Hangman:
    def __init__(self, word):
        self.word = word.lower()
        self.guessed = set()
        self.incorrect_guesses = set()
        self.state = ["_" if c.isalpha() else c for c in word]

    def guess(self, letter):
        print(f"Guessing letter: {letter}")
        self.guessed.add(letter)
        if letter in self.word:
            print(f"{letter} is in the word!")
            for i, c in enumerate(self.word):
                if c == letter:
                    self.state[i] = letter
        else:
            print(f"{letter} is NOT in the word.")
            self.incorrect_guesses.add(letter)

    def get_pattern(self):
        return "".join(self.state)

    def is_solved(self):
        return "_" not in self.state

    def allowed_guesses(self):
        return set("abcdefghijklmnopqrstuvwxyz") - self.guessed

In [None]:
# Define vocabulary and create vocabulary index mapping
VOCABULARY = ['!', '&', "'", '-', '.', '/', '0', '1', '2', '3', '5',
              'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k',
              'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v',
              'w', 'x', 'y', 'z']
VOCAB_SIZE = len(VOCABULARY)
VOCAB_TO_IDX = {char: idx for idx, char in enumerate(VOCABULARY)}
IDX_TO_VOCAB = {idx: char for idx, char in enumerate(VOCABULARY)}

# Special token values
UNKNOWN_TOKEN = 37  # For unknown letters in the word ('_')
PADDING_TOKEN = 38  # For padding positions beyond the word length

class ModelMove:
    def __init__(self, model):
        self.model = model

    def encode_state(self, pattern, incorrect_guesses, lives_left):
        # Encode pattern as a fixed-length vector (padded to length 50)
        # Each position contains:
        # - vocab index for revealed letters
        # - UNKNOWN_TOKEN (-1) for unknown letters ('_')
        # - PADDING_TOKEN (-2) for padding beyond word length
        pattern_vector = []
        for char in pattern:
            if char == '_':
                pattern_vector.append(UNKNOWN_TOKEN)  # Use 37 for unknown letters
            else:
                pattern_vector.append(VOCAB_TO_IDX.get(char, UNKNOWN_TOKEN))  # Get vocab index

        # Pad to length 50 with PADDING_TOKEN
        pattern_vector += [PADDING_TOKEN] * (50 - len(pattern_vector))

        # Encode incorrect guesses as a binary vector
        incorrect_guesses_vector = [0] * VOCAB_SIZE
        for letter in incorrect_guesses:
            if letter in VOCAB_TO_IDX:
                incorrect_guesses_vector[VOCAB_TO_IDX[letter]] = 1

        features = pattern_vector + incorrect_guesses_vector + [lives_left]
        state_vector = torch.tensor(features, dtype=torch.float32)
        state_vector = state_vector.unsqueeze(0)
        self.state_vector = state_vector
        return state_vector


    def make_move(self, pattern, incorrect_guesses, lives_left, device):
        state_vector = self.encode_state(pattern, incorrect_guesses, lives_left).to(device)
        pred_logits = self.model(state_vector)
        pred_dist = torch.softmax(pred_logits, dim=-1)

        # Sample from distribution
        guess_index = torch.argmax(pred_dist).item()
        # guess_index = torch.multinomial(pred_dist, num_samples=1).item()

        # Get corresponding letter
        guess = IDX_TO_VOCAB[guess_index]

        return guess

In [None]:
model = HangmanTransformerPolicy().to(device)
model.load_state_dict(torch.load("hangman_transformer.pth"))

<All keys matched successfully>

In [None]:
ModelPlayer = ModelMove(model)

In [None]:
train_words, val_words = load_words()

len(words)=449959


In [None]:
random.seed(7)
train_sample = random.sample(train_words, k=1000)
val_sample = random.sample(val_words, k=10)
print(f'\ntrain_sample={train_sample}\n')
print(f'\nval_sample=No Cheating!\n')


train_sample=['dedicatorily', 'snell', 'uncustomariness', 'nonpearlitic', 'preconquestal', 'proforma', 'maror', 'telephotographs', 'wirephotos', 'exempting', 'inconceivableness', 'graehme', 'fabulosity', 'imbiber', 'hurroosh', 'eternalizing', 'candlewright', 'self-involution', 'nevisdale', 'turnip', 'cageyness', 'premium', 'acoria', 'milarite', 'gullion', 'praising', 'bright-robed', 'crayer', 'broguery', 'dreamtide', 'carunculate', 'tuchunate', 'nightclothes', 'cerulific', 'unnarrow-minded', 'antar', 'woodener', 'irremissibleness', 'underemphasized', 'orifice', 'rewriter', 'ballbuster', 'unavengingly', 'sacian', 'mantises', 'kweichow', 'recompiled', 'archit', 'polarography', 'speen', 'derogation', 'gartering', 'outtowers', 'tautonymy', 'futwa', 'clermont-ferrand', 'wintersome', 'rigidities', 'aminolysis', 'unblade', 'palmae', 'supersensitising', 'lardworm', 'plasmoquin', 'keycard', 'devinna', 'codeia', 'shopmen', 'bibracteolate', 'pishoges', 'conglomeritic', 'boastings', 'long-bow', '

In [None]:
def contains_alpha(s):
    return any(c.isalpha() for c in s)

In [None]:
"""
Let's play a few games ourselves to see what it's like
"""

"\nLet's play a few games ourselves to see what it's like\n"

In [None]:
import random

class Hangman:
    def __init__(self, word):
        self.word = word.lower()
        self.guessed = set()
        self.incorrect_guesses = set()
        self.state = ["_" if c.isalpha() else c for c in word]
        self.max_attempts = 6

    def guess(self, letter):
        letter = letter.lower()
        if letter in self.guessed:
            return f"You already guessed '{letter}'!"

        self.guessed.add(letter)
        if letter in self.word:
            for i, c in enumerate(self.word):
                if c == letter:
                    self.state[i] = letter
            return f"Good guess! '{letter}' is in the word."
        else:
            self.incorrect_guesses.add(letter)
            return f"Sorry, '{letter}' is NOT in the word."

    def get_pattern(self):
        return " ".join(self.state)

    def is_solved(self):
        return "_" not in self.state

    def is_game_over(self):
        return self.is_solved() or len(self.incorrect_guesses) >= self.max_attempts

    def allowed_guesses(self):
        return set("abcdefghijklmnopqrstuvwxyz") - self.guessed

    def get_game_state(self):
        return f"""
Word: {self.get_pattern()}
Incorrect guesses: {', '.join(sorted(self.incorrect_guesses)) if self.incorrect_guesses else 'None'}
Attempts remaining: {self.max_attempts - len(self.incorrect_guesses)}
"""

    def display_hangman(self):
        stages = [
            """
    --------
    |      |
    |
    |
    |
    |
    -
    """,
            """
    --------
    |      |
    |      O
    |
    |
    |
    -
    """,
            """
    --------
    |      |
    |      O
    |      |
    |
    |
    -
    """,
            """
    --------
    |      |
    |      O
    |     /|
    |
    |
    -
    """,
            """
    --------
    |      |
    |      O
    |     /|\\
    |
    |
    -
    """,
            """
    --------
    |      |
    |      O
    |     /|\\
    |     /
    |
    -
    """,
            """
    --------
    |      |
    |      O
    |     /|\\
    |     / \\
    |
    -
    """
        ]
        return stages[len(self.incorrect_guesses)]

def play_hangman():
    print("Welcome to Hangman!")
    print("Please enter a list of words separated by commas:")

    word_list = val_sample

    if not word_list:
        print("No valid words entered. Using default words.")
        word_list = ["python", "hangman", "computer", "programming", "algorithm"]

    won, played = 0, 0
    word_ind = 0
    play_again = True
    while play_again and play_again < len(word_list):
        # Select a random word
        played += 1
        selected_word = word_list[word_ind]
        word_ind += 1
        game = Hangman(selected_word)

        print("\nNew game started!")
        print("Guess the word one letter at a time.")
        print(game.display_hangman())
        print(game.get_game_state())

        while not game.is_game_over():
            # Get user input
            guess = input("Enter a letter: ").strip()

            # Validate input
            if not guess or len(guess) != 1 or not guess.isalpha():
                print("Please enter a single letter.")
                continue

            # Show state before the move
            print("\nBefore your move:")
            print(game.get_game_state())

            # Make the guess
            result = game.guess(guess)
            print(result)

            # Show updated state after the move
            print("\nAfter your move:")
            print(game.display_hangman())
            print(game.get_game_state())

        # Game over
        if game.is_solved():
            print(f"Congratulations! You've guessed the word: {selected_word}")
            won += 1
        else:
            print(f"Game over! The word was: {selected_word}")

        # Ask to play again
        play_again_response = input("Do you want to play again? (yes/no): ").lower()
        play_again = play_again_response.startswith('y')
    print(f"\nYou won {won / played * 100:.2f}% of the games\n")

if __name__ == "__main__":
    play_hangman()

Welcome to Hangman!
Please enter a list of words separated by commas:

New game started!
Guess the word one letter at a time.

    --------
    |      |
    |
    |
    |
    |
    -
    

Word: _ _ _ _ _ _ _ _ _ _
Incorrect guesses: None
Attempts remaining: 6

Enter a letter: 
Please enter a single letter.
Enter a letter: a

Before your move:

Word: _ _ _ _ _ _ _ _ _ _
Incorrect guesses: None
Attempts remaining: 6

Good guess! 'a' is in the word.

After your move:

    --------
    |      |
    |
    |
    |
    |
    -
    

Word: _ _ _ _ _ _ _ _ a _
Incorrect guesses: None
Attempts remaining: 6

Enter a letter: e

Before your move:

Word: _ _ _ _ _ _ _ _ a _
Incorrect guesses: None
Attempts remaining: 6

Good guess! 'e' is in the word.

After your move:

    --------
    |      |
    |
    |
    |
    |
    -
    

Word: _ e _ _ _ _ _ _ a _
Incorrect guesses: None
Attempts remaining: 6

Enter a letter: i

Before your move:

Word: _ e _ _ _ _ _ _ a _
Incorrect guesses: None
Attempts 

In [None]:
"""
Let's see how the model plays!
"""

"\nLet's see how the model plays!\n"

In [None]:
model.eval()
state_vectors = []
num_games_played = 0
won = 0
for word in val_sample:
    Player = Hangman(word)
    lives_left = 6
    while (not Player.is_solved()) and (lives_left > 0):
        if (not contains_alpha(Player.state)) and Player.incorrect_guesses == set():
            print(f"Playing with word: {Player.word}\n")
        print(f"\nCurrent state is: {Player.state}, incorrect guesses: {Player.incorrect_guesses}")
        pattern = Player.state
        incorrect_guesses = Player.incorrect_guesses
        guess = ModelPlayer.make_move(pattern, incorrect_guesses, lives_left, device)
        print(f"Guessed letter {guess}")
        if (guess not in Player.word) or guess in Player.state or guess in Player.incorrect_guesses:
            lives_left -= 1
        state_vectors.append(ModelPlayer.state_vector)
        Player.guess(guess)

    if Player.is_solved():
        print("\nSolved!\n")
        won += 1
    else:
        print("\nFailed!\n")
    num_games_played += 1

print(f"GAMES ARE OVER! YOU SCORED {won / num_games_played * 100}%!")



Playing with word: demipillar


Current state is: ['_', '_', '_', '_', '_', '_', '_', '_', '_', '_'], incorrect guesses: set()
Guessed letter e
Guessing letter: e
e is in the word!

Current state is: ['_', 'e', '_', '_', '_', '_', '_', '_', '_', '_'], incorrect guesses: set()
Guessed letter i
Guessing letter: i
i is in the word!

Current state is: ['_', 'e', '_', 'i', '_', 'i', '_', '_', '_', '_'], incorrect guesses: set()
Guessed letter l
Guessing letter: l
l is in the word!

Current state is: ['_', 'e', '_', 'i', '_', 'i', 'l', 'l', '_', '_'], incorrect guesses: set()
Guessed letter m
Guessing letter: m
m is in the word!

Current state is: ['_', 'e', 'm', 'i', '_', 'i', 'l', 'l', '_', '_'], incorrect guesses: set()
Guessed letter s
Guessing letter: s
s is NOT in the word.

Current state is: ['_', 'e', 'm', 'i', '_', 'i', 'l', 'l', '_', '_'], incorrect guesses: {'s'}
Guessed letter n
Guessing letter: n
n is NOT in the word.

Current state is: ['_', 'e', 'm', 'i', '_', 'i', 'l', 'l', '

In [None]:
model.eval()
state_vectors = []
num_games_played = 0
won = 0
for word in ["Recurse"]:
    Player = Hangman(word)
    lives_left = 6
    while (not Player.is_solved()) and (lives_left > 0):
        if (not contains_alpha(Player.state)) and Player.incorrect_guesses == set():
            print(f"Playing with word: {Player.word}\n")
        print(f"\nCurrent state is: {Player.state}, incorrect guesses: {Player.incorrect_guesses}")
        pattern = Player.state
        incorrect_guesses = Player.incorrect_guesses
        guess = ModelPlayer.make_move(pattern, incorrect_guesses, lives_left, device)
        if (guess not in Player.word) or guess in Player.state or guess in Player.incorrect_guesses:
            lives_left -= 1
        state_vectors.append(ModelPlayer.state_vector)
        Player.guess(guess)

    if Player.is_solved():
        print("\nSolved!\n")
        won += 1
    else:
        print("\nFailed!\n")
    num_games_played += 1

print(F"GAMES ARE OVER! YOU SCORED {won / num_games_played * 100}%!")

Playing with word: recurse


Current state is: ['_', '_', '_', '_', '_', '_', '_'], incorrect guesses: set()
Guessing letter: e
e is in the word!

Current state is: ['_', 'e', '_', '_', '_', '_', 'e'], incorrect guesses: set()
Guessing letter: r
r is in the word!

Current state is: ['r', 'e', '_', '_', 'r', '_', 'e'], incorrect guesses: set()
Guessing letter: u
u is in the word!

Current state is: ['r', 'e', '_', 'u', 'r', '_', 'e'], incorrect guesses: set()
Guessing letter: s
s is in the word!

Current state is: ['r', 'e', '_', 'u', 'r', 's', 'e'], incorrect guesses: set()
Guessing letter: b
b is NOT in the word.

Current state is: ['r', 'e', '_', 'u', 'r', 's', 'e'], incorrect guesses: {'b'}
Guessing letter: d
d is NOT in the word.

Current state is: ['r', 'e', '_', 'u', 'r', 's', 'e'], incorrect guesses: {'d', 'b'}
Guessing letter: c
c is in the word!

Solved!

GAMES ARE OVER! YOU SCORED 100.0%!
