In [320]:
import pandas as pd
import re, random
import nltk
from tqdm import tqdm
from collections import Counter, defaultdict
import itertools
import matplotlib.pyplot as plt
import time
from sklearn.model_selection import train_test_split
from transformers import TextDataset,DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, AutoModelWithLMHead
from transformers import AutoTokenizer
from transformers import set_seed
import sys
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np

In [321]:
SEED = 314
set_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

In [322]:
data = pd.read_csv("./train.csv")

In [323]:
class spell:
    def __init__(self, data, corrupted_text, correct_text):
        print("starting")
        words = " ".join(list(data[correct_text])).lower()
        print("extracting tokens")
        words = re.findall(r'[\w]+', words)
        #words = u" ".join(words).split()
        print("creating set of syms")
        self.d_sym = "".join(list(set(list("".join(words)))))
        print("creating set of words")
        self.d_set = set(words)
        print("creating dict")
        self.d_dict = dict(Counter(words))
        print("init done")
        print("")
    
    def create_symspell(self, arr):
        self.pbar = tqdm(total=len(arr))
        self.symdict = defaultdict(list)
        pool = ThreadPool(10)
        pool.map(self.symspell, arr[:10000])
    
    def symspell(self, word):
        words = self.away_2(word)
        for w in words:
            self.symdict[w].append(word)
        self.pbar.update(1)
            
    def away_1(self, word):
        #letters = self.d_sym
        letters = 'абвгдежзийклмнопрстуфхцчшщъыьэюяё'
        splits = [(word[:i], word[i:]) for i in range(len(word) + 1)]
        #deletes = [L + R[1:] for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R) > 1]
        replaces = [L + c + R[1:] for L, R in splits if R for c in letters]
        inserts = [L + c + R for L, R in splits for c in letters]
        return set(transposes + replaces + inserts) #deletes

    def away_2(self, word):
        return set([e2 for e1 in self.away_1(word)
                    for e2 in self.away_1(e1)])
    
    def known(self, words):
        return set(w for w in words if w in self.d_set)

    def edit_candidates(self, word):
        ttt = self.known(self.away_1(word)) | self.known(self.away_2(word))

        return list(ttt)

    def most_freq_edits(self, word):
        lst = self.edit_candidates(word)
        lst.sort(key=lambda x: self.d_dict[x])
        lst.reverse()
        return lst

    def token(self, sent):
        return re.findall(r'[\w]+', sent)



In [324]:
a = spell(data, "corrupted_text", "correct_text")

starting
extracting tokens
creating set of syms
creating set of words
creating dict
init done



In [325]:
a.most_freq_edits("человен")

['человек',
 'человека',
 'человеку',
 'человеке',
 'человеко',
 'мелочен',
 'келовей',
 'человечны',
 'человке',
 'человев',
 'человечно',
 'мелован',
 'человече']

# fine-tuning gpt-2

In [326]:
#!g1.1
train_data = data.correct_text.copy()

In [327]:
#!g1.1
def build_text_files(data_arr, dest_path):
    with open(dest_path, 'w') as f:
        data = ''
        for texts in data_arr:
            data += texts + "  "
        f.write(data)

train, test = train_test_split(train_data,test_size=0.1)

build_text_files(train,'train_dataset.txt')
build_text_files(test,'test_dataset.txt')

In [328]:
#!g1.1
tokenizer = AutoTokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")

train_path = 'train_dataset.txt'
test_path = 'test_dataset.txt'

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=608.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1713123.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1270925.0), HTML(value='')))






Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [329]:
#!g1.1
def load_dataset(train_path,test_path,tokenizer):
    train_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=train_path,
          block_size=128)

    test_dataset = TextDataset(
          tokenizer=tokenizer,
          file_path=test_path,
          block_size=128)

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, mlm=False,
    )
    return train_dataset,test_dataset,data_collator

train_dataset,test_dataset,data_collator = load_dataset(train_path,test_path,tokenizer)




In [330]:
#!g1.1

model = AutoModelWithLMHead.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2")

training_args = TrainingArguments(
    output_dir="./gpt2-ru", 
    overwrite_output_dir=True, 
    num_train_epochs=1, 
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    eval_steps = 500,
    save_steps=1000,
    warmup_steps=500,
    )

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)



HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=551290714.0), HTML(value='')))




In [None]:
#!g1.1
trainer.train()

In [332]:
#!g1.1
trainer.save_model()

Saving model checkpoint to ./gpt2-ru
Configuration saved in ./gpt2-ru/config.json
Model weights saved in ./gpt2-ru/pytorch_model.bin


# Поиск и замена ошибок

In [333]:
#!g2.mig
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
    model = GPT2LMHeadModel.from_pretrained('./gpt2-ru') #finetuned rugpt-2
    model.to(device)
    model.eval()

tokenizer = GPT2Tokenizer.from_pretrained('sberbank-ai/rugpt3small_based_on_gpt2')

def score(sentence):
    tokenize_input = tokenizer.encode(sentence)
    tensor_input = torch.tensor([tokenize_input]).to(device)
    loss = model(tensor_input, labels=tensor_input)[0]
    return np.exp(loss.cpu().detach().numpy())

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1713123.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1270925.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=608.0), HTML(value='')))






In [342]:
#!g2.mig
def find_best(sent, strategy = None, top = 6, att_score = 500, att_freq = 3):
    global a
    tokenized_lower = a.token(sent.lower())
    tokenized_normal = a.token(sent)

    # ищем какие слова - ошибки
    words_mistakes = {}  # ключ - слово, значение - индекс в tokenized_lower
    for i, word in enumerate(tokenized_lower):
        if not [word] == list(a.known([word])):
            words_mistakes[word] = i
    
    
    # ищем по словарю возможные замены
    edit_suggestions = []
    words_mistakes_list = []
    for word in words_mistakes:
        words_mistakes_list.append(word)
        edit_suggestion = a.most_freq_edits(word)
        if edit_suggestion:
          if top is not None:
            if len(edit_suggestion) > top:
              edit_suggestion = edit_suggestion[:top-1]
          edit_suggestions.append(edit_suggestion)
        else:
            edit_suggestions.append([word])

    # генерируем подстановоки вместо ошибок

    pairs = list(itertools.product(*edit_suggestions))
    sent_suggestions = []
    for pair in pairs:
        sent_tmp = sent
        for i in range(len(words_mistakes)):
            if tokenized_normal[words_mistakes[words_mistakes_list[i]]][0].isupper():
                replace = pair[i]
                replace = replace[0].upper() + replace[1:]
            else:
                replace = pair[i]
            sent_tmp = sent_tmp.replace(tokenized_normal[words_mistakes[words_mistakes_list[i]]], replace)

        sent_suggestions.append(sent_tmp)
    #sent_suggestions = list(set(sent_suggestions))
    #print(len(sent_suggestions))

    # проверяем семантическую адекватность подстановок и выбираем лучшую
    scores = []
    for i in sent_suggestions:
        scores.append(score(i))
    #plt.hist(scores)
    #plt.show()
    return sent_suggestions[np.argmin(scores)]

In [343]:
#!g2.mig
find_best("Об этом чернз минуту.")

'Об этом через минуту.'

# Тесты

## Тест всего билда

In [344]:
#!g2.mig
def validate(data, verbose = False):
  res = []
  time1 = time.time()
  for i, elm in tqdm(enumerate(data.corrupted_text), total = len(data)):
    res.append(find_best(elm))
  l = 0
  n = 0
  time2 = time.time()
  for i, elm in enumerate(data.correct_text):
    n += 1
    if elm == res[i]:
      l += 1
    else:
      if verbose:
        print("FAILED || ", data.corrupted_text.iloc[i], '==>', res[i], '!!!===', elm)
  
  print("TOTAL ACU: ",l/n)
  print("SECONDS PER ITER :", np.round((time2-time1)/len(data), 4))
  print("TOTAL HOURS FOR ALL PRiVATE: ", np.round((time2-time1)/len(data) * 56000 / (3600),1) )

In [346]:
#!g2.mig
data_val = data[0:200]
validate(data_val, verbose = True)

FAILED ||  Считает, что ссожет ить вечно! ==> Считает, что сможет быть вечно! !!!=== Считает, что сможет жить вечно!
FAILED ||  Вы имеетн в виду силу ьога? ==> Вы имеете в виду силу тогда? !!!=== Вы имеете в виду силу бога?
FAILED ||  босаснов зефир ванильный темной глазури ==> боссанова зефир ванильный темной глазури !!!=== БоссаНова зефир ванильный темной глазури
FAILED ||  Филир Морис комп Эксперт ==> Филир Морис комп Эксперт !!!=== Филип Морис комп Эксперт
FAILED ||  Сегнал заднего хода ==> Сигналы заднего хода !!!=== Сигнал заднего хода
FAILED ||  - Во сколько твой ресй в Вашингтон? ==> - Во сколько твой тест в Вашингтон? !!!=== - Во сколько твой рейс в Вашингтон?
FAILED ||  я шлубaко взволнован наш с ней встречей. ==> я глубоко взволнован наш с ней встречей. !!!=== я глубоко взволнован нашей с ней встречей.
FAILED ||  А ты пока азправ. ==> А ты пока заправь. !!!=== А ты пока заправься.
FAILED ||  эт мой боат. Помните, мой пладш брат. ==> эт мой брат. Помните, мой плачу брат. !!!=

100%|██████████| 200/200 [00:38<00:00,  5.16it/s]


# Private submit

In [349]:
#!g2.mig
data_sub = pd.read_csv("./private_test.csv")

In [350]:
#!g2.mig
data_sub.head(5)

Unnamed: 0,corrupted_text
0,мясыне блюда говядина
1,- А можно я пойд?
2,Бордюры обонй ассортименте
3,Вместо союса кетчуп
4,"Не прдесталя, как она могла туда папаст."


In [351]:
#!g2.mig
len(data_sub)

56526

In [352]:
#!g2.mig
def submit(data, outpath):
    with open(outpath, 'w') as file:
        for i, elm in tqdm(enumerate(data.corrupted_text), total = len(data)):
            file.write(find_best(elm) + "\n")


In [353]:
#!g2.mig
submit(data_sub, "private.submit")

100%|██████████| 56526/56526 [3:37:05<00:00,  4.34it/s]


In [354]:
#!g2.mig
with open("private.submit") as file:
    file.read().split("n")