In [None]:
%pip install torch tqdm

In [None]:
import torch, random, string, math
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from torch import Tensor
from tqdm.auto import tqdm
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("wordle_words.txt", "r") as f:
    words = [line.strip() for line in f if line.strip()]

In [None]:
def tokenize(gold: str, tries: list[str]):
    '''
    The model input will be the sequence of previous tries, tokenized as follows:
    One token per letter, plus an EoS token labeled 0.
    Gray letters will be labeled 1-26, yellow letters 27-52, and green letters 53-78, for a total vocabulary size of 79.
    Example: tokenize('hello', ['sales', 'round']) = [19, 1, 64, 31, 19, 18, 41, 21, 14, 4, 0].
    Here, for example, the 'l' in 'sales' (the third token) is the correct letter in the correct position; as 'l' is the 12th letter of the alphabet, it will be assigend token 12+52=64.
    We work under the assumption that all words are 5 letters, so the tokenized sequence will always have length 1 (mod 5).
    '''
    tokenized_seq = []
    for word in tries:
        remaining = {}
        for c in gold:
            remaining[c] = remaining.get(c, 0) + 1

        marks = [0] * 5 #This will keep track of whether each position in word is gray, yellow, or green.
        for i, ch in enumerate(word):
            if ch == gold[i]:
                marks[i] = 2
                remaining[ch] -= 1

        for i, ch in enumerate(word):
            if marks[i] == 0 and remaining.get(ch, 0) > 0:
                marks[i] = 1
                remaining[ch] -= 1

        for i, ch in enumerate(word):
            base = ord(ch) - 96 #Magic number, so that the base of 'a' is 1.
            if marks[i] == 2:
                tokenized_seq.append(base + 26 * 2)
            elif marks[i] == 1:
                tokenized_seq.append(base + 26)
            else:
                tokenized_seq.append(base)

    tokenized_seq.append(0)
    return tokenized_seq

In [None]:
class WordlePlayer(nn.Module):
    '''
    A pretty standard transformer model. 6 layers, dimension 128, 8 attention heads. Outputs logits with length equal to self.num_words, from which the final word choice is chosen.
    '''
    def __init__(self, tokenizer, num_words: int):
        super().__init__()
        self.num_words = num_words
        self.tokenize = tokenizer
        
        vocab_size = 79
        d_model = 128
        num_heads = 8
        num_layers = 6

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(6, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls_head = nn.Linear(d_model, self.num_words)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.tok_emb.weight)
        nn.init.xavier_uniform_(self.pos_emb.weight)
        nn.init.xavier_uniform_(self.cls_head.weight)
        nn.init.zeros_(self.cls_head.bias)

    def next_guess(self, correct_word: str, tries_so_far: list[str], topP: float=1.0):
        #We implement topP to avoid particularly unlikely guesses.
        #TopP here is a better choice than topK because there's generally going to be a large number of viable early guesses, so this helps keep that variety.
        tokenized_tries = self.tokenize(correct_word, tries_so_far)
        x = torch.as_tensor(tokenized_tries, device=device)
        logits = self.forward(torch.as_tensor(tokenized_tries).unsqueeze(0))

        if 0.0 < topP < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            sorted_probs = F.softmax(sorted_logits, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            to_remove = cumulative_probs > topP
            to_remove[..., 1:] = to_remove[..., :-1].clone()
            to_remove[..., 0] = False

            sorted_logits = sorted_logits.masked_fill(to_remove, float('-inf'))
            filtered_logits = torch.full_like(logits, float('-inf'))
            filtered_logits.scatter_(1, sorted_indices, sorted_logits)
            logits = filtered_logits
        
        probs = F.softmax(logits, dim=-1)
        choice = torch.multinomial(probs, num_samples=1)

        return words[choice]

    def forward(self, x: Tensor) -> Tensor:
        x = x.to(device)
        _, L = x.shape

        '''
        Now here's a great little feature of this model:
        Because the order of the guesses doesn't actually matter, all that matters per-token is the position of that letter within its guess (i.e. is it the first letter, the second, etc.)
        So we actually only need *five* positional embeddings (plus one for the EoS token), which repeat cyclically.
        The great thing about this is not only do we have fewer parameters to train, but also we can expect the model to run reasonably well on inputs larger than those on which it's trained.
        The training data only features up to 5 previous guesses like real Wordle, but with these periodic embeddings, we can actually run the model on an indefinite number of prior guesses.
        This means we can measure how long it takes to guess the correct word, even if that runs longer than the 6 guesses in our training data.
        '''
        pos = [0, 1, 2, 3, 4] * ((L - 1) // 5)
        pos += [5]
        pos = torch.as_tensor(pos).to(device)
        h = self.tok_emb(x) + self.pos_emb(pos)
        h = self.encoder(h)

        eos_repr = h[:, -1, :]
        logits = self.cls_head(eos_repr)
        return logits

In [None]:
model = WordlePlayer(tokenize, len(words)).to(device)
ckpt_path = "models/phase2.pt"
with torch.no_grad():
    checkpoint = torch.load(ckpt_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])
    model.to(device)

In [None]:
def guesses_until_correct(model, gold: str, max_guesses: int):
    tries = []
    num_tries = 0
    while gold not in tries:
        num_tries += 1
        next_guess = model.next_guess(gold, tries, topP=0.95)
        tries.append(next_guess)
        if num_tries > max_guesses:
            return tries
    if gold in tries:
        return tries

In [None]:
correct_word = "hello"
max_guesses = 100
guesses = guesses_until_correct(model, correct_word, max_guesses=max_guesses)
for guess in guesses:
    print(guess)
if len(guesses) == max_guesses:
    print("Timed out!")
else:
    print(f"Guessed in {len(guesses)} tries!")

In [None]:
def evaluate_model(model, word_list, num_trials=1000):
    total_guesses = 0
    total_successes = 0
    for j in tqdm(range(num_trials)):
        word = random.choice(word_list)
        num_guesses = len(guesses_until_correct(model, word, max_guesses = 100))
        total_guesses += num_guesses
        if num_guesses <= 6:
            total_successes += 1

    avg_guesses = total_guesses / num_trials
    success_rate = total_successes / num_trials
    return avg_guesses, success_rate

In [None]:
avg_guesses, success_rate = evaluate_model(model, word_list=words, num_trials=1000)
print(f"Average number of guesses: {avg_guesses}")
print(f"Success rate: {success_rate}")