In [1]:
!pip install transformers[sentencepiece]
!pip install --user -U nltk



In [2]:
import re
import requests
import pandas as pd
from transformers import M2M100ForConditionalGeneration, M2M100Tokenizer, logging
import nltk
import time
import torch

Подсоединение к Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
initTime = time.time()

nltk.download('punkt')
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model_checkpoint = "ai-forever/RuM2M100-1.2B"
model = M2M100ForConditionalGeneration.from_pretrained(model_checkpoint)
tokenizer = M2M100Tokenizer.from_pretrained(model_checkpoint, src_lang="ru", tgt_lang="ru")

model = model.to(device)

print ("Time initialization: ", time.time() - initTime)
logging.set_verbosity_error()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Time initialization:  46.91407513618469


In [5]:
with open('/content/drive/My Drive/errorsents.txt', encoding="utf-8") as f:
  sentmarked_list = f.readlines()

Собираем информацию о словах с ошибками

In [6]:
errPattern = re.compile("_([\w\s-]+)_")
sent_list = []
errtup_list = []
for s in sentmarked_list:
    errIter = errPattern.finditer(s)
    ss = re.sub("_", "", s)
    sent_list.append(ss)
    j = 0
    errtup = []
    for err in errIter:
        errtup.append((err.start()-j*2, err.end()-2*(1 + j)))
        j += 1
    errtup_list.append(errtup)

Загружаем предложения без ошибок.
В последствии здесь тоже могут выявлены ошибки, поэтому также
формируем errors_pos

In [7]:
with open('/content/drive/My Drive/correctsents.txt', encoding="utf-8") as f:
  sentcorrect_list = f.readlines()

In [8]:
for s in sentcorrect_list:
    errIter = errPattern.finditer(s)
    ss = re.sub("_", "", s)
    sent_list.append(ss)
    j = 0
    errtup = []
    for err in errIter:
        errtup.append((err.start()-j*2, err.end()-2*(1 + j)))
        j += 1
    errtup_list.append(errtup)

Собираем DataFrame

In [9]:
sentences = pd.DataFrame(list(zip(sent_list, errtup_list)), columns =['sentences', 'errors_pos'])

Начальные параметры

In [10]:
p = re.compile(r'[\w-]+')

TP = 0
TN = 0
FP = 0
FN = 0

Маскирование токенов в предложении и определение ошибок

In [11]:
def m2m100_sent(text, model, tokenizer):
    specsigns = [('№', 'No')]
    for ss in specsigns:
      text = re.sub(ss[0], ss[1], text)

    # masking whole text and return errors if available
    tr = 0.1 #threshold of the error (parameter !!!)

    p = re.compile(r'[\w-]+')
    listErr = []
    ind = 0

    # собираем данные о словах
    iter = p.finditer(text)
    words = []
    for w in iter:
      words.append([w.group(), w.start(), w.end()])

    inputs = tokenizer(text, text_target=text, return_tensors="pt").to(device)
    logits = model(**inputs).logits
    match_str = ""
    istored = -1

    for i, x in enumerate(inputs["input_ids"][0]):
      if tokenizer.decode(x) not in tokenizer.all_special_tokens:
        match_str += re.escape(tokenizer.decode(x))
        masked_token = x.item()
        mask_token_index = torch.where(inputs["input_ids"] == x)[1]
        mask_token_logits = logits[0, mask_token_index, :]
        probs = torch.nn.functional.softmax(mask_token_logits, dim=1)
        # print(f'Token: [{tokenizer.decode(x)}]; probability: {probs[0][masked_token].item()}')

        # Add to error tokens, if threshold > probability
        if tr > probs[0][masked_token].item():
          merr = re.match(r''+ match_str, text)
          if merr:
            #print (merr, words)
            #print (match_str)
            for w in words:
              if merr.end() > w[1] and merr.end() <= w[2]:
                if istored != i:
                  errorrDesc = {
                    "word": w[0],
                    "start": w[1],
                    "end": w[2],
                    "prob": probs[0][masked_token].item()
                  }
                  listErr.append(errorrDesc)
                  istored = i
          else:
            print (merr, words)
            print (match_str)

        match_str += "\s*"
    return listErr

In [12]:
# ind = 29
# response = m2m100_sent(sentences["sentences"][ind], model, tokenizer)

# print (response)

Перебор всех предложений с DataFrame и сбор сведений для расчета метрики

In [13]:
outerrors = []
startPos = 0
for ind in sentences.index:
   if (ind % 10) == 0:
     print ("Index: ", ind, "   Time:", time.time() - initTime)
   if ind >=0 and ind < 10000:
    response = m2m100_sent(sentences["sentences"][ind], model, tokenizer)
    if len(response) == 0:
        # No errors found
        ntoken = len(p.findall(sentences["sentences"][ind]))
        nerr = len(sentences["errors_pos"][ind])
        # True Negative = number of words without errors
        TN += ntoken - nerr
        # False Negative - nerr errors were not found
        FN += nerr
        article_err = sentences["sentences"][ind]
    else:
        # errors were found: iter by found errors
        nfound = 0
        startPos = 0
        article_err = ""
        for e in response:
            article_err += sentences["sentences"][ind][startPos:e['start']] + "<mark>" + sentences["sentences"][ind][e['start']:e['end']] + "</mark>"
            startPos = e['end']
            # found errors correct or not
            iscorrect = 0
            for trueerr in sentences["errors_pos"][ind]:
                if trueerr[0] == e["start"]:
                    iscorrect = 1
                    nfound += 1
                    break
            if iscorrect == 1:
                # True Positive - error word was found
                TP += 1
        article_err += sentences["sentences"][ind][startPos:]
        # False Positive - wrong errors were found
        if (len(response) - nfound) > 0:
            FP += len(response) - nfound
        # False Negative - errors were not found
        if (len(sentences["errors_pos"][ind]) - nfound) > 0:
            FN += len(sentences["errors_pos"][ind]) - nfound
    outerrors.append(article_err)

Index:  0    Time: 47.22422194480896
Index:  10    Time: 70.32453942298889
Index:  20    Time: 104.87640357017517
Index:  30    Time: 129.22294330596924
Index:  40    Time: 165.66042041778564
Index:  50    Time: 189.97816801071167
Index:  60    Time: 218.03775310516357
Index:  70    Time: 251.3896210193634
Index:  80    Time: 278.4599356651306
Index:  90    Time: 305.5596489906311
Index:  100    Time: 335.0792078971863
Index:  110    Time: 362.78939509391785
Index:  120    Time: 390.8618516921997
Index:  130    Time: 419.58144664764404
Index:  140    Time: 448.3350067138672
Index:  150    Time: 476.5265474319458
Index:  160    Time: 501.8937351703644
Index:  170    Time: 528.3508427143097
Index:  180    Time: 553.5342960357666
Index:  190    Time: 582.9996991157532
Index:  200    Time: 612.6839668750763
Index:  210    Time: 649.8606510162354
Index:  220    Time: 674.9029729366302
Index:  230    Time: 709.7499227523804
Index:  240    Time: 730.326123714447
Index:  250    Time: 758.69344

Сохранение файла с метками

In [14]:
with open('/content/drive/My Drive/errorsmarks.txt', "w", encoding="utf-8") as file:
  file.writelines(outerrors)

Расчет метрики

In [15]:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
fmeasure = 2 * recall * precision / (recall + precision)
print ("Precision: ", precision)
print ("Recall: ", recall)
print ("F-measure: ", fmeasure)

Precision:  0.8441176470588235
Recall:  0.5180505415162455
F-measure:  0.6420581655480986
