In [None]:
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from collections import defaultdict, Counter
import re

## Load data
def build_dictionary(dictionary_file_location):
    text_file = open(dictionary_file_location, "r")
    full_dictionary = text_file.read().splitlines()
    text_file.close()
    return full_dictionary

full_dictionary_location = "words_250000_train.txt"
full_dictionary = build_dictionary(full_dictionary_location)
full_dic, answers = train_test_split(full_dictionary, test_size = 0.004)

In [None]:
## record all the substrings of length 2 to 30
def get_len_stat(dic, order):
    len_stat = defaultdict(list)
    data = dic
    for length in tqdm(range(2, order + 1)):
        for i in range(len(data)-order):
            gram = data[i:i+length]
            if " " not in gram:
                len_stat[length].append(gram)
    return len_stat

len_stat = get_len_stat(" ".join(full_dic), 30)

In [None]:
## Check if the substring pattern is in the training data
def check_pattern(pattern, len_stat):
    candidate_pattern = []
    pattern_len = len(pattern)
    if isinstance(len_stat, dict):
        sel_len_stat = len_stat[pattern_len]
    else:
        sel_len_stat = len_stat
    for train_pattern in sel_len_stat:
        if re.match(pattern, train_pattern):
            candidate_pattern.append(train_pattern)
    return candidate_pattern

## Check if there are more than {threshold}% vowels in the pattern
def exceed_vowels_threshold(pattern, vowels, threshold=0.55):
    cnt = 0
    for letter in pattern:
        if letter in vowels:
            cnt += 1
    return (cnt / len(pattern)) >= threshold

## Count all the letters in the candidate patterns
def count_candidate(pattern, candidate_pattern, guessed_letters, vowels, vowel_filter=False):
    letter_count = Counter("".join(candidate_pattern))
    if vowel_filter and exceed_vowels_threshold(pattern, vowels):
        letter_count = Counter(dict([(letter, count) for letter, count in letter_count.items() \
                                     if letter not in vowels and letter not in guessed_letters \
                                     and letter.isalpha()]))
    else:
        letter_count = Counter(dict([(letter, count) for letter, count in letter_count.items() \
                                     if letter not in guessed_letters and letter.isalpha()]))
    return letter_count

## Guess the next letter
def generate_letter(guess, len_stat, guessed_letters):
    vowels = set("aeiou")
    guess = guess.replace("_", ".")
    choice = None

    candidate_pattern = check_pattern(guess, full_dic)
    candidates = count_candidate(guess, candidate_pattern, guessed_letters, vowels, vowel_filter=True)
    if len(candidates) != 0:
        choice = candidates.most_common()[0][0]
        print("First step return.", guess)
        return choice

    sub_lens = [len(guess), len(guess) // 2, len(guess) // 3, len(guess) // 4, len(guess) // 5]
    for sub_len in sub_lens:
        if sub_len == len(guess):
            vowel_filter = True
            len_theshold = 1
        else:
            vowel_filter = False
            len_theshold = 3
        if sub_len >= len_theshold:
            candidate_list = Counter()
            for i in range(len(guess) - sub_len +1):
                sub_guess = guess[i:i+sub_len]
                candidate_pattern = check_pattern(sub_guess, len_stat)
                candidates = count_candidate(sub_guess, candidate_pattern, guessed_letters, vowels, vowel_filter)
                candidate_list += candidates
            if len(candidate_list) != 0:
                choice = candidate_list.most_common()[0][0]
                return choice

    for letter, _ in Counter("".join(full_dic)).most_common():
            if letter not in guessed_letters and letter.isalpha():
                if letter in vowels and exceed_vowels_threshold(guess, vowels):
                    continue
                choice = letter
                print("Last step return.")
                return choice

    return

def play(answer, len_stat, nTrials=6):
    guess = "_ " * int(len(answer)/2)
    guess_clean = guess[::2].replace(" ", "")
    guessed_letters = []
    errors = 0
    flag = False
    while(errors < nTrials):
        c = generate_letter(guess_clean, len_stat, guessed_letters)
        guessed_letters += [c]
        if answer.find(c)!=-1:
            idx = [pos for pos, char in enumerate(answer) if char == c]
            for j in idx:
                guess = '%s%s%s'%(guess[:j],c,guess[j+1:])
        else:
            errors += 1
        guess_clean = guess[::2]
        if guess_clean.find('_') == -1:
            flag = True
            break
    return guess, flag

N = len(answers)
success = 0
total_success = 0

for i, answer in enumerate(tqdm(answers)):
    answer = " ".join(answer) + " "
    res, flag = play(answer, len_stat)
    if flag:
        success += 1
        total_success += 1
        print("Success! The answer is " + res)
    else:
        print("Failed! The answer is " + answer + " The guess is " + res)

    if (i + 1) % 100 == 0:
        acc = success / 100.0 * 100
        print("Success rate for last 100 answers: %0.2f%%" % acc)
        success = 0 

total_acc = total_success / (N * 1.0) * 100
print("Overall success rate is %0.2f%%" % total_acc)
