In [1]:
import datrie
import re
import nltk
from nltk.collocations import *
from nltk.tokenize import sent_tokenize
import string
from pymystem3 import Mystem
m = Mystem()

rusdict=datrie.Trie.load('russian_dic.trie')

r_alphabet = re.compile(u'[0-9а-яА-ЯёЁ]+|[.,:;?!]+')

d1=datrie.Trie.load('hyphen_ext.trie')
d2=datrie.Trie.load('wrong_hyphen.trie')

def lookup(tokens):
    s = []
    for i in range(len(tokens)):
        s.append(tokens[i])
        if len(s)>1:
            if s[-2].lower()+' '+s[-1].lower() in d1.keys():
                word=d1[s[-2].lower()+' '+s[-1].lower()]
                s = s[:-2]
                s.append(word)
    return s     

def lookup_non_hyphens(tokens):
    s=[]
    for i in tokens:
        if '-' in i and i in d2.keys():
            for w in d2[i].split():
                s.append(w)
        else:
            s.append(i)
    return s

def assemble_string(tokens):
    s=''
    for i in range(len(tokens)):
        if i==0:
            word=tokens[i][0].upper()+tokens[i][1:]
            s = s+word
        if i>0:
            if tokens[i][0] in string.punctuation:
                s = s+tokens[i]
            else:
                s = s + ' ' + tokens[i]
    return s

def get_prob(s):
    prob=0.0
    if s in probdict.keys():
        prob+=probdict[s]
    return prob+0.0001

def word_bigrams(word):
    if len(word)>0:
        word_bigrams=[word[k:k+2] for k in range(len(word)-1)]
        word_bigrams.append(word[0])
        word_bigrams.append(word[-1])
        return word_bigrams
    else:
        return word
    
def similarity(s1,s2):
    v=0
    for i in s1:
        if i in s2:
            s2=s2.replace(i, '')
            v+=1
    return v

def compare(bigrams1, bigrams2):
    common=list(set(bigrams1).intersection(bigrams2))
    unique=list(set(bigrams1).symmetric_difference(set(bigrams2)))
    if len(unique)>0:
        coeff=len(common)/len(unique)
        return coeff
    else:
        return 0.0
    
def levenshtein(s, t):
    if s == t: return 0
    elif len(s) == 0: return len(t)
    elif len(t) == 0: return len(s)
    v0 = [None] * (len(t) + 1)
    v1 = [None] * (len(t) + 1)
    for i in range(len(v0)):
        v0[i] = i
    for i in range(len(s)):
        v1[0] = i + 1
        for j in range(len(t)):
            cost = 0 if s[i] == t[j] else 1
            v1[j + 1] = min(v1[j] + 1, v0[j + 1] + 1, v0[j] + cost)
        for j in range(len(v0)):
            v0[j] = v1[j]
    return v1[len(t)]

def suggest_word(i):
    suggest=[]
    bigrams_1=word_bigrams(i)
    options={}
    for j in rusdict.items(i[0]):
        if (j[1]>=len(i)-2 and j[1]<=len(i)+2) and similarity(i,j[0])>=int(len(i)/2): 
            options[j[0]]=compare(bigrams_1,word_bigrams(j[0]))
    if options:
        max_val=max(options.values())
        for i in options.items():
            if i[1]==max_val:
                suggest.append(i[0])
    else:
        suggest.append(i)
    return suggest

probdict = datrie.Trie.load('/media/sf_communalflat/ihatelinguistics/lmodel.trie')
char_probdict=datrie.Trie.load('/media/sf_communalflat/ihatelinguistics/lmodel_chars.trie')

def count_char_prob(word):
    v=0.0
    word_bigrams=[word[k:k+2] for k in range(len(word)-1)]
    for i in word_bigrams:
        if i in char_probdict:
            v+=char_probdict[i]
    return v
                                        
def correct_text(sentences):
    corrected_text=[]
    for i in sentences:
        tokens = r_alphabet.findall(i.lower())
        tokens1=lookup(tokens)
        tokens2=lookup_non_hyphens(tokens1)
        sent=[]
        for word in tokens2:
            if word in rusdict:
                sent.append(word)
            else:
                possible=[]
                for i in suggest_word(word):
                    possible.append(i)
                if len(possible)==1:
                    sent.append(possible[0])
                elif len(possible)>1:
                    probs={}
                    lemmas={}
                    reverse_lemmas={}
                    windows=[]
                    for i in possible:
                        lemmas[i]=m.lemmatize(i)[0]
                        reverse_lemmas[m.lemmatize(i)[0]]=i
                        if len(tokens2)==1:
                            probs[i]=1-(levenshtein(i, word))
                        elif len(tokens2)>=2 and tokens2[0]==word:
                            windows.append('pad'+' '+lemmas[i]+' '+m.lemmatize(tokens2[1])[0])
                        elif len(tokens2)>=2 and tokens2[-1]==word:
                            windows.append(m.lemmatize(tokens2[-2])[0]+' '+lemmas[i]+' '+'pad')
                        else:
                            windows.append(m.lemmatize(tokens2[tokens2.index(word)-1])[0]+' ' +lemmas[i]+' '+m.lemmatize(tokens2[tokens2.index(word)+1])[0])
                    for i in windows:
                        window_tokens=i.split()
                        probs[reverse_lemmas[window_tokens[1]]]=get_prob('[] []'.format(window_tokens[0],window_tokens[1]))*get_prob('[] []'.format(window_tokens[1],window_tokens[2]))
                    max_prob=max(probs.values())
                    extra_probs=[]
                    for i in probs.items():
                        if i[1]==max_prob:
                            extra_probs.append(i[0])
                    if len(extra_probs)>1:
                        extra_dict={}
                        for i in extra_probs:
                            extra_dict[i]=count_char_prob(i)
                        val_l=max(extra_dict, key=extra_dict.get)
                        sent.append(val_l)
                    else:
                        sent.append(extra_probs[0])
        corrected_text.append(assemble_string(sent))
    return ' '.join(i for i in corrected_text)

test=open('/media/sf_communalflat/ihatelinguistics/spell_test.txt', encoding='utf-8').readlines()

test_lines=[]
i=0
while i<=len(test):
    test_lines.append(test[i])
    i+=3
    
checked_lines=[]
for i in test_lines:
    try:
        checked_lines.append(correct_text([i]))
    except:
        checked_lines.append(i)
        
with open('spelltest_checked7.txt', 'w', encoding='utf-8') as f:
    for i in checked_lines:
        f.write(i+'\n')