In [1]:
import numpy as np
from tqdm import tqdm
import numba
from numba import njit
from jax import numpy as jnp
import jax
from jax import vmap, pmap
from jax import jit as jjit



In [20]:
import string
import pathlib

path = "/Users/isaacbreen/Documents/Projects/WordleSolver/cpp/src"
# Save to cpp/src/data/wordlist.hpp and cpp/src/data/guesslist.hpp
RAW_DATA_FOLDER = pathlib.Path(path).parent.parent / "data"
CPP_DATA_FOLDER = pathlib.Path(path).parent.parent / "cpp" / "src" / "data"
USE_CONSTEXPR = False

nw = None
ng = None

# Import words from solutions_nyt.txt and nonsolutions_nyt.txt. They are quoted (e.g. "soare") and separated by ", ".
with open(RAW_DATA_FOLDER / "solutions_nyt.txt", "r") as f:
    # Remove all spaces and quotes
    words = f.read().replace(" ", "").replace('"', "").split(",")[:nw]
    words = sorted(words)
with open(RAW_DATA_FOLDER / "nonsolutions_nyt.txt", "r") as f:
    guesses = f.read().replace(" ", "").replace('"', "").split(',')[:ng]
    guesses = sorted(guesses)
    
guesses = words + guesses
# Ensure each guess is unique
assert len(set(guesses)) == len(guesses)
print(f"There are {len(words)} words and {len(guesses)} guesses.")
# Print the first 10 words and guesses with a ... to indicate that there are more
print(f"wordlist  = [{', '.join(sorted(words)[:10])} ...]")
print(f"guesslist = [{', '.join(sorted(guesses)[:10])} ...]")
# Convert to int8 numpy arrays
letters = {c: i for i, c in enumerate(string.ascii_lowercase)}
def str_to_int8(s):
    return jnp.array([letters[c] for c in s], dtype=np.int8)
def int8_to_str(a):
    return "".join(string.ascii_lowercase[i] for i in a)
words = jnp.array([str_to_int8(w) for w in words], dtype=np.int8)
guesses = jnp.array([str_to_int8(g) for g in guesses], dtype=np.int8)

There are 2309 words and 12947 guesses.
wordlist  = [aback, abase, abate, abbey, abbot, abhor, abide, abled, abode, abort ...]
guesslist = [aahed, aalii, aargh, aarti, abaca, abaci, aback, abacs, abaft, abaka ...]


In [21]:
WORD_LENGTH = 5
NUM_HINTS = 3**WORD_LENGTH
NUM_WORDS = len(words)
NUM_GUESSES = len(guesses)
MAX_TURNS = 6

FULL_WORDLIST = jnp.ones((NUM_WORDS, ), dtype=np.bool_)

USE_JAX = True
jit = jjit if USE_JAX else njit
if USE_JAX:
    # Convert to jax arrays
    words = jnp.array(words)
    guesses = jnp.array(guesses)
else:
    # Convert to numpy arrays
    words = np.array(words) 
    guesses = np.array(guesses)

@jit
def make_hint(word, guess):
    if not USE_JAX:
        hint = 0
        for i in range(WORD_LENGTH):
            if guess[i] == word[i]:
                hint += 3**i*2
        for i in range(WORD_LENGTH):
            for j in range(WORD_LENGTH):
                if guess[i] == word[j] and guess[i] != word[i] and guess[j] != word[j]:
                    hint += 3**i
                    break
        return hint
    else:
        # JAX version
        hint = (guess == word) * (3**jnp.arange(WORD_LENGTH)) * 2
        hint += jnp.any((guess != word)[jnp.newaxis, :] * (guess != word)[:, jnp.newaxis] * (guess[:, jnp.newaxis] == word[jnp.newaxis, :]), axis=1) * (3**jnp.arange(WORD_LENGTH))
        return hint.sum()
    
def hint_to_str(hint):
    s = ""
    for i in range(WORD_LENGTH):
        s += str(hint % 3)
        hint //= 3
    return s

@jit
def _precompute_hints_helper(word, guesses):
    if not USE_JAX:
        hints = np.zeros((NUM_GUESSES,), dtype=np.int8)
        for g in range(NUM_GUESSES):
            hints[g] = make_hint(word, guesses[g])
        return hints
    else:
        # JAX version
        return vmap(make_hint, in_axes=(None, 0))(word, guesses)

# Precompute the hints for each (word, guess) pair and store them in a numpy array
def precompute_hints(words, guesses):
    if not USE_JAX:
        hints = np.zeros((NUM_WORDS, NUM_GUESSES), dtype=np.int8)
        for w in tqdm(range(NUM_WORDS)):
            hints[w, :] = _precompute_hints_helper(words[w], guesses)
        return hints
    else:
        # JAX version
        return vmap(_precompute_hints_helper, (0, None))(words, guesses)
    
@jit
def get_hint(word_index, guess_index, hints):
    """
    Returns the hint for the given word and guess.
    """
    return hints[word_index, guess_index]

hints = precompute_hints(words, guesses) # (NUM_WORDS, NUM_GUESSES)

In [23]:
print(int8_to_str(words[0]))
print(int8_to_str(guesses[20]))
print(hint_to_str(make_hint(words[0], guesses[20])))
assert hint_to_str(make_hint(words[0], guesses[20])) == "20000"

aback
adept
20000


In [24]:
@jit
def _compatible_word_guess_hint_triple(word_index, guess_index, hint_index, hints):
    """
    Returns true if the given word, guess, and hint are compatible.
    """
    return get_hint(word_index, guess_index, hints) != hint_index

@jit
def _precompute_compatible_word_guess_hint_triples_helper(g, hints):
    compatible_word_guess_hint_triples = np.zeros((NUM_HINTS, NUM_WORDS), dtype=np.bool_)
    for h in range(NUM_HINTS):
        for w in range(NUM_WORDS):
            compatible_word_guess_hint_triples[h, w] = _compatible_word_guess_hint_triple(w, g, h, hints)

def precompute_compatible_word_guess_hint_triples(hints):
    hints = np.array(hints)
    # Precompute compatible word-guess-hint triplets and store them in a numpy array with shape (NUM_HINTS, len(guesses), len(words))
    compatible_word_guess_hint_triples = np.zeros((len(guesses), NUM_HINTS, len(words)), dtype=bool)
    for g in tqdm(range(NUM_GUESSES)):
        compatible_word_guess_hint_triples[g] = _precompute_compatible_word_guess_hint_triples_helper(g, hints)
    return compatible_word_guess_hint_triples

def precompute_compatible_word_guess_hint_triples_jax(hints):
    # JAX version
    guess_indices = jnp.arange(NUM_GUESSES)[:, jnp.newaxis, jnp.newaxis]
    hint_indices = jnp.arange(NUM_HINTS)[jnp.newaxis, :, jnp.newaxis]
    word_indices = jnp.arange(NUM_WORDS)[jnp.newaxis, jnp.newaxis, :]
    return hints[word_indices, guess_indices] == hint_indices
    
compatibilities = precompute_compatible_word_guess_hint_triples_jax(hints) # (NUM_GUESSES, NUM_HINTS, NUM_WORDS)

@jit
def compatible_word_guess_hint_triple(word_index, guess_index, hint_index, compatibilities):
    """
    Returns true if the given word, guess, and hint are compatible.
    """
    return compatibilities[hint_index, guess_index, word_index]

@jit
def get_compatible_words(guess_index, hint_index, compatibilities):
    """
    Returns a boolean array of the words that are compatible with the given guess and hint.
    """
    return compatibilities[guess_index, hint_index]

In [27]:
num_words_after_correct_guess = get_compatible_words(0, 242, compatibilities).sum()
assert num_words_after_correct_guess == 1, f"Expected 1, got {num_words_after_correct_guess}"

guess = 100
hint = 16
print(int8_to_str(guesses[guess]))
num_words_left = get_compatible_words(guess, hint, compatibilities).sum()
assert num_words_left == 12, f"Expected 1, got {num_words_left}"

arbor


In [28]:
@jit
def get_hint_counts_jax(guess_index, wordlist, hints):
    """
    Returns the number of times each hint is appears for the given guess as a dense array.
    """
    return jnp.bincount((hints[:, guess_index] + 1) * wordlist, length=NUM_HINTS+1)[1:]

@jit
def get_next_wordlists_and_counts_jax(guess_index, wordlist, hints, compatibilities):
    """
    Returns all possible wordlists and their counts following the given guess.
    """
    hint_counts = get_hint_counts_jax(guess_index, wordlist, hints)
    wordlists = wordlist[jnp.newaxis, :] & compatibilities[guess_index, :, :]
    return wordlists, hint_counts

def get_next_wordlists_and_counts_sparse(guess_index, wordlist, hints, compatibilities):
    """
    Returns all possible wordlists and their counts following the given guess with wordlists that never occur removed
    """
    hint_counts = get_hint_counts_jax(guess_index, wordlist, hints)
    wordlists = wordlist[jnp.newaxis, :] & compatibilities[guess_index, hint_counts>0, :]
    hint_indices = jnp.where(hint_counts>0)[0]
    return wordlists, hint_counts[hint_counts>0], hint_indices

In [45]:
guess = 1001
hint = 0
wordlist = get_compatible_words(guess, hint, compatibilities)
guess = 1002
hint = 0
wordlist = wordlist & get_compatible_words(guess, hint, compatibilities)
print(wordlist.sum())
guess = 1003
print(get_hint_counts_jax(guess, wordlist, hints).shape)
next_wordlists, counts = get_next_wordlists_and_counts_jax(guess, wordlist, hints, compatibilities)
print(next_wordlists.shape, counts.shape)
next_wordlists_sparse, counts_sparse, hint_indices = get_next_wordlists_and_counts_sparse(guess, wordlist, hints, compatibilities)
print(f"hint_indices: {hint_indices}")
print(next_wordlists_sparse.shape, counts_sparse.shape)
print(next_wordlists_sparse.sum(axis=1))
# print(next_wordlists.sum(axis=1))
print(counts_sparse)
assert next_wordlists_sparse.shape[0] == (counts>0).sum()
assert (next_wordlists_sparse == next_wordlists[counts>0]).all()
next_wordlists_sparse.shape
next_wordlists_sparse.sum(axis=1)
# wordlist.shape

848
(243,)
(243, 2309) (243,)
hint_indices: [ 0  9 18 27 36 45 54 63 72]
(9, 2309) (9,)
[473 166  43  88  23  10  36   4   5]
[473 166  43  88  23  10  36   4   5]


DeviceArray([473, 166,  43,  88,  23,  10,  36,   4,   5], dtype=int32)

In [31]:
next_wordlists_sparse

DeviceArray([[False, False,  True, ...,  True, False, False],
             [False, False, False, ..., False, False, False],
             [False, False, False, ..., False, False,  True],
             ...,
             [ True, False, False, ..., False, False, False],
             [False, False, False, ..., False, False, False],
             [False, False, False, ..., False, False, False]], dtype=bool)

In [11]:
def calculate_expectation_breadth_first_yield(f, values, probabilities, *args, **kwargs):
    """
    Calculates the expectation of f(values) given the probabilities.
    """
    iterators = [(p, f(x, *args, **kwargs)) for x, p in zip(values, probabilities)]
    assert np.isclose(sum(p for p, _ in iterators), 1), f"Expected sum of probabilities to be 1, got {sum(p for p, _ in iterators)}"
    while iterators:
        for i in reversed(range(len(iterators))):
            (p, it) = iterators[i]
            try:
                dx = next(it)
                assert dx>=0, f"Expected non-negative result, got {dx}"
                yield p*dx
            except StopIteration:
                iterators.pop(i)
            
def min_breadth_first_yield(f, values, *args, **kwargs):
    """
    Finds the value that maximizes f(value, *args, **kwargs). Repeatedly yields candidate values.
    The last value yielded is the minimum.
    """
    iterators = [(i, f(x, *args, **kwargs)) for i, x in enumerate(values)]
    results = [0] * len(values)
    prev_yielded = 0
    upper_bound = np.inf # The lowest value of a completed iterator so far
    while iterators:
        smallest_result_this_iteration = None
        for j in reversed(range(len(iterators))):
            (i, it) = iterators[j]
            try:
                result = next(it)
                assert result >= 0, f"Expected non-negative result, got {result}"
                results[i] += result
                if smallest_result_this_iteration is None or results[i] < smallest_result_this_iteration:
                    smallest_result_this_iteration = results[i]
                if results[i] >= upper_bound:
                    iterators.pop(j)
            except StopIteration:
                iterators.pop(j)
                if results[i] < upper_bound:
                    upper_bound = results[i]
        if smallest_result_this_iteration is not None:
            yield smallest_result_this_iteration - prev_yielded
            prev_yielded = smallest_result_this_iteration

def calculate_expectation_jax(f, values, probabilities, *args, **kwargs):
    """
    Calculates the expectation of f(values) given the probabilities.
    """
    return jnp.sum(probabilities * vmap(f, in_axes=0)(values, *args, **kwargs))

def _calculate_EMTW_given_guess(guess, wordlist, num_words, turn):
    next_wordlists, counts, hint_indices = get_next_wordlists_and_counts_sparse(guess, wordlist, hints, compatibilities)
    next_num_words = next_wordlists.sum(axis=1)
    useless = next_num_words == num_words
    num_words -= counts[useless].sum()
    if num_words > 0:
        next_wordlists = next_wordlists[~useless]
        counts = counts[~useless]
        probabilities = counts / num_words
        assert np.isclose(probabilities.sum(), 1), f"{probabilities.sum()} != 1"
        yield from calculate_expectation_breadth_first_yield(calculate_EMTW, next_wordlists, probabilities, turn=turn)

PRINT_TURNS = 0
MAX_TURNS = 6
def calculate_EMTW(wordlist, turn=0):
    """
    Calculates the expected minimum time to win for the given wordlist and the optimal strategy.
    """
    if turn <= PRINT_TURNS:
        print(f"Turn {turn} with {wordlist.sum()} words remaining")
    num_words = wordlist.sum()
    if turn >= MAX_TURNS:
        yield jnp.inf
    elif num_words == 0:
        raise ValueError("No words left in wordlist")
    elif num_words == 1:
        yield 1
    else:
        yield 1
        yield from min_breadth_first_yield(_calculate_EMTW_given_guess, range(NUM_GUESSES), wordlist=wordlist, num_words=num_words, turn=turn+1)

word = 6
wordlist = FULL_WORDLIST
# guess = 0
# hint = get_hint(word, guess, hints)
# wordlist = wordlist & get_compatible_words(guess, hint, compatibilities)
# guess = 1
# hint = get_hint(word, guess, hints)
# wordlist = wordlist & get_compatible_words(guess, hint, compatibilities)
# guess = 5
# hint = get_hint(word, guess, hints)
# wordlist = wordlist & get_compatible_words(guess, hint, compatibilities)
print(wordlist.sum())

EMTW = 0
for x in calculate_EMTW(wordlist):
    EMTW += x
    print(EMTW, x)

2309
Turn 0 with 2309 words remaining
1 1
1.0004331 0.00043308793
1.0008662 0.00043308793
1.0012993 0.00043308793
1.0017323 0.00043308793
1.0021654 0.00043308781
1.0025985 0.00043308781
1.0030316 0.00043308781
1.0034647 0.00043308781
1.0038978 0.00043308781
1.0043309 0.00043308781
1.004764 0.00043308781
1.0056301 0.00086617563
1.0060632 0.00043308828
1.0069294 0.00086617563
1.0077956 0.00086617563
1.0082287 0.00043308828
1.0086617 0.00043308828
1.009961 0.0012992639


KeyboardInterrupt: 

In [36]:
hint

DeviceArray(3, dtype=int32)

In [13]:
guesses[6]

DeviceArray([ 5, 14,  2,  0, 11], dtype=int8)

In [14]:
words[6]

DeviceArray([ 5, 14,  2,  0, 11], dtype=int8)

In [15]:
def words_remaining_after_guesses(word, *guesses):
    """
    Calculates the number of turns it takes to win the game with the given word and guesses.
    """
    wordlist = get_compatible_words(guesses[0], get_hint(word, guesses[0], hints), compatibilities)
    for guess in guesses[1:]:
        wordlist = wordlist & get_compatible_words(guess, get_hint(word, guess, hints), compatibilities)
    return wordlist

def num_words_remaining_after_guesses(word, *guesses):
    return words_remaining_after_guesses(word, *guesses).sum()

def num_hint_configs_for_last_guess(word, *guesses):
    """
    Calculates the number of hint configurations that can appear after the last guess.
    """
    if len(guesses) == 1:
        wordlist = np.ones(NUM_WORDS, dtype=np.bool_)
    else:
        wordlist = words_remaining_after_guesses(word, *guesses[:-1])
    return (get_hint_counts_jax(guesses[-1], wordlist, hints) > 0).sum()
        
print(num_words_remaining_after_guesses(0, 10,20,30,40))
# Estimate the number of words remaining after the 1 to 6 guess by randomly sampling words and guesses
N = 1000
expected_cumulative_steps_to_store = 1
for turns in range(1,7):
    expected_words_remaining = 0
    expected_hint_configs = 0
    for i in range(N):
        word = np.random.randint(NUM_WORDS)
        guesses = np.random.randint(NUM_GUESSES, size=turns)
        expected_words_remaining += num_words_remaining_after_guesses(word, *guesses)
        expected_hint_configs += num_hint_configs_for_last_guess(word, *guesses)
    expected_cumulative_steps_to_store *= NUM_GUESSES * expected_hint_configs / N
    print(f"Turns: {turns}, Expected words remaining: {expected_words_remaining/N}, Expected hint configs: {expected_hint_configs/N}, Expected cumulative steps to store: {expected_cumulative_steps_to_store}")

2
Turns: 1, Expected words remaining: 217.58099365234375, Expected hint configs: 84.65799713134766, Expected cumulative steps to store: 1096067.125
Turns: 2, Expected words remaining: 37.70100021362305, Expected hint configs: 19.941999435424805, Expected cumulative steps to store: 282992574464.0
Turns: 3, Expected words remaining: 11.404999732971191, Expected hint configs: 6.3420000076293945, Expected cumulative steps to store: 2.323648388844749e+16
Turns: 4, Expected words remaining: 4.7870001792907715, Expected hint configs: 3.0280001163482666, Expected cumulative steps to store: 9.109518927212524e+20
Turns: 5, Expected words remaining: 2.434999942779541, Expected hint configs: 1.8480000495910645, Expected cumulative steps to store: 2.179548759418847e+25
Turns: 6, Expected words remaining: 1.9170000553131104, Expected hint configs: 1.4229999780654907, Expected cumulative steps to store: 4.015509067298246e+29


In [None]:
compatibilities[guess, 0, :]

DeviceArray([False,  True, False, ..., False,  True, False], dtype=bool)

In [28]:
guess_indices = jnp.arange(NUM_GUESSES)[:, jnp.newaxis, jnp.newaxis]
hint_indices = jnp.arange(NUM_HINTS)[jnp.newaxis, :, jnp.newaxis]
word_indices = jnp.arange(NUM_WORDS)[jnp.newaxis, jnp.newaxis, :]
(hints[word_indices, guess_indices] != hint_indices).shape


(12947, 243, 2309)

In [21]:
NUM_WORDS = 10
wordle_update = Function("wordle_update", BitVecSort(NUM_WORDS), BitVecSort(NUM_WORDS))
