In [1]:
import numpy as np
import random as rn
import spacy

import os

os.environ['PYTHONHASHSEED'] = '0'

np.random.seed(12)
rn.seed(34)

In [2]:
!python -m spacy init fill-config ./base_config.cfg ./config.cfg

[38;5;2m✔ Auto-filled config with all values[0m
[38;5;2m✔ Saved config[0m
config.cfg
You can now add your data and train your pipeline:
python -m spacy train config.cfg --paths.train ./train.spacy --paths.dev ./dev.spacy


In [2]:
import pandas as pd
import os
import re

In [3]:
# collect data

def create_dataset(txt_file, ann_file):
    with open(txt_file, "r") as txt_f, open(ann_file, "r") as ann_f:
        fulltext = re.sub('\xa0', ' ', txt_f.read())
        annotations = ann_f.readlines()
        
        sentences = [x for x in filter(None, fulltext.split('\n'))]
        labels = [list() for _ in sentences]
        for line in annotations:
            if line.startswith("T"):
                _, entity_type, start, end, *text = line.split()
                try:
                    start = int(start)
                    end = int(end)
                except Exception as e:
                    break
                for i, curr_sent in enumerate(sentences):
                    esc_text = re.escape(curr_sent)
                    for curr_sent_span in re.finditer(esc_text+'\n', fulltext):
                        curr_sent_span = curr_sent_span.span()
                        if end <= curr_sent_span[1]:
                            try:
                                labels[i].append((re.search(re.escape(" ".join(text)), curr_sent).span()[0],
                                                 re.search(re.escape(" ".join(text)), curr_sent).span()[1],
                                                 entity_type))
#                                 print(text,
#                                       re.search(re.escape(" ".join(text)), curr_sent).span()[0],
#                                       re.search(re.escape(" ".join(text)), curr_sent).span()[1],
#                                       entity_type)
                            except Exception as e:
                                continue
                            break
    return sentences, labels

X_data = []
y_data = []
numberoffiles = 0

for subdir, dirs, files in os.walk('a3data/train_data/'):
    for file in files:
        if file.endswith(".txt"):
            txt_file = os.path.join(subdir, file)
#             print(txt_file)
            filetext = open(txt_file, 'r').read()
            ann_file = os.path.join(txt_file[:-4] + ".ann")
            if os.path.isfile(ann_file):
                numberoffiles += 1
                X, y = create_dataset(txt_file, ann_file)
                X_data.extend(X)
                y_data.extend(y)


In [4]:
assert len(y_data) == len(X_data)
len(y_data)

7976

In [5]:
dataset = list(zip(X_data, y_data))

In [6]:
import spacy
from spacy import displacy

In [7]:
from tqdm import tqdm

In [8]:
nlp = spacy.load("ru_core_news_lg")

In [9]:
X_train = []
y_train = []

for i, item in enumerate(y_data):
    for ann in item:
        y_train.append([ann])
        X_train.append(X_data[i])

In [10]:
list(zip(X_train, y_train))[:3]

[('Полиция Бельгии арестовала двух человек по подозрению в подготовке теракта',
  [(0, 15, 'ORGANIZATION')]),
 ('Полиция Бельгии арестовала двух человек по подозрению в подготовке теракта',
  [(8, 15, 'COUNTRY')]),
 ('Полиция Бельгии арестовала двух человек по подозрению в подготовке теракта',
  [(27, 31, 'NUMBER')])]

In [34]:
from spacy.tokens import DocBin

db = DocBin()
errors = []
for text, annotations in tqdm(zip(X_data, y_data)):
    doc = nlp(text)
    entities = []
    for start, end, label in annotations:
        span = doc.char_span(start, end, label=label)
        entities.append(span)
        try:
            doc.ents += (span,)
        except Exception as e:
            errors.append(e)
    db.add(doc)

7976it [00:47, 166.79it/s]


In [38]:
errors[1]

ValueError('[E1010] Unable to set entity information for token 1 which is included in more than one span in entities, blocked, missing or outside.')

In [37]:
len(db)

7976

In [39]:
db.to_disk('./dataset.spacy')

In [None]:
!python -m spacy train config_trans.cfg --output "./spacy_models" --paths.train "./dataset.spacy" --paths.dev "./dataset.spacy"

[38;5;4mℹ Saving to output directory: spacy_models[0m
[38;5;4mℹ Using CPU[0m
[38;5;4mℹ To switch to GPU 0, use the option: --gpu-id 0[0m
[1m
[2023-04-17 23:47:59,888] [INFO] Set up nlp object from config
[2023-04-17 23:47:59,893] [INFO] Pipeline: ['transformer', 'ner']
[2023-04-17 23:47:59,895] [INFO] Created vocabulary
[2023-04-17 23:47:59,896] [INFO] Finished initializing nlp object
Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a

In [None]:
print("FINISH")

In [42]:
best_model = spacy.load("spacy_models/model-best")
last_model = spacy.load("spacy_models/model-last")

In [66]:
res = best_model('Пулеметы, автоматы и снайперские винтовки изъяты в арендуемом американцами доме в Бишкеке')

In [71]:
res.ents[0].start_char

62

In [67]:
displacy.serve(res, style="ent", auto_select_port=True)


Using the 'ent' visualizer
Serving on http://0.0.0.0:5001 ...

Shutting down server on port 5001.


In [167]:
truth_result = ""
for subdir, dirs, files in os.walk('a3data/train_data/coll3_tacred/'):
    for file in sorted(files):
        if file.endswith(".ann"):
            truth_result += f'{file}\n'
            txt_file = os.path.join(subdir, file)
            testtext = open(txt_file, 'r').read()
            for sent in filter(None, testtext.split('\n')):
                if sent[0] == 'T':
                    sentence = sent.split()
                    truth_result += " ".join([sentence[1], sentence[2], sentence[3]])+"\n"
            truth_result += testtext


In [168]:
truth_result[:100]

'003.ann\nNATIONALITY 62 74\nCITY 82 89\nDATE 117 126\nORGANIZATION 145 179\nCOUNTRY 171 179\nCOUNTRY 221 2'

In [179]:
result = ""

for subdir, dirs, files in os.walk('a3data/train_data/coll3_tacred/'):
    for file in sorted(files):
        if file.endswith(".txt"):
            result += f'{file[:-4]}.ann\n'
            txt_file = os.path.join(subdir, file)
            testtext = open(txt_file, 'r').read()
            sentences = [x.strip() for x in filter(None, testtext.split('\n'))]
            for sentence in sentences:
                sent_result = best_model(sentence)
                for word in sent_result.ents:
                    result += f'{word.label_} {word.start_char} {word.end_char}\n'

In [170]:
result[:100]

'003.ann\nNATIONALITY 62 74\nLOC 82 89\nTIME 0 16\nCITY 0 6\nDATE 8 17\nLOC 62 70\nLOC 112 115\nLOC 118 125\nD'

In [85]:
from collections import defaultdict

def f1_score(submission_answer, truth):
    submission_entities = defaultdict(list)
    current = None
    for line in submission_answer:
        line = line.strip().split()
        if len(line) == 1:
            current = line[0]
            continue
        if len(line) == 3:
            tag, l, r = line
            submission_entities[current].append((tag, l, r))

    true_entities = defaultdict(list)
    current = None
    for line in truth:
        line = line.strip().split()
        if len(line) == 1:
            current = line[0]
            continue
        if len(line) == 3:
            tag, l, r = line
            true_entities[current].append((tag, l, r))

    precision, recall, precision_total, recall_total = 0, 0, 0, 0
    for current in true_entities:
        if current not in submission_entities:
            precision_total += len(submission_entities[current])
            recall_total += len(true_entities[current])
            continue

        for entity in set(submission_entities[current]):
            precision_total += 1
            if entity in true_entities[current]:
                precision += 1

        for entity in set(true_entities[current]):
            recall_total += 1
            if entity[1:] in [e[1:] for e in submission_entities[current]]:
                recall += 1

    if precision_total > 0:
        precision /= precision_total
    if recall_total > 0:
        recall /= recall_total
    if precision + recall > 0:
        f1 = 2 * precision * recall / (precision + recall)
    else:
        f1 = 0

    # print(precision, recall)
    return f1


In [86]:
found = 0
for annotation in result.split('\n'):
    if re.search(annotation, truth_result):
        found+= 1

found / len(truth_result.split('\n'))

0.040515589618533354

In [180]:
result = re.sub(r'\nPER ', '\nPERSON ', result)
result = re.sub(r'\nLOC ', '\nCITY ', result)
result = re.sub(r'\nORG ', '\nORGANIZATION ', result)

In [181]:
f1_score(result.split('\n'), truth_result.split('\n'))

0.042568161829375555

In [177]:
result

'003.ann\nNATIONALITY 62 74\nCOUNTRY 82 89\nTIME 0 16\nCITY 0 6\nDATE 8 17\nCOUNTRY 62 70\nCOUNTRY 112 115\nCOUNTRY 118 125\nDATE 179 189\nORGANIZATION 203 206\nCOUNTRY 207 215\nEVENT 83 90\nCOUNTRY 106 113\nAGE 137 147\nCOUNTRY 159 167\nCOUNTRY 192 195\nEVENT 197 216\nNUMBER 218 223\nNUMBER 304 306\nNUMBER 325 340\nNUMBER 342 345\nPRODUCT 363 370\nNUMBER 371 384\nNUMBER 386 392\nNUMBER 553 566\nPERSON 573 580\nNUMBER 582 586\nORGANIZATION 622 625\nCOUNTRY 96 99\nNUMBER 143 145\nPROFESSION 146 160\nCOUNTRY 181 184\nPROFESSION 211 241\nNUMBER 83 86\nNUMBER 93 97\nNUMBER 120 135\nNUMBER 137 142\nNUMBER 165 178\nNUMBER 180 183\nNUMBER 207 221\nNUMBER 232 235\nNUMBER 258 272\nNUMBER 286 289\nNUMBER 350 354\nNUMBER 361 363\nNUMBER 413 415\nPRODUCT 451 458\nORGANIZATION 13 16\nORGANIZATION 71 82\nCOUNTRY 83 90\nTIME 92 98\nORGANIZATION 160 171\nCOUNTRY 26 29\nCOUNTRY 59 67\nCOUNTRY 116 124\nDATE 134 144\nCOUNTRY 183 186\nCOUNTRY 43 51\nCOUNTRY 74 84\nPROFESSION 126 140\nORGANIZATION 187 200\

In [178]:
open('answer.txt', 'w').write(result)

86541