In [None]:
%pip install torch tqdm

In [None]:
import torch, random, string, math
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, Sampler
from torch import Tensor
from tqdm.auto import tqdm
import torch.nn.functional as F

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with open("wordle_words.txt", "r") as f:
    words = [line.strip() for line in f if line.strip()]

In [None]:
def tokenize(gold: str, tries: list[str]):
    '''
    The model input will be the sequence of previous tries, tokenized as follows:
    One token per letter, plus an EoS token labeled 0.
    Gray letters will be labeled 1-26, yellow letters 27-52, and green letters 53-78, for a total vocabulary size of 79.
    Example: tokenize('hello', ['sales', 'round']) = [19, 1, 64, 31, 19, 18, 41, 21, 14, 4, 0].
    Here, for example, the 'l' in 'sales' (the third token) is the correct letter in the correct position; as 'l' is the 12th letter of the alphabet, it will be assigend token 12+52=64.
    We work under the assumption that all words are 5 letters, so the tokenized sequence will always have length 1 (mod 5).
    '''
    tokenized_seq = []
    for word in tries:
        remaining = {}
        for c in gold:
            remaining[c] = remaining.get(c, 0) + 1

        marks = [0] * 5 #This will keep track of whether each position in word is gray, yellow, or green.
        for i, ch in enumerate(word):
            if ch == gold[i]:
                marks[i] = 2
                remaining[ch] -= 1

        for i, ch in enumerate(word):
            if marks[i] == 0 and remaining.get(ch, 0) > 0:
                marks[i] = 1
                remaining[ch] -= 1

        for i, ch in enumerate(word):
            base = ord(ch) - 96 #Magic number, so that the base of 'a' is 1.
            if marks[i] == 2:
                tokenized_seq.append(base + 26 * 2)
            elif marks[i] == 1:
                tokenized_seq.append(base + 26)
            else:
                tokenized_seq.append(base)

    tokenized_seq.append(0)
    return tokenized_seq

In [None]:
class WordDataset(Dataset):
    def __init__(self, words: list[str], tokenizer):
        self.words = words
        self.tokenizer = tokenizer
        self.words_by_letter = {}
        self.words_differing_by_one_letter = {}
        #A dictionary whose keys are the elements of self.words, and whose corresponding value is the list of words in self.words 
        for word in self.words:
            for letter in word:
                self.words_by_letter.setdefault(letter, []).append(word)
                #A dictionary whose keys are letters, and whose corresponding value is the list of words containing that letter.
        for word in self.words:
            self.words_differing_by_one_letter[word] = []
            for other_word in self.words:
                same_letter_count = 0
                for i in range(5):
                    if word[i] == other_word[i]:
                        same_letter_count += 1
                if same_letter_count == 4:
                    self.words_differing_by_one_letter[word].append(other_word)

    def __len__(self):
        return len(self.words) * 18 #Each word will be the gold in 18 items.

    def __getitem__(self, idx: int):
        '''
        The dataset is structured as follows: Each word is the gold in 18 items.
        The first six items will feature random guesses chosen uniformly from self.words, the number of such guesses starting at 0 and increasing by 1 each time.
        The exception to this: 25% of the time, the first guess will be a word that differs in just one letter from the gold. 2% of the time, the second guess will also be this.
        This is because, from initial testing, the model can sometimes get stuck in a loop guessing the same word that's wrong in only one letter.
        The seventh item will feature a single guess which is the correct word.
        This is so that the model can learn the actual correct spellings of words, since the output tokenization is word-level rather than letter-level.
        The eight to twelfth items will feature guesses that prioritize using a large number of unguessed letters. The intention is that these guess sequences will be high-information.
        The exception to this: If the gold has at least one uncommon letter, the first guess will be another word that has that letter.
        The final six items will feature guesses that prioritize guessing letters that *aren't* in the gold. This is so the model learns to extract information from gray letters, something it tends to struggle a bit with.
        '''
        '''
        Key feature of this dataset: The number of tries (i.e. the length of the input) depends only on idx % 18. This will be very useful in batching.
        '''
        gold_idx = idx // 18
        gold = self.words[gold_idx]
        tries = []
        idx_mod18 = idx % 18
        if idx_mod18 < 6:
            num_tries = idx_mod18
            for i in range(num_tries):
                if i == 0:
                    if random.random() > 0.75 and len(self.words_differing_by_one_letter[gold]) > 0:
                        tries.append(random.choice(self.words_differing_by_one_letter[gold]))
                    else:
                        tries.append(random.choice(self.words))
                elif i == 1:
                    if random.random() > 0.98 and len(self.words_differing_by_one_letter[gold]) > 0:
                        tries.append(random.choice(self.words_differing_by_one_letter[gold]))
                    else:
                        tries.append(random.choice(self.words))
                else:
                    tries.append(random.choice(self.words))
        elif idx_mod18 == 6:
            tries.append(gold)
        elif idx_mod18 < 12:
            rare_letters = ['f', 'h', 'v', 'w', 'y', 'k', 'j', 'x', 'q', 'z']
            rare_letters_in_gold = [char for char in rare_letters if char in gold]
            num_tries = idx_mod18 - 6

            if len(rare_letters_in_gold) == 1:
                rare_letter = rare_letters_in_gold[0]
                tries.append(random.choice(self.words_by_letter[rare_letter]))
                num_tries -= 1
            if len(rare_letters_in_gold) > 1 and num_tries > 1:
                rare_letters = random.sample(rare_letters_in_gold, 2)
                tries.append(random.choice(self.words_by_letter[rare_letters[0]]))
                tries.append(random.choice(self.words_by_letter[rare_letters[1]]))
                num_tries -= 2

            for guess in range(num_tries):
                tried_letters = []
                for word in tries:
                    tried_letters += list(word)
                untried_letters = [char for char in list(string.ascii_lowercase) if char not in tried_letters]
                new_letter = random.choice(untried_letters)

                attempts = 0
                while True:
                    new_word = random.choice(self.words_by_letter[new_letter])
                    num_new_letters = len([char for char in new_word if char in untried_letters])
                    if num_new_letters > 5 - len(tries) or attempts > 9:
                        tries.append(new_word)
                        break
                    attempts += 1

        else:
            num_tries = idx_mod18 - 12
            used_letters = [char for char in gold]

            for guess in range(num_tries):
                max_used_letters = (1 + len(tries)) // 2

                attempts = 0
                while True:
                    new_word = random.choice(self.words)
                    new_used_letters = [char for char in new_word if char in used_letters]
                    num_new_used_letters = len(new_used_letters)
                    if num_new_used_letters <= max_used_letters or attempts > 9:
                        used_letters += [char for char in new_word if char not in used_letters]
                        tries.append(new_word)
                        break
                    attempts += 1
                
            
        return self.tokenizer(gold, tries), gold_idx

In [None]:
class BucketBatchSampler(Sampler[list[int]]):
    '''
    Since there are only 6 possible input lengths in our dataset (corresponding to 0-5 prior guesses), training a PAD token for batching is unnecessary.
    Rather, we use the fact that the input length of a dataset element depends only on idx % 18 to group together items we know have the same length.
    '''
    def __init__(self, n_samples, modulo: int, batch_size):
        self.n_samples = n_samples
        self.modulo = modulo #Here modulo will be 18.
        self.batch_size = batch_size

    def __iter__(self):
        rng = random.Random()
        buckets = [list(range(r, self.n_samples, self.modulo)) for r in range(self.modulo)]
        for b in buckets:
            rng.shuffle(b)

        batches = []
        for b in buckets:
            for i in range(0, len(b), self.batch_size):
                chunk = b[i:i + self.batch_size]
                batches.append(chunk)
        rng.shuffle(batches)

        for batch in batches:
            yield batch

    def __len__(self):
        total = 0
        for r in range(self.modulo):
            count = 0
            if r < self.n_samples:
                count = ((self.n_samples - 1 - r) // self.modulo) + 1
            total += math.ceil(count / self.batch_size)
        return total        

In [None]:
def collate_fn(batch):
    data, gold = zip(*batch)
    return torch.stack([torch.as_tensor(t) for t in data]), torch.as_tensor(gold)

In [None]:
class WordlePlayer(nn.Module):
    '''
    A pretty standard transformer model. 6 layers, dimension 128, 8 attention heads. Outputs logits with length equal to self.num_words, from which the final word choice is chosen.
    '''
    def __init__(self, tokenizer, num_words: int):
        super().__init__()
        self.num_words = num_words
        self.tokenize = tokenizer
        
        vocab_size = 79
        d_model = 128
        num_heads = 8
        num_layers = 6

        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(6, d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=num_heads, batch_first=True, activation="gelu", norm_first=True)
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.cls_head = nn.Linear(d_model, self.num_words)

        self._reset_parameters()

    def _reset_parameters(self):
        nn.init.xavier_uniform_(self.tok_emb.weight)
        nn.init.xavier_uniform_(self.pos_emb.weight)
        nn.init.xavier_uniform_(self.cls_head.weight)
        nn.init.zeros_(self.cls_head.bias)

    def next_guess(self, correct_word: str, tries_so_far: list[str], topP: float=1.0):
        #We implement topP to avoid particularly unlikely guesses.
        #TopP here is a better choice than topK because there's generally going to be a large number of viable early guesses, so this helps keep that variety.
        tokenized_tries = self.tokenize(correct_word, tries_so_far)
        x = torch.as_tensor(tokenized_tries, device=device)
        logits = self.forward(torch.as_tensor(tokenized_tries).unsqueeze(0))

        if 0.0 < topP < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            sorted_probs = F.softmax(sorted_logits, dim=-1)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)

            to_remove = cumulative_probs > topP
            to_remove[..., 1:] = to_remove[..., :-1].clone()
            to_remove[..., 0] = False

            sorted_logits = sorted_logits.masked_fill(to_remove, float('-inf'))
            filtered_logits = torch.full_like(logits, float('-inf'))
            filtered_logits.scatter_(1, sorted_indices, sorted_logits)
            logits = filtered_logits
        
        probs = F.softmax(logits, dim=-1)
        choice = torch.multinomial(probs, num_samples=1)

        return words[choice]

    def forward(self, x: Tensor) -> Tensor:
        x = x.to(device)
        _, L = x.shape

        '''
        Now here's a great little feature of this model:
        Because the order of the guesses doesn't actually matter, all that matters per-token is the position of that letter within its guess (i.e. is it the first letter, the second, etc.)
        So we actually only need *five* positional embeddings (plus one for the EoS token), which repeat cyclically.
        The great thing about this is not only do we have fewer parameters to train, but also we can expect the model to run reasonably well on inputs larger than those on which it's trained.
        The training data only features up to 5 previous guesses like real Wordle, but with these periodic embeddings, we can actually run the model on an indefinite number of prior guesses.
        This means we can measure how long it takes to guess the correct word, even if that runs longer than the 6 guesses in our training data.
        '''
        pos = [0, 1, 2, 3, 4] * ((L - 1) // 5)
        pos += [5]
        pos = torch.as_tensor(pos).to(device)
        h = self.tok_emb(x) + self.pos_emb(pos)
        h = self.encoder(h)

        eos_repr = h[:, -1, :]
        logits = self.cls_head(eos_repr)
        return logits

In [None]:
NUM_BATCHES = 16

model = WordlePlayer(tokenize, len(words)).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.0)
warmup_steps = 100_000
#With 16 batches, one epoch is ~15-16k steps (a bit more steps than the number of words, which here is just under 15k).
warmup = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lambda s: min(1.0, (s + 1) / warmup_steps))
cosine = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=60_000, T_mult=1, eta_min=5e-5)
#This is the scheduler that worked decently well for me.
train_ds = WordDataset(words, tokenize)
batch_sampler = BucketBatchSampler(n_samples=len(train_ds), modulo=18, batch_size=NUM_BATCHES)
loader = DataLoader(train_ds, batch_sampler=batch_sampler, num_workers=0, collate_fn=collate_fn)

In [None]:
NUM_EPOCHS = 1000
#I just put a big number here and stopped it when it seemed to be plateauing. It didn't really matter with this scheduler.

global_step = 0
for epoch in range(NUM_EPOCHS):
    epoch_loss = 0.0
    progress = tqdm(loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    step_count = 0

    for step, (tries, gold) in enumerate(progress):
        tries = tries.to(device)
        gold = gold.to(device)
        
        model_logits = model(tries)
        loss = F.cross_entropy(model_logits, gold)
        epoch_loss += loss.item()
        progress.set_postfix(loss=loss.item())

        loss.backward()
        opt.step()
        opt.zero_grad(set_to_none=True)

        if global_step < warmup_steps:
            warmup.step()
        else:
            cosine.step(global_step - warmup_steps)

        global_step += 1

    avg_loss = epoch_loss / len(loader)
    print(f"Epoch {epoch+1} | mean loss {avg_loss:.4f}")

    ckpt_path = f"epoch_{epoch:02d}.pt"
    torch.save(
        {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "avg_loss": avg_loss
        },
        ckpt_path
    )
    #The other thing with this periodic warm reset scheduler is that since the quality of the model is pretty oscillatory, it's best to just save at the end of every epoch and then check them all at the end.

In [None]:
def guesses_until_correct(model, gold: str, max_guesses: int):
    tries = []
    num_tries = 0
    while gold not in tries:
        num_tries += 1
        next_guess = model.next_guess(gold, tries, topP=0.95)
        tries.append(next_guess)
        if num_tries > max_guesses:
            return tries
    if gold in tries:
        return tries

In [None]:
def evaluate_model(model, word_list, num_trials=1000):
    total_guesses = 0
    total_successes = 0
    for j in tqdm(range(num_trials)):
        word = random.choice(word_list)
        num_guesses = len(guesses_until_correct(model, word, max_guesses = 100))
        total_guesses += num_guesses
        if num_guesses <= 6:
            total_successes += 1

    avg_guesses = total_guesses / num_trials
    success_rate = total_successes / num_trials
    return avg_guesses, success_rate

In [None]:
with torch.no_grad():
    for i in range(150): #Formatted since this way since this is how checkpoints are saved as during the training loop.
        if i == 0:
            ckpt_path = "epoch_00.pt"
        elif i < 10:
            ckpt_path = f"epoch_0{i}.pt"
        else:
            ckpt_path = f"epoch_{i}.pt"
        checkpoint = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(checkpoint["model_state_dict"])
        model.to(device)
        
        avg_guesses, success_rate = evaluate_model(model, words, num_trials=1000)
        
        print(f"Epoch: {i}")
        print(f"Average number of guesses: {avg_guesses}")
        print(f"Success rate: {success_rate}")