In [58]:
import time
import sys
from collections import defaultdict, Counter
import re
import numpy as np

In [52]:
class Trie:
    def __init__(self):
        self.word = None
        self.probability = None
        self.children = defaultdict(Trie)
        
    def insert(self, word, probability):
        node = self
        for letter in word:
            node = node.children[letter]
        node.word = word
        node.probability = probability

In [53]:
class LevenshteinAutomaton:
    def __init__(self, word, max_dist):
        self.word = word
        self.max_dist = max_dist

    def start(self):
        return range(len(self.word) + 1)

    def step(self, state, letter):
        new_state = [state[0] + 1]
        for i in range(len(state) - 1):
            insert_state = new_state[i] + 1
            replace_state = state[i] + (self.word[i] != letter)
            delete_state = state[i + 1] + 1
            new_state.append(min(insert_state, replace_state, delete_state))
        return new_state

    def is_match(self, state):
        return state[-1] <= self.max_dist

    def can_match(self, state):
        return min(state) <= self.max_dist

In [128]:
class WordSearcher:
    def __init__(self, trie, max_dist=1):
        self.trie = trie
        self.max_dist = max_dist
        
    def search(self, word):
        words = []
        automaton = LevenshteinAutomaton(word, self.max_dist)
        state = automaton.start()
        self.search_recursive(self.trie.children, automaton, state, words)
        return words
        
    def search_recursive(self, node, automaton, state, words):
        for letter in node:
            new_state = automaton.step(state, letter)
            if automaton.is_match(new_state) and node[letter].word != None:
                words.append((node[letter].word, new_state[-1] - 0.1 * np.log(node[letter].probability)))
            if automaton.can_match(new_state):
                self.search_recursive(node[letter].children, automaton, new_state, words)

In [131]:
class SpellChecker:
    def __init__(self, rus_dictionary, eng_dictionary, max_dist=2):
        self.rus_searcher = WordSearcher(SpellChecker.get_trie(rus_dictionary), max_dist)
        self.eng_searcher = WordSearcher(SpellChecker.get_trie(eng_dictionary), max_dist)
        
    @staticmethod
    def get_trie(dictionary):
        cnt = Counter(words(open(dictionary).read()))
        trie = Trie()
        N = sum(cnt.values())
        for word, count in cnt.items():
            trie.insert(word, count / N)
        return trie
    
    @staticmethod
    def get_words(text):
        return re.findall(r'\w+', text.lower())
    
    @staticmethod
    def get_language(text):
        eng_count = len(re.findall(r'[a-z]', text))
        rus_count = len(re.findall(r'\w', text)) - eng_count
        return 'eng' if rus_count < eng_count else 'rus'
        
    def check(self, text):
        lang = SpellChecker.get_language(text)
        words = SpellChecker.get_words(text)
        correct_text = []
        
        if lang == 'rus':
            for word in words:
                candidates = self.rus_searcher.search(word)
                correct_text.append(min(candidates, key=lambda x: x[1])[0])
                
        if lang == 'eng':
            for word in words:
                candidates = self.eng_searcher.search(word)
                correct_text.append(min(candidates, key=lambda x: x[1])[0])
        
        return correct_text

In [132]:
speller = SpellChecker('../data/rus.txt', '../data/eng.txt')

In [142]:
speller.check('i am dogq')

['i', 'am', 'dog']

In [122]:
speller.rus_searcher.trie.children['я'].probability

0.0066631044272627196