In [None]:
import os
from tqdm import tqdm
from rapidfuzz import process, fuzz


def load_dictionary(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return set(f.read().splitlines())

def load_queries(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return f.read().splitlines()

def restricted_levenshtein(s1, s2):
    n, m = len(s1), len(s2)
    dp = [[0] * (m + 1) for _ in range(n + 1)]

    for i in range(n + 1):
        dp[i][0] = i
    for j in range(m + 1):
        dp[0][j] = j

    for i in range(1, n + 1):
        for j in range(1, m + 1):
            if s1[i - 1] == s2[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                insert_cost = dp[i][j - 1] + 1
                delete_cost = dp[i - 1][j] + 1
                transpose_cost = (
                    dp[i - 2][j - 2] + 1
                    if i > 1 and j > 1 and s1[i - 1] == s2[j - 2] and s1[i - 2] == s2[j - 1]
                    else float('inf')
                )
                dp[i][j] = min(insert_cost, delete_cost, transpose_cost)

    return dp[n][m]

def find_operation(candidate):
    word, dist, dict_word = candidate

    # Вставка
    for i in range(len(dict_word)):
        new_word = word[:i] + dict_word[i] + word[i:]
        new_dist = restricted_levenshtein(new_word, dict_word)
        if new_dist < dist:
            return (new_word, new_dist, dict_word)

    # Перестановка
    for i in range(len(word) - 1):
        new_word = list(word)
        new_word[i], new_word[i + 1] = new_word[i + 1], new_word[i]
        new_word = ''.join(new_word)
        new_dist = restricted_levenshtein(new_word, dict_word)
        if new_dist < dist:
            return (new_word, new_dist, dict_word)
    
    # Удаление
    for i in range(len(word)):
        new_word = word[:i] + word[i + 1:]
        new_dist = restricted_levenshtein(new_word, dict_word)
        if new_dist < dist:
            return (new_word, new_dist, dict_word)
    
    return (new_word, dist, dict_word)


def find_correction(candidate):
    word, dist, dict_word = candidate
    correction = []
    for _ in range(candidate[1] - 1):
        word, dist, dict_word = find_operation((word, dist, dict_word))
        correction.append(word)
    return ' '.join(correction)


def process_queries(dictionary, queries, output_file):
    with open(output_file, 'w', encoding='utf-8') as out_f:
        for word in tqdm(queries, desc="Processing Queries", ncols=100):
            if word in dictionary:
                out_f.write(f"{word} 0\n")
                continue

            candidates = process.extract(word, dictionary, scorer=fuzz.ratio, limit=7)
            best_candidate = None
            
            for dict_word, score, _ in candidates:
                true_distance = restricted_levenshtein(word, dict_word)
                if not best_candidate or true_distance < best_candidate[1]:
                    best_candidate = (word, true_distance, dict_word)
            
            if best_candidate[1] >= 5:
                out_f.write(f"{best_candidate[0]} 5+\n")
            else:
                find_str_corr = find_correction(best_candidate)
                out_f.write(f"{best_candidate[0]} {best_candidate[1]} {find_str_corr} {best_candidate[2]}\n")


dictionary_file = 'data/4867_correct_all_typos_2/dict.txt'
queries_file = 'data/4867_correct_all_typos_2/queries.txt'
output_file = 'data/4867_correct_all_typos_2/answer.txt'

dictionary = load_dictionary(dictionary_file)
queries = load_queries(queries_file)

process_queries(dictionary, queries, output_file)