In [1]:
!pip install transformers[sentencepiece]
!pip install --user -U nltk



In [2]:
import re
import requests
import pandas as pd
from transformers import AutoTokenizer, AutoModel, AutoModelForMaskedLM, pipeline, logging
import nltk
import time
import torch

Подсоединение к Google Drive

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
initTime = time.time()

nltk.download('punkt')
device = "cuda:0" if torch.cuda.is_available() else "cpu"
# model_checkpoint = "ai-forever/ruRoberta-large"
model_checkpoint = "/content/drive/My Drive/fine-train-ruroberta202403265e5"

# tokenizer = AutoTokenizer.from_pretrained("ai-forever/ruRoberta-large")
model = AutoModelForMaskedLM.from_pretrained(model_checkpoint)
# tokenizer = AutoTokenizer.from_pretrained("cointegrated/rubert-tiny2")
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model     = model.to(device)

print ("Time initialization: ", time.time() - initTime)
logging.set_verbosity_error()

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


Time initialization:  45.34330224990845


In [5]:
with open('/content/drive/My Drive/errorsents.txt', encoding="utf-8") as f:
  sentmarked_list = f.readlines()

Собираем информацию о словах с ошибками

In [6]:
errPattern = re.compile("_([\w\s-]+)_")
sent_list = []
errtup_list = []
for s in sentmarked_list:
    errIter = errPattern.finditer(s)
    ss = re.sub("_", "", s)
    sent_list.append(ss)
    j = 0
    errtup = []
    for err in errIter:
        errtup.append((err.start()-j*2, err.end()-2*(1 + j)))
        j += 1
    errtup_list.append(errtup)

Загружаем предложения без ошибок.
В последствии здесь тоже могут выявлены ошибки, поэтому также
формируем errors_pos

In [7]:
with open('/content/drive/My Drive/correctsents.txt', encoding="utf-8") as f:
  sentcorrect_list = f.readlines()

In [8]:
for s in sentcorrect_list:
    errIter = errPattern.finditer(s)
    ss = re.sub("_", "", s)
    sent_list.append(ss)
    j = 0
    errtup = []
    for err in errIter:
        errtup.append((err.start()-j*2, err.end()-2*(1 + j)))
        j += 1
    errtup_list.append(errtup)

Собираем DataFrame

In [9]:
sentences = pd.DataFrame(list(zip(sent_list, errtup_list)), columns =['sentences', 'errors_pos'])

Начальные параметры

In [10]:
p = re.compile(r'[\w-]+')

TP = 0
TN = 0
FP = 0
FN = 0

Маскирование токенов в предложении и определение ошибок

In [11]:
def mask_bert_sent_old(text, model, tokenizer):
    maskToken = "[MASK]"
    # masking whole text and return errors if available
    tr = 5e-4 #threshold of the error (parameter !!!)
    unmasker = pipeline("fill-mask", model=model, tokenizer=tokenizer)
    p = re.compile(r'[\w-]+')
    sentArticle = nltk.tokenize.sent_tokenize(text, language="russian")
    listErr = []
    ind = 0
    for s in sentArticle:
        ind = text.index(s)
        iter = p.finditer(s)
        for match in iter:
            if match.start() == 0:
                masktext = maskToken + s[match.end():]
            elif match.end() == len(s):
                masktext = s[:match.start()] + maskToken
            else:
                masktext = s[:match.start()] + maskToken + s[match.end():]
            res = unmasker(masktext, targets=[match.group()], batch_size=8)
            if res[0]['score'] < tr:
                errorrDesc = {
                    "word": match.group(),
                    "start": ind + match.start(),
                    "end": ind + match.end(),
                    "prob": res[0]['score']
                }
                listErr.append(errorrDesc)
    return listErr

In [12]:
def mask_bert_sent(text, model, tokenizer):
    # masking whole text and return errors if available
    tr = 1e-3 #threshold of the error (parameter !!!)
    sentArticle = nltk.tokenize.sent_tokenize(text, language="russian")

    p = re.compile(r'[\w-]+')
    listErr = []
    ind = 0

    for s in sentArticle:
      inputs = tokenizer(s, return_tensors="pt").to(device)
      # print (inputs.word_ids)
      list_errtoken = []
      for i, x in enumerate(inputs["input_ids"][0]):
        if tokenizer.decode(x) not in tokenizer.all_special_tokens:
          masked_token = x.item()
          inputs["input_ids"][0][i] = tokenizer.mask_token_id
          token_logits = model(**inputs).logits

          # Find the location of [MASK] and extract its logits
          mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]
          mask_token_logits = token_logits[0, mask_token_index, :]

          probs = torch.nn.functional.softmax(mask_token_logits, dim=1)
          spanchar = inputs.token_to_chars(i)
          # Add to error tokens, if threshold > probability
          if tr > probs[0][masked_token].item():
            list_errtoken.append([spanchar.start, probs[0][masked_token].item()])
          # print (f'Probability of masked token ({tokenizer.decode(masked_token)}): {probs[0][masked_token].item()},  spanchar: {spanchar} ')
          inputs["input_ids"][0][i] = masked_token

      iter = p.finditer(s)
      istored = -1
      for idx, word in enumerate(iter):
        for terr in list_errtoken:
          # if start char of found token into the selected word
          # store this word once
          if terr[0] >= word.start() and terr[0] < word.end():
            if istored != idx:
              errorrDesc = {
                "word": word.group(),
                "start": ind + word.start(),
                "end": ind + word.end(),
                "prob": terr[1]
              }
              listErr.append(errorrDesc)
            istored = idx

    #print ("------------------")
    #print(*[(token_id, inputs.token_to_chars(idx)) for idx, token_id in enumerate(inputs.input_ids[0])], sep="\n")
    #print ("------------------")
    #print(*[(word, inputs.word_to_tokens(idx)) for idx, word in enumerate(text.split())], sep="\n")
    return listErr

In [13]:
def token_sent(text, model, tokenizer):
    specsigns = [('№', 'No')]
    for ss in specsigns:
      text = re.sub(ss[0], ss[1], text)

    # masking whole text and return errors if available
    tr = 1e-3 #threshold of the error (parameter !!!)

    p = re.compile(r'[\w-]+')
    listErr = []
    ind = 0

    # собираем данные о словах
    iter = p.finditer(text)
    words = []
    for w in iter:
      words.append([w.group(), w.start(), w.end()])

    inputs = tokenizer(text, text_target=text, return_tensors="pt").to(device)
    logits = model(**inputs).logits
    match_str = ""
    istored = -1

    for i, x in enumerate(inputs["input_ids"][0]):
      if tokenizer.decode(x) not in tokenizer.all_special_tokens:
        match_str += re.escape(tokenizer.decode(x))
        masked_token = x.item()
        mask_token_index = torch.where(inputs["input_ids"] == x)[1]
        mask_token_logits = logits[0, mask_token_index, :]
        probs = torch.nn.functional.softmax(mask_token_logits, dim=1)
        # print(f'Token: [{tokenizer.decode(x)}]; probability: {probs[0][masked_token].item()}')

        # Add to error tokens, if threshold > probability
        if tr > probs[0][masked_token].item():
          merr = re.match(r''+ match_str, text)
          if merr:
            #print (merr, words)
            #print (match_str)
            for w in words:
              if merr.end() > w[1] and merr.end() <= w[2]:
                if istored != i:
                  errorrDesc = {
                    "word": w[0],
                    "start": w[1],
                    "end": w[2],
                    "prob": probs[0][masked_token].item()
                  }
                  listErr.append(errorrDesc)
                  istored = i
          else:
            print (merr, words)
            print (match_str)

        match_str += "\s*"
    return listErr

Перебор всех предложений с DataFrame и сбор сведений для расчета метрики

In [14]:
outerrors = []
startPos = 0
for ind in sentences.index:
    if (ind % 10) == 0:
        print ("Index: ", ind, "   Time:", time.time() - initTime)
    response = mask_bert_sent(sentences["sentences"][ind], model, tokenizer)
    # response = token_sent(sentences["sentences"][ind], model, tokenizer)
    if len(response) == 0:
        # No errors found
        ntoken = len(p.findall(sentences["sentences"][ind]))
        nerr = len(sentences["errors_pos"][ind])
        # True Negative = number of words without errors
        TN += ntoken - nerr
        # False Negative - nerr errors were not found
        FN += nerr
        article_err = sentences["sentences"][ind]
    else:
        # errors were found: iter by found errors
        nfound = 0
        startPos = 0
        article_err = ""
        for e in response:
            article_err += sentences["sentences"][ind][startPos:e['start']] + "<mark>" + sentences["sentences"][ind][e['start']:e['end']] + "</mark>"
            startPos = e['end']
            # found errors correct or not
            iscorrect = 0
            for trueerr in sentences["errors_pos"][ind]:
                if trueerr[0] == e["start"]:
                    iscorrect = 1
                    nfound += 1
                    break
            if iscorrect == 1:
                # True Positive - error word was found
                TP += 1
        article_err += sentences["sentences"][ind][startPos:]
        # False Positive - wrong errors were found
        if (len(response) - nfound) > 0:
            FP += len(response) - nfound
        # False Negative - errors were not found
        if (len(sentences["errors_pos"][ind]) - nfound) > 0:
            FN += len(sentences["errors_pos"][ind]) - nfound
    outerrors.append(article_err)

Index:  0    Time: 46.437843322753906
Index:  10    Time: 54.87690258026123
Index:  20    Time: 64.78771471977234
Index:  30    Time: 71.4721109867096
Index:  40    Time: 80.16538834571838
Index:  50    Time: 86.21599960327148
Index:  60    Time: 92.65598106384277
Index:  70    Time: 101.84962201118469
Index:  80    Time: 108.9059944152832
Index:  90    Time: 115.78187084197998
Index:  100    Time: 123.92688202857971
Index:  110    Time: 130.51052713394165
Index:  120    Time: 137.8007357120514
Index:  130    Time: 145.16491317749023
Index:  140    Time: 151.80784034729004
Index:  150    Time: 159.21632599830627
Index:  160    Time: 164.25568652153015
Index:  170    Time: 170.22371435165405
Index:  180    Time: 177.31410360336304
Index:  190    Time: 185.40057635307312
Index:  200    Time: 192.4130117893219
Index:  210    Time: 200.80059218406677
Index:  220    Time: 206.5223171710968
Index:  230    Time: 215.2523171901703
Index:  240    Time: 219.96142482757568
Index:  250    Time: 22

Сохранение файла с метками

In [15]:
with open('/content/drive/My Drive/errorsmarks.txt', "w", encoding="utf-8") as file:
  file.writelines(outerrors)

Расчет метрики

In [16]:
precision = TP / (TP + FP)
recall = TP / (TP + FN)
fmeasure = 2 * recall * precision / (recall + precision)
print ("Precision: ", precision)
print ("Recall: ", recall)
print ("F-measure: ", fmeasure)

Precision:  0.36597938144329895
Recall:  0.6586270871985158
F-measure:  0.47051027170311466
