In [32]:
from razdel import sentenize, tokenize
from tqdm.auto import tqdm
from dawg import BytesDAWG, IntDAWG
from nltk import ngrams

import textdistance as td
import numpy as np

import re
import sys
import gc
import string

from collections import Counter
from typing import *
punct = set(string.punctuation + "«»—…“”")

In [2]:
with open("data/correct_sents.txt", 'r', encoding='utf-8') as f:
    correct_sents = f.readlines()
    
with open("data/sents_with_mistakes.txt", 'r', encoding='utf-8') as f:
    sents_with_mistakes = f.readlines()
    
with open("data/wiki_data.txt", 'r', encoding='utf-8') as f:
    wiki_data = f.readlines()

In [4]:
len(correct_sents), len(sents_with_mistakes), len(wiki_data)

(916, 916, 20002)

In [5]:
correct_sents = list(map(lambda s: re.sub(r'[^\w\s]','',s).replace("  ", " "),  correct_sents))

In [6]:
sents_with_mistakes = list(map(lambda s: re.sub(r'[^\w\s]','',s).replace("  ", " "), sents_with_mistakes))

In [7]:
wiki_data = list(map(lambda s: s.replace("#", ' '), wiki_data))

In [8]:
wiki_sentences = []
for text in tqdm(wiki_data):
    sentences = list(map(
        lambda s: re.sub(
            r'[^\w\s]',
            ' ',
            s.replace(")", " ").replace("(", " ")
        ).replace("  ", " ").replace("  ", " ").replace("  ", " ").strip(), # squash multiple space in the middle
        map(
            lambda x: x.text,
            sentenize(text)
        )
    ))
    wiki_sentences += sentences

HBox(children=(FloatProgress(value=0.0, max=20002.0), HTML(value='')))




In [9]:
unigrams = []
bigrams  = []
trigrams = []
for sent in tqdm(wiki_sentences):
    words = ["<start>"] + list(map(lambda x: x.text.lower(), tokenize(sent))) + ["<end>"]
    unigrams += words
    bigrams += ngrams(words, n=2)
    trigrams += ngrams(words, n=3)

HBox(children=(FloatProgress(value=0.0, max=217867.0), HTML(value='')))




In [10]:
unigram_frequencies = Counter(unigrams)
bigram_frequencies = Counter(bigrams)
trigram_frequencies = Counter(trigrams)

In [11]:
vocab = list(set(unigram_frequencies.keys()))

In [12]:
def gen_deletions(word: str, n_deletions: int = 2) -> List[str]:
    first_order = []
    for i in range(len(word)):
        variant = word[:i] + word[i+1:]
        first_order.append(variant)
        
    if n_deletions == 1:
        return list(set(first_order))
    else:
        second_order = []
        for variant in first_order:
            second_order += gen_deletions(variant, n_deletions-1)
            
        return list(set(first_order + second_order))

In [13]:
symspell_index = []
for i, word in enumerate(tqdm(vocab)):
    _word = re.sub(r'[0-9]+', '', word)
    if _word == word and len(word) > 2:
        variants = gen_deletions(_word, 2)
        for v in variants:
            if len(v) > 2:
                symspell_index.append((v, i.to_bytes(4, 'little')))

HBox(children=(FloatProgress(value=0.0, max=365947.0), HTML(value='')))




## Note
Можно конечно хранить индекс в питоновском dict, но при сериализации в json он занимал на диске 3гб места, и вероятно очень много в оперативке

Вместо питоновского dict, я использую альтернативный вариант с dawg-контейнерами, которые позволяют хранить пары вида строковый_ключ -> список байтовый массивов, строковый_ключ -> int-число c очень очень маленьким расходом памяти и таким-же временем доступа, как и у диктов

А ещё эта штука используется в fast-варианте pymorphy2

In [14]:
symspell_index = BytesDAWG(symspell_index)
gc.collect()

3

In [15]:
unigram_frequencies = unigram_frequencies.most_common()
bigram_frequencies = list(map(lambda x: (" ".join(x[0]), x[1]), bigram_frequencies.most_common()))
trigram_frequencies = list(map(lambda x: (" ".join(x[0]), x[1]), trigram_frequencies.most_common()))

In [16]:
unigrams_index = IntDAWG(unigram_frequencies)
bigrams_index = IntDAWG(bigram_frequencies)
trigrams_index = IntDAWG(trigram_frequencies)

In [17]:
unigrams_index.save("unigrams.index")
bigrams_index.save("bigrams.index")
trigrams_index.save("trigrams.index")
symspell_index.save("symspell.index")

In [18]:
!pbzip2 -kzvf -9 symspell.index unigrams.index bigrams.index trigrams.index

Parallel BZIP2 v1.1.9     - by: Jeff Gilchrist [http://compression.ca]
[Apr. 13, 2014]               (uses libbzip2 by Julian Seward)
Major contributions: Yavor Nikolov <nikolov.javor+pbzip2@gmail.com>

         # CPUs: 16
 BWT Block Size: 900 KB
File Block Size: 900 KB
 Maximum Memory: 100 MB
-------------------------------------------
         File #: 1 of 4
     Input Name: symspell.index
    Output Name: symspell.index.bz2

     Input Size: 101412872 bytes
Compressing data...
    Output Size: 68295847 bytes
-------------------------------------------
         File #: 2 of 4
     Input Name: unigrams.index
    Output Name: unigrams.index.bz2

     Input Size: 3149828 bytes
Compressing data...
    Output Size: 1919098 bytes
-------------------------------------------
         File #: 3 of 4
     Input Name: bigrams.index
    Output Name: bigrams.index.bz2

     Input Size: 35648516 bytes
Compressing data...
    Output Size: 20356333 bytes
-------------------------------------------
 

In [19]:
!ls -lh *.index.bz2

-rw-r--r-- 1 root root  20M Dec  5 12:54 bigrams.index.bz2
-rw-r--r-- 1 root root  66M Dec  5 12:54 symspell.index.bz2
-rw-r--r-- 1 root root  60M Dec  5 12:54 trigrams.index.bz2
-rw-r--r-- 1 root root 1.9M Dec  5 12:54 unigrams.index.bz2


In [20]:
def get_matches_symspell(word, vocab: List[str], symspell_index: BytesDAWG):
    variants = gen_deletions(word, n_deletions=2)
    candidates = []
    for v in variants:
        _binary = symspell_index.get(v, None)
        if _binary:
            candidates += list(
                map(
                    lambda x: vocab[int.from_bytes(x, 'little')], # декодируем 
                    _binary # в таком порядке: bytes -> int(индекс слова в массиве) -> str(само слово)
                )
            )
    
    candidates = list(set(candidates))
    candidates = list(map(lambda x: (x, td.damerau_levenshtein(x, word)), candidates))
    candidates = sorted(candidates, key=lambda x: x[1])
    return list(candidates)

In [67]:
def score_by_ngram(
    word, context, unigrams, bigrams, trigrams,
    total_words_count: int, weights = (0.1, 0.1, 0.8), do_trigram_only=False
) -> float:
    score = 0.0
    if word not in unigrams:
        return 0.0
        
    unigram_weight, bigram_weight, trigram_weigth = weights
    
    if not do_trigram_only:
        unigram_score = unigrams.get(word, 0.0) / total_words_count
        score += unigram_score * unigram_weight
        if len(context) >= 1:
            bigram_score = bigrams.get(f"{context[-1]} {word}", 0.0) / unigrams.get(context[-1], 1.0)
            score += bigram_score * bigram_weight
        else:
            score += bigram_weight
        
    if len(context) >= 2:
        trigram_score = trigrams.get(
            f"{context[-2]} {context[-1]} {word}", 0.0
        ) / bigrams.get(f"{context[-2]} {context[-1]}", 1.0)
        score += trigram_score * trigram_weigth
    else:
        score += trigram_weigth
        
    return score

In [55]:
total_words_count = sum(unigrams_index.get(word) for word in vocab)

In [75]:
def spellcorrect(sentence, vocab, vocabset, symspell_index, unigrams_index, bigrams_index, trigrams_index, total_words_count, do_ngram=True):
    words = ["<start>"] + [x.text for x in tokenize(sentence)] + ["<end>"]
    corrected_words = []
    for i, word in enumerate(words):
        if word in vocabset:
            corrected_words.append(word)
            continue
        
        candidates = get_matches_symspell(word, vocab, symspell_index)
        if len(candidates) == 0:
            corrected_words.append(word)
            continue
            
        max_dist = max(x[1] for x in candidates)
        scored_candidates = []
        for candidate, dist in candidates:
            context = words[max(i-2, 0):i]
            if do_ngram: # если будем ранжировать варианты по n-gramm вероятностям
                lm_score = score_by_ngram(
                    candidate, context,
                    unigrams_index,
                    bigrams_index,
                    trigrams_index,
                    total_words_count
                )
            else: # иначе оставим у всех одинаковый скор
                lm_score = 1.0
            score = ((max_dist + 1) - dist)*lm_score # нужно чтобы у слова с меньшей дистанцией был наивысший скор
            scored_candidates.append((candidate, score))
            
        best_candidate, best_score = sorted(scored_candidates, key=lambda x: x[1])[-1]
        corrected_words.append(best_candidate)
            
    return " ".join(corrected_words[1:-1])

In [76]:
sents_with_mistakes = list(x.lower().replace("\n", "") for x in sents_with_mistakes)

In [77]:
# исправляем с ngram-ранжированием
vocabset = set(vocab) # set(..) проще сделать один раз чем ждать пока его на каждой итерации будет делать функция
corrected_ngram = []
for sent in tqdm(sents_with_mistakes):
    correction = spellcorrect(
        sent.lower(), vocab, vocabset,
        symspell_index,
        unigrams_index,
        bigrams_index, 
        trigrams_index,
        total_words_count
    )
    corrected_ngram.append(correction)

HBox(children=(FloatProgress(value=0.0, max=916.0), HTML(value='')))




In [78]:
vocabset = set(vocab)
corrected_symspell = []
# исправляем только по индексу + дистанции
for sent in tqdm(sents_with_mistakes):
    correction = spellcorrect(
        sent.lower(), vocab, vocabset,
        symspell_index,
        unigrams_index,
        bigrams_index, 
        trigrams_index,
        total_words_count,
        do_ngram=False
    )
    corrected_symspell.append(correction)

HBox(children=(FloatProgress(value=0.0, max=916.0), HTML(value='')))




In [79]:
def align_words(sent_1, sent_2, sent_3):
    tokens_1 = sent_1.lower().split()
    tokens_2 = sent_2.lower().split()
    tokens_3 = sent_3.lower().split()
    
    tokens_1 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_1 if (set(token)-punct)]
    tokens_2 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_2 if (set(token)-punct)]
    tokens_3 = [re.sub('(^\W+|\W+$)', '', token) for token in tokens_3 if (set(token)-punct)]
    
    return list(zip(tokens_1, tokens_2, tokens_3))

In [80]:
def metrics(original, corrected, reference):
    correct, total = 0.0, 0.0
    mistaken_fixed, mistaken_total = 0.0, 0.0
    correct_broken, correct_total = 0.0, 0.0
    for s_orig, s_corrected, s_reference in tqdm(zip(original, corrected, reference)):
        for word_orig, word_corrected, word_reference in align_words(s_orig, s_corrected, s_reference):
            total += 1.0
            if word_corrected == word_reference:
                correct += 1.0
                
            if word_orig == word_reference:
                correct_total += 1.0
                if word_corrected != word_reference:
                    correct_broken += 1.0
            else:
                mistaken_total += 1.0
                if word_corrected == word_reference:
                    mistaken_fixed += 1.0
                    
    return {
        'correct_ratio': correct/total,
        'mistakes_correct_ratio': mistaken_fixed/mistaken_total,
        'broken_ratio': correct_broken/correct_total
    }

In [81]:
## for symspell-only model
print(metrics(sents_with_mistakes, corrected_symspell, correct_sents))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


{'correct_ratio': 0.8447552447552448, 'mistakes_correct_ratio': 0.4191235059760956, 'broken_ratio': 0.09423186750428326}


In [82]:
## for symspell + ngram model
print(metrics(sents_with_mistakes, corrected_ngram, correct_sents))

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))


{'correct_ratio': 0.8419580419580419, 'mistakes_correct_ratio': 0.39681274900398406, 'broken_ratio': 0.09423186750428326}


## Заключение:
Почему-то модель без ngram работает немножечко лучше, чем модель с ngram.
Есть две возможные причины:
1. я где-то накосячил в расчете вероятностей(но если заменить log-exp - то получаются ещё хуже метрики)
2. так сложилось - так загрузился датасет, посчитался индекс и иные причины, не зависящие напрямую от меня