In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [2]:
DRIVE_ROOT = f'/content/drive/MyDrive/murmansk'

## Loading lenta dataset

In [None]:
import os
import pandas as pd
lenta_path = os.path.join(DRIVE_ROOT, 'lenta_pairs_typos.csv')
lenta_data = pd.read_csv(lenta_path)

In [None]:
lenta_data.head()

Unnamed: 0.1,Unnamed: 0,source,target
0,0,Австрийские правоохранительные оргпны не предс...,Австрийские правоохранительные органы не предс...
1,1,Сотрудники социальной сети Instagram проанализ...,Сотрудники социальной сети Instagram проанализ...
2,2,С начала расследоаания российского вмешатеюьст...,С начала расследования российского вмешательст...
3,3,Хакерская грумпировва Anonymous опубликовала н...,Хакерская группировка Anonymous опубликовала н...
4,4,Архиепископ канонической Украинской православн...,Архиепископ канонической Украинской православн...


##Loading Clang8 dataset

In [120]:
!git clone https://github.com/google-research-datasets/clang8

Cloning into 'clang8'...
remote: Enumerating objects: 31, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (22/22), done.[K
remote: Total 31 (delta 9), reused 25 (delta 5), pack-reused 0[K
Receiving objects: 100% (31/31), 9.09 KiB | 9.09 MiB/s, done.
Resolving deltas: 100% (9/9), done.


In [121]:
import os
SCRIPT_PATH = os.path.join(DRIVE_ROOT, 'run.sh')

In [122]:
!pip install virtualenv

Collecting virtualenv
  Downloading virtualenv-20.24.5-py3-none-any.whl (3.7 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.7 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.9/3.7 MB[0m [31m26.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.7/3.7 MB[0m [31m57.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting distlib<1,>=0.3.7 (from virtualenv)
  Downloading distlib-0.3.7-py2.py3-none-any.whl (468 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/468.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m468.9/468.9 kB[0m [31m43.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: distlib, virtualenv
Successfully installed distlib-0.3.7 virtualenv-20.24.5


In [123]:
!cp -f $SCRIPT_PATH clang8/run.sh

In [None]:
%cd clang8
!chmod u+x run.sh
! ./run.sh
%cd ../

In [125]:
import pandas as pd
DATA_PATH = '/content/clang8/output_data/clang8_source_target_ru.spacy_tokenized.tsv'
data = pd.read_table(DATA_PATH, encoding='utf-8', on_bad_lines='skip', sep='\t', names=['source','target'])

In [126]:
data[data.source != data.target]

Unnamed: 0,source,target
1,"Краткое содержание этой книги , герой не ходит...","Краткое содержание этой книги , герой не ходит..."
2,"Ни с кем не говарить , не встречаться , не вый...","Ни с кем не разговаривать , не встречаться , н..."
3,"Но , конец он нашёл свою мечту . Когда мне был...","Но , наконец , он нашёл свою мечту . Когда мне..."
4,Я очень испытала симпатию его . Чувство полово...,Я очень испытала симпатию его . Чувство полово...
6,Но я чувствую его далеко чем в детстве .,"Но я чувствую его далеко сильнее , чем в детст..."
...,...,...
44803,"После окончания университета , я наверняка буд...","После окончания университета , мне наверняка б..."
44804,"Два года назад , у нас всё было по - другому .",Два года назад у нас всё было по - другому .
44814,"Теперь хочу показать некоторые фотографии , не...","Теперь хочу показать некоторые фотографии , не..."
44815,"К сожалению , у меня нет фотографий с двух лет...","К сожалению , у меня нет фотографий с двух лет..."


## Building model

In [None]:
!pip install transformers

In [14]:
MODEL_NAME = 'DeepPavlov/rubert-base-cased'

In [15]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")

In [16]:
from torch.nn import Module, Dropout, Linear, Sequential
from transformers import AutoModel


class Model(Module):
    def __init__(self, pretrained_model_name, num_punct_classes, freeze = True, **kwargs):
        super().__init__()
        self.emb = AutoModel.from_pretrained(pretrained_model_name, output_attentions=False, output_hidden_states=False)
        self.emb_size = list(self.emb.parameters())[-1].shape[0]
        if freeze:
            for param in self.emb.parameters():
                param.requires_grad = False

        self.punct_hid_size = 392
        self.spelling_hid_size = 392
        self.punctuation_head = Sequential(
                Dropout(p = 0.1),
                Linear(self.emb_size, self.punct_hid_size),
                Linear(self.punct_hid_size, num_punct_classes)
        )

        self.spelling_head = Sequential(
                Dropout(p = 0.1),
                Linear(self.emb_size, self.spelling_hid_size),
                Linear(self.spelling_hid_size, 1)
        )


    def forward(self, input_ids, attention_mask, **kwargs):
        emb = self.emb(input_ids = input_ids, attention_mask = attention_mask)[0]
        punct_output = self.punctuation_head(emb)
        spelling_output = self.spelling_head(emb)
        return punct_output, spelling_output

## Preparing data

In [24]:
PUNCT = ',;:.'
tag2punct = dict(enumerate(PUNCT))
punct2tag = dict(zip(tag2punct.values(), tag2punct.keys()))

In [None]:
import re
import numpy as np
import torch
torch.set_default_dtype(torch.float32)


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, punct2tag):
        self.data = data
        self.tokenizer = tokenizer
        self.punct2tag = punct2tag
        punct_ids = self.tokenizer(list(self.punct2tag.keys()), add_special_tokens = False).input_ids
        punct_ids = [p[0] for p in punct_ids]
        assert len(punct_ids) == len(self.punct2tag)
        self.id2tag = dict(zip(punct_ids, punct2tag.values()))
        self.punct_regexp = re.compile('|'.join(f'\{p}' for p in self.punct2tag))

    def delete_punct(self, text):
        return self.punct_regexp.sub('', str(text))

    def get_punct_label(self, source_tokens, source_tokens_with_punct):
        punct = []
        for i, token in enumerate(source_tokens):
            punct_label = torch.zeros(len(self.id2tag))
            mentions = np.flatnonzero(np.array(source_tokens_with_punct) == token)
            for m in mentions:
                if np.isclose(i, m, atol=0.5e1) and m < len(source_tokens_with_punct) - 1:
                    if source_tokens_with_punct[m + 1] in self.id2tag:
                        punct_label[self.id2tag[source_tokens_with_punct[m + 1]]] = 1
            punct.append(punct_label)
        return punct

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx):
        instance = self.data.iloc[idx]
        source, target = str(instance.source), str(instance.target)
        source_tokens_with_punct = self.tokenizer(source, add_special_tokens = False).input_ids
        source, target = self.delete_punct(source), self.delete_punct(target)
        source_tokens = self.tokenizer(source, add_special_tokens = False).input_ids
        target_tokens = self.tokenizer(target, add_special_tokens = False).input_ids

        spelling = []
        spelling = [0.0 if token in target_tokens else 1.0 for token in source_tokens]

        punct = self.get_punct_label(source_tokens, source_tokens_with_punct)

        return {'input_ids':torch.tensor(source_tokens),
                'spelling': torch.tensor(spelling).unsqueeze(1),
                'punct': torch.vstack(punct)}

In [None]:
def pad(sequence, max_length, padding = 0):
    pad_size = max_length - len(sequence)
    pad_sequence = [padding] * pad_size
    if pad_sequence:
        pad_tensor = torch.tensor(pad_sequence)
        return torch.cat([sequence, pad_tensor])
    return sequence


def collate_fn(batch):
    max_length = max(len(example['input_ids']) for example in batch)
    padded_input_ids, attention_mask, padded_spelling,padded_punct = [], [], [] , []
    for example in batch:
        padded_input_ids.append(pad(example['input_ids'], max_length))
        attention_mask.append(pad(torch.ones_like(example['input_ids']), max_length))
        padded_spelling.append(pad(example['spelling'], max_length, padding = [0]))
        padded_punct.append(pad(example['punct'], max_length, padding = [0] * len(PUNCT)))
    return {'input_ids': torch.stack(padded_input_ids),
            'attention_mask': torch.stack(attention_mask),
            'spelling': torch.stack(padded_spelling),
            'punct': torch.stack(padded_punct)}

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader

RANDOM_STATE = 42
BATCH_SIZE = 64
clang_train_data, clang_val_data = train_test_split(data, test_size=0.2, random_state = RANDOM_STATE)
clang_train_dataset = Dataset(clang_train_data, tokenizer, punct2tag)
clang_val_dataset = Dataset(clang_val_data, tokenizer, punct2tag)
lenta_train_data, lenta_val_data = train_test_split(lenta_data, test_size=0.2, random_state = RANDOM_STATE)
lenta_train_dataset = Dataset(lenta_train_data, tokenizer, punct2tag)
lenta_val_dataset = Dataset(lenta_val_data, tokenizer, punct2tag)
train_dataset = clang_train_dataset  + lenta_train_dataset
val_dataset = clang_val_dataset  + lenta_val_dataset
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle = True, collate_fn = collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle = True, collate_fn = collate_fn)

## Trainer

In [None]:
from tqdm.notebook import tqdm, trange
from torch.optim import Adam, AdamW
from torch.nn.functional import sigmoid
from torchvision.ops import sigmoid_focal_loss
from sklearn.metrics import f1_score

def accuracy(true, logits, threshold):
    pred = (sigmoid(logits) >= threshold).int()
    correct = torch.count_nonzero(true == pred)
    total = true.size().numel()
    return correct / total

def f1(true, logits, threshold):
    pred = (sigmoid(logits.cpu()) >= threshold).int().view(-1, logits.shape[-1])
    true = true.cpu().int().view(-1, true.shape[-1])
    return f1_score(true.cpu(), pred.cpu(), average='macro', zero_division = 1)

def forward(batch, model, device, alpha = 0.85, spelling_threshold = 0.5, punct_threshold = 0.5):
    ids = batch['input_ids'].to(device)
    mask = batch['attention_mask'].to(device)
    spelling_labels = batch['spelling'].to(device)
    punct_labels = batch['punct'].to(device)

    punct_output, spelling_output = model(ids, mask)

    spelling_loss = sigmoid_focal_loss(spelling_output, spelling_labels, reduction = 'mean', alpha = alpha)
    punct_loss = sigmoid_focal_loss(punct_output, punct_labels, reduction = 'mean', alpha = alpha)

    #spelling_accuracy = accuracy(spelling_labels, spelling_output, spelling_threshold)
    #punct_accuracy = accuracy(punct_labels, punct_output, punct_threshold)

    spelling_f1 = f1(spelling_labels, spelling_output, spelling_threshold)
    punct_f1 = f1(punct_labels, punct_output, punct_threshold)

    return spelling_loss, punct_loss, spelling_f1, punct_f1


def pass_epoch(loader, model, optimizer = None, split = 'train', **kwargs):
    sum_spelling_f1 = 0
    sum_punct_f1 = 0
    tq = tqdm(loader)
    for step, batch in enumerate(tq):
        spelling_loss, punct_loss, spelling_f1, punct_f1 = forward(batch, model, **kwargs)
        if split == 'train':
            spelling_loss.backward()
            punct_loss.backward()
            optimizer.step()
            optimizer.zero_grad()
        sum_spelling_f1 += spelling_f1
        sum_punct_f1 += punct_f1
        avg_spelling_f1 = sum_spelling_f1/ (step + 1)
        avg_punct_f1 = sum_punct_f1/ (step + 1)
        tq.set_description(f'{split.capitalize()}: Spelling loss: {spelling_loss:4.4f} Punct loss: {punct_loss:4.4f}\
        Avg Spelling f1: {avg_spelling_f1:4.4f} Avg Punct f1: {avg_punct_f1:4.4f}')
    return spelling_loss, punct_loss, avg_spelling_f1, avg_punct_f1


def train(train_loader, val_loader, model, optim = Adam,
          epochs = 30, lr=1e-3, device = 'cpu',
          val_epoch = 1, save_epoch = 1, path = 'model.pt', **kwargs):
    optimizer = optim(params=[p for p in model.parameters() if p.requires_grad], lr=lr)
    model.to(device)

    for epoch in trange(1, epochs + 1):
        print(f'Epoch {epoch}')
        model.train()
        spelling_loss, punct_loss, avg_spelling_f1, avg_punct_f1 = pass_epoch(train_loader, model, optimizer = optimizer, split = 'train', device = device, **kwargs)

        if epoch % val_epoch == 0:
            spelling_loss, punct_loss, avg_spelling_f1, avg_punct_f1 = pass_epoch(val_loader, model, split = 'val', device = device, **kwargs)

        if epoch % save_epoch == 0:
            torch.save(model, path)
            print(f'Model saved to {path}')

# Training

In [None]:
#model = Model(pretrained_model_name = MODEL_NAME, num_punct_classes = len(PUNCT))

In [None]:
checkpoint_path = os.path.join(DRIVE_ROOT, 'bert_punct_spelling.pt')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torch.load(checkpoint_path, map_location = device)

In [None]:
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'
checkpoint_path = os.path.join(DRIVE_ROOT, 'bert_punct_spelling.pt')
model = model.to(device)
train(train_loader, val_loader, model, device = device, path = checkpoint_path, lr=3e-5, alpha = 0.85)

  0%|          | 0/30 [00:00<?, ?it/s]

Epoch 1


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 2


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 3


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 4


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 5


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 6


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 7


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 8


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 9


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 10


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 11


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 12


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 13


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 14


  0%|          | 0/686 [00:00<?, ?it/s]

  0%|          | 0/172 [00:00<?, ?it/s]

Model saved to /content/drive/MyDrive/murmansk/bert_punct_spelling.pt
Epoch 15


  0%|          | 0/686 [00:00<?, ?it/s]

KeyboardInterrupt: ignored

# Masked Correction

In [None]:
PUNCT = ',;:.'
tag2punct = dict(enumerate(PUNCT))
punct2tag = dict(zip(tag2punct.values(), tag2punct.keys()))

In [18]:
!pip install python-Levenshtein

Collecting python-Levenshtein
  Downloading python_Levenshtein-0.21.1-py3-none-any.whl (9.4 kB)
Collecting Levenshtein==0.21.1 (from python-Levenshtein)
  Downloading Levenshtein-0.21.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (172 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m172.5/172.5 kB[0m [31m3.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting rapidfuzz<4.0.0,>=2.3.0 (from Levenshtein==0.21.1->python-Levenshtein)
  Downloading rapidfuzz-3.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m30.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, Levenshtein, python-Levenshtein
Successfully installed Levenshtein-0.21.1 python-Levenshtein-0.21.1 rapidfuzz-3.3.0


In [19]:
import torch
checkpoint_path = os.path.join(DRIVE_ROOT, 'bert_punct_spelling.pt')
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = torch.load(checkpoint_path, map_location = device)

In [None]:
from transformers import AutoModelForMaskedLM

mlm = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)

Downloading pytorch_model.bin:   0%|          | 0.00/714M [00:00<?, ?B/s]

In [None]:
import torch
from torch.nn.functional import sigmoid
from Levenshtein import distance as lev_distance
from nltk.tokenize.treebank import TreebankWordDetokenizer

class MaskCorrector:
    def __init__(self, model, tokenizer, mlm, tag2punct, device = 'cpu', k = 10, spelling_threshold = 0.5,  punct_threshold = 0.5):
        self.EXTENDED_PUNCT = ',.?!:;'
        self.TOKEN_SEP = '##'
        self.MASK = '[MASK]'
        self.device = device
        self.model = model.to(self.device)
        self.tokenizer = tokenizer
        self.mlm = mlm.to(self.device)
        self.tag2punct = tag2punct
        self.spelling_threshold =  spelling_threshold
        self.punct_threshold = punct_threshold
        self.k = k
        self.mask_id = tokenizer(self.MASK, add_special_tokens = False).input_ids[0]
        self.detokenizer =  TreebankWordDetokenizer()

    def get_predictions(self, tokenized_text):
        self.model.eval()
        punct_output, spelling_output = model(**{k:v.to(self.device) for k, v in tokenized_text.items()})
        return punct_output, spelling_output

    def mask_mistakes(self, tokenized_text):
        _, spelling_output  = self.get_predictions(tokenized_text)
        spelling_pred = (sigmoid(spelling_output.cpu()) >= self.spelling_threshold).int().squeeze(-1)
        masked_text = tokenized_text['input_ids'].clone()
        masked_text[spelling_pred == 1] = self.mask_id
        return masked_text, tokenized_text['input_ids'][spelling_pred == 1]

    def select_correction(self, mistakes, candidates, probabilities):
        corrections = []
        for mistake, cands, probs in zip(mistakes, candidates, probabilities):
            lev_threshold = len(mistake)*2
            levs = []
            for cand, prob in zip(cands, probs):
                lev = lev_distance(cand, mistake)
                if lev < lev_threshold:
                    levs.append((lev, prob.item(), cand))
            levs = sorted(levs, key = lambda x: (x[0], -x[1]))
            if levs:
                corrections.append(levs[0][2])
            else:
                corrections.append(mistake)
        return corrections

    def get_corrections(self, masked_text, masked_tokens):
        self.mlm.eval()
        logits = mlm(input_ids = masked_text.to(self.device)).logits
        mask_token_index = torch.where(masked_text == self.mask_id)[1]
        mask_token_logits = logits[0, mask_token_index]
        candidates = torch.topk(mask_token_logits, self.k, dim=1)
        mistakes = self.tokenizer.batch_decode(masked_tokens)
        correction_candidates = [self.tokenizer.batch_decode(c) for c in candidates.indices]
        return mistakes, correction_candidates, candidates.values

    def join_tokens(self, tokens):
        text = []
        full_token = ''
        for token in tokens:
            if token.startswith(self.TOKEN_SEP):
                full_token += token.replace(self.TOKEN_SEP, '')
            else:
                text.append(full_token)
                full_token = token
        text.append(full_token)
        text.pop(0)
        return text

    def correct_spelling(self, tokenized_text):
        masked_text, masked_tokens = self.mask_mistakes(tokenized_text)
        mistakes, candidates, probabilities = self.get_corrections(masked_text, masked_tokens)
        corrections = self.select_correction(mistakes, candidates, probabilities)
        tokens = self.tokenizer.convert_ids_to_tokens(masked_text[0])
        tokens_corrected_spelling = [corrections.pop(0) if token == self.MASK else token for token in tokens]
        return tokens_corrected_spelling

    def change_punctuation(self, tokenized_text, decoded_tokens):
        punct_output, _  = self.get_predictions(tokenized_text)
        punct_pred = sigmoid(punct_output.cpu())
        punct_pred = punct_pred.where(punct_pred >= self.punct_threshold, 0)
        text = []
        next_punct = ''
        for token, pred in zip(decoded_tokens, punct_pred[0]):
            if token in self.EXTENDED_PUNCT:
                text.append(token)
            else:
                text.append(next_punct)
                text.append(token)
            if pred.any() > 0:
                tag = int(torch.argmax(pred))
                next_punct = self.tag2punct[tag]
                if next_punct == '.':
                  next_punct = ''
            else:
                next_punct = ''
        return [token for token in text if token]

    def detokenize(self, tokens):
      text = self.detokenizer.detokenize(tokens)
      text = text.replace(' .','.')
      return text

    def correct(self, text):
        text = str(text)
        tokenized_text = self.tokenizer(text, add_special_tokens = False, return_tensors='pt')
        tokens_corrected_spelling = self.correct_spelling(tokenized_text)
        tokens_corrected_punct = self.change_punctuation(tokenized_text, tokens_corrected_spelling)
        corrected_text = self.join_tokens(tokens_corrected_punct)
        return self.detokenize(corrected_text)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
mask_corrector = MaskCorrector(model = model, tokenizer = tokenizer, mlm = mlm, tag2punct = tag2punct, device = device,
                               k = 10000,
                               spelling_threshold = 0.8, punct_threshold = 0.3)

In [None]:
text = 'Добрй день увжаемые кллеги. Сегодня с нами на совещании:\n- Максим\n- Эдуард\n- Сергей'
print(f'Initial text: {text}')
print(f'Corrected text: {mask_corrector.correct(text)}')

Initial text: Добрй день увжаемые кллеги. Сегодня с нами на совещании:
- Максим
- Эдуард
- Сергей
Corrected text: Добрй день, увжаемые кллеги. Сегодня с нами на совещании: - Максим - Эдуард - Сергей


# Norvig Correction

In [103]:
import os
import json
dict_path = os.path.join(DRIVE_ROOT, 'dict.json')
with open(dict_path) as jsfile:
    dictionary = json.load(jsfile)

In [97]:
class NorvigAlgorithm:
    def __init__(self, dictionary):
        self.dictionary = dictionary
        self.N = sum(self.dictionary.values())
        self.letters    = 'абвгдеёжзийклмнопрстуфхцчшщъыьэюя'

    def P(self, word):
        return self.dictionary[word] / self.N

    def correction(self, word):
        return max(self.candidates(word), key=P)

    def candidates(self, word):
        return (self.known([word]) or self.known(self.edits1(word)) or self.known(self.edits2(word)) or [word])

    def known(self, words):
        return set(w for w in words if w in self.dictionary)

    def edits1(self, word):
        splits     = [(word[:i], word[i:])    for i in range(len(word) + 1)]
        deletes    = [L + R[1:]               for L, R in splits if R]
        transposes = [L + R[1] + R[0] + R[2:] for L, R in splits if len(R)>1]
        replaces   = [L + c + R[1:]           for L, R in splits if R for c in self.letters]
        inserts    = [L + c + R               for L, R in splits for c in self.letters]
        return set(deletes + transposes + replaces + inserts)

    def edits2(self, word):
        return (e2 for e1 in self.edits1(word) for e2 in self.edits1(e1))

In [110]:
import torch
from string import punctuation
from torch.nn.functional import sigmoid
from nltk.tokenize.treebank import TreebankWordDetokenizer

class NorvigCorrector:
    def __init__(self, model, tokenizer, corrector, tag2punct, device = 'cpu', spelling_threshold = 0.5,  punct_threshold = 0.5):
        self.EXTENDED_PUNCT = ',.?!:;'
        self.TOKEN_SEP = '##'
        self.device = device
        self.model = model.to(self.device)
        self.tokenizer = tokenizer
        self.tag2punct = tag2punct
        self.spelling_threshold = spelling_threshold
        self.punct_threshold = punct_threshold
        self.detokenizer =  TreebankWordDetokenizer()
        self.corrector =  corrector

    def get_predictions(self, tokenized_text):
        self.model.eval()
        punct_output, spelling_output = model(**{k:v.to(self.device) for k, v in tokenized_text.items()})
        punct_preds = sigmoid(punct_output.cpu())
        punct_preds = punct_preds.where(punct_preds >= self.punct_threshold, 0)
        spelling_preds = (sigmoid(spelling_output.cpu()) >= self.spelling_threshold).int().squeeze(-1)
        return punct_preds, spelling_preds

    def merge_tokens(self, tokenized_text, punct_preds, spelling_preds):
        bpe_tokens = self.tokenizer.convert_ids_to_tokens(tokenized_text['input_ids'][0], skip_special_tokens=True)
        tokens, misspelled, next_puncts = [], [], []
        token = ''
        is_misspelled = 0
        next_punct = ''
        for bpe_token, punct_pred, spelling_pred in zip(bpe_tokens, punct_preds[0], spelling_preds[0]):
            if spelling_pred.item() == 1:
                is_misspelled = 1
            if punct_pred.any() > 0:
                tag = int(torch.argmax(punct_pred))
                next_punct = self.tag2punct[tag]
            if bpe_token.startswith(self.TOKEN_SEP):
                token += bpe_token.replace(self.TOKEN_SEP, '')
            else:
                tokens.append(token)
                misspelled.append(is_misspelled)
                next_puncts.append(next_punct)
                token = bpe_token
                is_misspelled = 0
                next_punct = ''
        tokens.append(token)
        misspelled.append(is_misspelled)
        next_puncts.append(next_punct)
        tokens.pop(0)
        misspelled.pop(0)
        next_puncts.pop(0)
        assert len(tokens) == len(misspelled) == len(next_puncts)
        return tokens, misspelled, next_puncts

    def correct_spelling(self, tokens, misspelled):
        corrected_tokens = []
        for token, is_misspelled in zip(tokens, misspelled):
            if is_misspelled:
                corrected_tokens.append(self.replace(token))
            else:
                corrected_tokens.append(token)
        return corrected_tokens

    def replace(self, token):
        capital = token[0].isupper()
        replacement = self.corrector.correction(token)
        if capital:
            replacement = replacement.capitalize()
        return replacement

    def change_punctuation(self, tokens, next_puncts):
        text = []
        model_punct = False
        for token, next_punct in zip(tokens, next_puncts):
            if next_punct:
                if token in punctuation:
                    text.append(next_punct)
                else:
                    text.append(token)
                    text.append(next_punct)
                model_punct = True
            elif token not in list(self.tag2punct.values()) + ['!', '?'] or not model_punct:
                text.append(token)
                model_punct = False
        return [token for token in text if token]

    def detokenize(self, tokens):
      text = self.detokenizer.detokenize(tokens)
      text = text.replace(' .','.')
      return text

    def correct(self, text):
        text = str(text)
        tokenized_text = self.tokenizer(text, add_special_tokens = False, return_tensors='pt')
        punct_preds, spelling_preds = self.get_predictions(tokenized_text)
        tokens, misspelled, next_puncts = self.merge_tokens(tokenized_text, punct_preds, spelling_preds)
        tokens_corrected_spelling = self.correct_spelling(tokens, misspelled)
        corrected_text = self.change_punctuation(tokens_corrected_spelling, next_puncts)
        return self.detokenize(corrected_text)

In [117]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
algorithm = NorvigAlgorithm(dictionary)
norvig_corrector = NorvigCorrector(model = model, tokenizer = tokenizer, tag2punct = tag2punct, device = device, corrector = algorithm,
                               spelling_threshold = 0.5, punct_threshold = 0.5)

In [118]:
#text = data.source.iloc[21]
import torch
text = 'Добрй день увжаемые кллеги. Сегодня с нами на совещании:\n- Максим\n- Эдуард\n- Сергей'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Initial text: {text}')
print(f'Corrected text: {norvig_corrector.correct(text)}')

Initial text: Добрй день увжаемые кллеги. Сегодня с нами на совещании:
- Максим
- Эдуард
- Сергей
Corrected text: Добрый день уважаемые коллеги. Сегодня с нами на совещании: - Максим - Эдуард - Сергей


# Hyperparam Tuning

In [None]:
!pip install optuna

Collecting optuna
  Downloading optuna-3.3.0-py3-none-any.whl (404 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m404.2/404.2 kB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting alembic>=1.5.0 (from optuna)
  Downloading alembic-1.12.0-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.0/226.0 kB[0m [31m23.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting cmaes>=0.10.0 (from optuna)
  Downloading cmaes-0.10.0-py3-none-any.whl (29 kB)
Collecting colorlog (from optuna)
  Downloading colorlog-6.7.0-py2.py3-none-any.whl (11 kB)
Collecting Mako (from alembic>=1.5.0->optuna)
  Downloading Mako-1.2.4-py3-none-any.whl (78 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.7/78.7 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: Mako, colorlog, cmaes, alembic, optuna
Successfully installed Mako-1.2.4 alembic-1.12.0 cmaes-0.10.0 colorlog-6.7.0 optuna-3.3.0


In [None]:
def objective(trial, train_loader, val_loader, checkpoint_path):
  model = torch.load(checkpoint_path)
  model.to(device)
  learning_rate = trial.suggest_float('learning_rate', 1e-5, 4e-3, log=True)
  optim = trial.suggest_categorical('optimizer', ['adam', 'adamW'])
  alpha = trial.suggest_float('alpha', 0.6, 0.9)
  #use_weight_decay = trial.suggest_int('use_weight_decay', 0, 1)
  #weight_decay = trial.suggest_float('weight_decay', 1e-3, 1e-1, log=True)
  #trainable_layers = trial.suggest_int('trainable_layers', 30, 36)

  optimizers = {'adam':torch.optim.Adam, 'adamW':torch.optim.AdamW}
  optimizer = optimizers[optim](params=[p for p in model.parameters() if p.requires_grad], lr=learning_rate)

  pass_epoch(train_loader, model, optimizer = optimizer, split = 'train', device = device)
  spelling_loss, punct_loss, spelling_f1, punct_f1 = pass_epoch(val_loader, model, optimizer = None, split = 'val', device = device)
  return punct_f1

In [None]:
checkpoint_path = os.path.join(DRIVE_ROOT, 'bert_punct_spelling.pt')
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
import optuna
from optuna.samplers import TPESampler

optuna.logging.set_verbosity(optuna.logging.INFO)

sampler = TPESampler(seed=1)
study = optuna.create_study(study_name="correction-model", direction="maximize", sampler=sampler)
study.optimize(lambda trial: objective(trial, train_loader, val_loader, checkpoint_path), n_trials=10)

[I 2023-09-17 23:30:33,601] A new study created in memory with name: correction-model


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:34:01,136] Trial 0 finished with value: 0.3083370309699272 and parameters: {'learning_rate': 0.0001216511654495502, 'optimizer': 'adam', 'a': 0.6906997717895519}. Best is trial 0 with value: 0.3083370309699272.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:37:25,897] Trial 1 finished with value: 0.32209902661244555 and parameters: {'learning_rate': 2.4091710288757674e-05, 'optimizer': 'adamW', 'a': 0.7036682181129144}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:40:51,975] Trial 2 finished with value: 0.31820744748278085 and parameters: {'learning_rate': 0.00010774888148633154, 'optimizer': 'adam', 'a': 0.8055658501190278}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:44:19,685] Trial 3 finished with value: 0.3141405171223704 and parameters: {'learning_rate': 3.4040585327316366e-05, 'optimizer': 'adam', 'a': 0.8011402530535207}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:47:47,816] Trial 4 finished with value: 0.3199483239359745 and parameters: {'learning_rate': 0.00012185746252275139, 'optimizer': 'adam', 'a': 0.6594304467254636}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:51:14,543] Trial 5 finished with value: 0.30811899622425165 and parameters: {'learning_rate': 0.0012122310545199342, 'optimizer': 'adam', 'a': 0.8076967847007942}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:54:39,736] Trial 6 finished with value: 0.3128279942869596 and parameters: {'learning_rate': 0.0019072918368943854, 'optimizer': 'adam', 'a': 0.6117164349698647}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-17 23:58:04,551] Trial 7 finished with value: 0.3065927408294697 and parameters: {'learning_rate': 2.7663615525486703e-05, 'optimizer': 'adam', 'a': 0.7263322875015157}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-18 00:01:30,065] Trial 8 finished with value: 0.30153733105614755 and parameters: {'learning_rate': 0.0031080358710310007, 'optimizer': 'adamW', 'a': 0.6946546893018188}. Best is trial 1 with value: 0.32209902661244555.


  0%|          | 0/281 [00:00<?, ?it/s]

  0%|          | 0/71 [00:00<?, ?it/s]

[I 2023-09-18 00:04:54,563] Trial 9 finished with value: 0.31119475525807616 and parameters: {'learning_rate': 0.0006113875601531118, 'optimizer': 'adam', 'a': 0.8250432944834902}. Best is trial 1 with value: 0.32209902661244555.


# Coloring

In [None]:
!pip install colorama

Collecting colorama
  Downloading colorama-0.4.6-py2.py3-none-any.whl (25 kB)
Installing collected packages: colorama
Successfully installed colorama-0.4.6


In [None]:
from colorama import Back, Style

def checker(texts, model, tokenizer, tag2punct, device = 'cpu', spelling_threshold = 0.5, punct_threshold = 0.5):
    TOKEN_SEP = '##'
    model.eval()
    model.to(device)
    max_length = max(len(text.split(' ')) for text in texts)*2
    tokenized_texts = tokenizer(texts, add_special_tokens = False, return_tensors='pt', padding = 'max_length',max_length = max_length)
    punct_output, spelling_output = model(**{k:v.to(device) for k, v in tokenized_texts.items()})
    punct_pred = (sigmoid(punct_output.cpu()) >= punct_threshold).int()
    spelling_pred = (sigmoid(spelling_output.cpu()) >= spelling_threshold).int()
    checked_texts = []
    for i_text, text in enumerate(texts):
        checked_text = ''
        tokens = tokenizer.convert_ids_to_tokens(tokenized_texts['input_ids'][i_text], skip_special_tokens=True)
        for i_token, token in enumerate(tokens):
            space = True
            if token.startswith(TOKEN_SEP):
                space = False
                token = token.replace(TOKEN_SEP, '')
            checked_text += ' ' * space

            if token in PUNCT and not next_ch_is_punct:
                checked_text += f'{Back.CYAN}{token}{Back.RESET}'
            elif spelling_pred[i_text][i_token]:
                checked_text += f'{Back.MAGENTA}{token.upper()}{Back.RESET}'
            else:
                checked_text += token

            next_ch_is_punct = True
            if sum(punct_pred[i_text][i_token]) > 0:
                punct_to_add = tag2punct[int(torch.argmax(punct_pred[i_text][i_token]))]
                if i_token < len(tokens) - 1:
                    if  tokens[i_token + 1] != punct_to_add:
                        checked_text += f'{Back.YELLOW}{punct_to_add}{Back.RESET}'
                else:
                    checked_text += f'{Back.YELLOW}{punct_to_add}{Back.RESET}'
            elif sum(punct_pred[i_text][i_token]) == 0 and i_token < len(tokens) - 1:
                if tokens[i_token + 1] in PUNCT:
                    next_ch_is_punct = False
        checked_texts.append(checked_text)
    for text in checked_texts:
      print(text)
    return checked_texts

In [None]:
checker(data.source.iloc[0:20].tolist(), model, tokenizer, tag2punct, device = device);

 Эта книга как автобиография его [46m.[49m
 Краткое содержание этой книги , герой не ходит в школу [46m.[49m Потом он каждый день сидет дома [46m.[49m
 Ни с кем не [45mГОВ[49m[45mАРИТ[49mь [46m,[49m не встречаться [46m,[49m не выйти из дома [46m.[49m
 Но [46m,[49m конец он нашёл свою мечту [46m.[49m Когда мне было 13 лет я читала этот рассказ на журнале .
 Я очень испытала симпатию его [46m.[49m Чувство полового созревания [46m,[49m думать что я специальный человек[43m,[49m чем всех [46m,[49m и [46m.[49m т [46m.[49m д [46m.[49m [46m.[49m [46m.[49m
 Теперь мне 26 лет [46m.[49m Я прочитала её [46m.[49m Я помню его [46m.[49m
 Но я чувствую его далеко чем в детстве [46m.[49m
 Как готовить " кэ[45mР[49mри рис "
 Я хочу писать как готовить " кэ[45mР[49mри рис - рис с мясом[43m.[49m и овощами[43m.[49m " [46m.[49m
 Этот рис с мясом и овощами готовят из репчатого лука [46m,[49m мяса [46m,[49m картошки и морковы [46m.[49m
 1 [46m.[4