In [72]:
from enum import Enum
from itertools import product
import logging
import random

import pandas as pd
import numpy as np
from tabulate import tabulate
from termcolor import colored
from tqdm import tqdm



MIN_WORD_LEN = 3

In [70]:

def get_logger():
    logger = logging.getLogger(__name__)
    logger.setLevel(logging.ERROR)
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.ERROR)
    formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)
    return logger

logger = get_logger()

In [55]:
class Direction(Enum):
    NONE = 0
    ACROSS = 1
    DOWN = 2
    INVALID = 4

In [56]:
class WordGrid:
    def __init__(self, shape) -> None:
        self.puzzle = np.chararray(shape)
        self.puzzle[:] = '-'
        self.shape = np.array(self.puzzle.shape)
        self.state = np.zeros(shape, dtype=np.byte)
        self.state[:] = 0

    def __str__(self) -> str:
        return str(tabulate(self.puzzle, tablefmt="plain"))
    
    def __repr__(self) -> str:
        return str(self)
    
    def reset(self) -> None:
        self.puzzle[:] = '-'
        self.state[:] = 0
    
    def add_word(self, position: tuple, direction: Direction, word: str) -> bool:
        y, x = position
        if direction == Direction.DOWN:
            if len(word) + x > self.puzzle.shape[0]:
                logger.warning(f"Cannot place word of length {len(word)}, '{word}' at {position}")
                return False
            
            if (self.state[x:x + len(word), y] & Direction.DOWN.value).any():
                logger.warning(f"Word overlap detected while trying to place '{word}' at {position}")
                return False
            
            if x - 1 > 0 and self.state[x - 1, y] != Direction.NONE.value :
                logger.warning(f"Word interference detected while trying to place '{word}' at {position}")
                return False
            
            if x + len(word) + 1 < self.shape[0] and self.state[x + len(word) + 1, y] != Direction.NONE.value:
                logger.warning(f"Word interference detected while trying to place '{word}' at {position}")
                return False
            
            self.puzzle[x:x + len(word), y] = list(word.lower())
            self.state[x:x + len(word), y] |= Direction.DOWN.value

        else:
            if len(word) + y > self.puzzle.shape[1]:
                logger.warning(f"Cannot place word of length {len(word)}, '{word}' at {position}")
                return False
            
            if (self.state[x, y:y + len(word)] & Direction.ACROSS.value).any():
                logger.warning(f"Word overlap detected while trying to place '{word}' at {position}")
                return False
            
            if y - 1 > 0 and self.state[x,y - 1]  != Direction.NONE.value:
                logger.warning(f"Word interference detected while trying to place '{word}' at {position}")
                return False
            
            if y + len(word) + 1 < self.shape[1] and self.state[x, y + len(word) + 1] != Direction.NONE.value:
                logger.warning(f"Word interference detected while trying to place '{word}' at {position}")
                return False
            
            self.puzzle[x, y:y + len(word)] = list(word.lower())
            self.state[x, y:y + len(word)] |= Direction.ACROSS.value
            
        return True
    
    def get_letters(self, position: tuple, direction: Direction, length: int):
        y, x = position
        letters = []
        if direction == Direction.DOWN:
            for i, letter in enumerate(self.puzzle[x:x + length, y]):
                if letter == b'-':
                    continue
                letters.append((i, letter.decode()))
        else:
            for i, letter in enumerate(self.puzzle[x, y:y + length]):
                if letter == b'-':
                    continue
                letters.append((i, letter.decode()))
        
        return letters
        
    
    def get_letter(self, position: tuple):
        return self.puzzle[position[1], position[0]].decode()
        

In [57]:
puzzle = WordGrid((5,10))
puzzle

-  -  -  -  -  -  -  -  -  -
-  -  -  -  -  -  -  -  -  -
-  -  -  -  -  -  -  -  -  -
-  -  -  -  -  -  -  -  -  -
-  -  -  -  -  -  -  -  -  -

In [58]:
word_index = pd.read_csv("data/word_index.csv")
dictionary = word_index[word_index["lang_code"] == "en"]
dictionary = dictionary[dictionary["len"] >= MIN_WORD_LEN]
dictionary = dictionary[dictionary["len"] <= max(puzzle.shape)]

In [59]:
while not puzzle.add_word((1,2), Direction.DOWN, dictionary["word"].sample(1).item()):
    pass
print(puzzle)
puzzle.reset()



-  -  -  -  -  -  -  -  -  -
-  -  -  -  -  -  -  -  -  -
-  e  -  -  -  -  -  -  -  -
-  a  -  -  -  -  -  -  -  -
-  m  -  -  -  -  -  -  -  -


In [60]:
def get_candidates(puzzle: WordGrid, position: tuple, direction: Direction, blacklist: list):
    if direction == Direction.DOWN:
        max_len = puzzle.shape[0] - position[1]
    else:
        max_len = puzzle.shape[1] - position[0]
    candidates = dictionary[dictionary["len"] <= max_len]
    
    candidates = candidates[~candidates["word"].isin(blacklist)]
        
    letters = puzzle.get_letters(position, direction, max_len)
    for index, letter in letters:
        candidates = candidates[candidates["word"].str[index] == letter]
        
    return candidates
    

In [75]:
n = 0
direction = Direction.DOWN
word_list = []
puzzle.reset()
positions = {
    Direction.DOWN: {pos: [] for pos in product(range(puzzle.shape[1] - MIN_WORD_LEN), range(puzzle.shape[0] - MIN_WORD_LEN))},
    Direction.ACROSS: {pos: [] for pos in product(range(puzzle.shape[1] - MIN_WORD_LEN), range(puzzle.shape[0] - MIN_WORD_LEN))}
}

pbar = tqdm()
while n < 12:
    if direction == Direction.DOWN and len(positions[Direction.DOWN]) == 0:
        direction = Direction.ACROSS
    elif direction == Direction.ACROSS and len(positions[Direction.ACROSS]) == 0:
        direction = Direction.DOWN
    
    position = random.choice(list(positions[direction]))
    
    blacklist = positions[direction][position] + word_list
    candidates = get_candidates(puzzle, position, direction, blacklist)

    if len(candidates) == 0:
        positions[direction].pop(position, None)
        if len(positions[Direction.ACROSS]) == 0 and len(positions[Direction.DOWN]) == 0:
            break
        continue
    
    try:
        word = candidates["word"].sample(1, weights=np.log(candidates.freq)).item()
    except:
        word = candidates["word"].sample(1).item()
    
    pbar.update(n)
    pbar.set_description(f"word: {word}, position: {position}, candidates: {len(candidates)}", refresh=True)
    
    if puzzle.add_word(position, direction, word):
        if direction == Direction.DOWN and len(positions[Direction.ACROSS]) > 0:
            direction = Direction.ACROSS
        elif direction == Direction.ACROSS and len(positions[Direction.DOWN]) > 0:
            direction = Direction.DOWN
            
        word_list.append(word)
        n += 1
    else:
        positions[direction][position].append(word)
        logger.info(f"Can't place word {word} at {position}")

print(word_list)
print(puzzle)

word: loner, position: (5, 0), candidates: 13: : 100it [00:25,  3.97it/s]
word: squi, position: (6, 0), candidates: 1: : 4850it [01:31, 59.23it/s]     

IndexError: Cannot choose from an empty sequence

word: squi, position: (6, 0), candidates: 1: : 4850it [01:46, 59.23it/s]

In [64]:
def custom_print(puzzle: WordGrid):
    to_print = []
    for chars, states in zip(puzzle.puzzle, puzzle.state):
        data = []
        for char, state in zip(chars, states):
            if state & Direction.ACROSS.value and state & Direction.DOWN.value:
                color = "magenta"
            elif state & Direction.ACROSS.value:
                color = "blue"
            elif state & Direction.DOWN.value:
                color = "yellow"
            else:
                color = "white"
            data.append(colored(char.decode('utf-8'), color))
        to_print.append(data)
    print(tabulate(to_print))

In [65]:
custom_print(puzzle)

-  -  -  -  -  -  -  -  -  -
[97m-[0m  [33mg[0m  [33mf[0m  [34mf[0m  [34mo[0m  [34mi[0m  [34ml[0m  [97m-[0m  [97m-[0m  [97m-[0m
[97m-[0m  [33mo[0m  [33mu[0m  [97m-[0m  [97m-[0m  [97m-[0m  [35mp[0m  [34mb[0m  [34mk[0m  [97m-[0m
[97m-[0m  [33mo[0m  [33mn[0m  [97m-[0m  [97m-[0m  [97m-[0m  [33mi[0m  [97m-[0m  [97m-[0m  [97m-[0m
[97m-[0m  [33md[0m  [33mk[0m  [97m-[0m  [97m-[0m  [97m-[0m  [33mt[0m  [97m-[0m  [97m-[0m  [97m-[0m
[97m-[0m  [97m-[0m  [33my[0m  [97m-[0m  [97m-[0m  [97m-[0m  [97m-[0m  [97m-[0m  [97m-[0m  [97m-[0m
-  -  -  -  -  -  -  -  -  -
