In [1]:
# !python3 -m spacy download ru_core_news_lg -q

In [2]:
import pandas as pd
import random
import spacy
from spacy.training import Example
import ru_core_news_lg
from tqdm import trange, tqdm
import json
import warnings

In [3]:
warnings.filterwarnings('ignore')

In [4]:
train_data_df = pd.read_json("../data/data_train.jsonl", lines=True)
train_data = []


for ind, row in train_data_df.iterrows():
    entities = sorted(row['ners'], key=lambda x: x[1] - x[0], reverse=True)
    spans = set()
    no_overlap = []

    for start, end, label in entities:
        if not any(i in spans for i in range(start, end + 1)):
            no_overlap.append((start, end + 1, label))
            spans.update(range(start, end + 1))

    train_data.append((row['sentences'], {"entities": no_overlap}))

In [5]:
test_data_df = pd.read_json("../data/data_test.jsonl", lines=True)
test_data_df = test_data_df.rename(columns={"senences": "sentences"})

In [6]:
NERS = ['AGE',
             'AWARD',
             'CITY',
             'COUNTRY',
             'CRIME',
             'DATE',
             'DISEASE',
             'DISTRICT',
             'EVENT',
             'FACILITY',
             'FAMILY',
             'IDEOLOGY',
             'LANGUAGE',
             'LAW',
             'LOCATION',
             'MONEY',
             'NATIONALITY',
             'NUMBER',
             'ORDINAL',
             'ORGANIZATION',
             'PENALTY',
             'PERCENT',
             'PERSON',
             'PRODUCT',
             'PROFESSION',
             'RELIGION',
             'STATE_OR_PROVINCE',
             'TIME',
             'WORK_OF_ART']

In [7]:
model = ru_core_news_lg.load()
optimizer = model.initialize()

In [8]:
num_iterations = 12

for iteration in range(1, num_iterations + 1):
    print(f'Running epoch {iteration} of {num_iterations}...')
    total_loss = 0
    losses_dict = {}
    for idx, (text, annotations) in enumerate(tqdm(train_data)):
        doc = model.make_doc(text)
        batch = Example.from_dict(doc, annotations)
        model.update([batch], sgd=optimizer, losses=losses_dict)
        total_loss += losses_dict.get("ner", 0)
    print(
        f'Epoch {iteration}/{num_iterations}, total avg loss: {total_loss / len(train_data)}')

Running epoch 1 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.63it/s]


Epoch 1/12, total avg loss: 19900.315172498602
Running epoch 2 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.70it/s]


Epoch 2/12, total avg loss: 11141.89735617086
Running epoch 3 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.67it/s]


Epoch 3/12, total avg loss: 8334.962051563403
Running epoch 4 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.62it/s]


Epoch 4/12, total avg loss: 6705.955228019603
Running epoch 5 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.69it/s]


Epoch 5/12, total avg loss: 5112.413778605281
Running epoch 6 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.69it/s]


Epoch 6/12, total avg loss: 4290.200580888229
Running epoch 7 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.72it/s]


Epoch 7/12, total avg loss: 3491.961507691918
Running epoch 8 of 12...


100%|██████████| 519/519 [00:54<00:00,  9.58it/s]


Epoch 8/12, total avg loss: 3116.9145020454803
Running epoch 9 of 12...


100%|██████████| 519/519 [00:54<00:00,  9.59it/s]


Epoch 9/12, total avg loss: 2656.9037263403343
Running epoch 10 of 12...


100%|██████████| 519/519 [00:54<00:00,  9.56it/s]


Epoch 10/12, total avg loss: 2374.369271675872
Running epoch 11 of 12...


100%|██████████| 519/519 [00:53<00:00,  9.66it/s]


Epoch 11/12, total avg loss: 2165.2046653613734
Running epoch 12 of 12...


100%|██████████| 519/519 [00:54<00:00,  9.54it/s]

Epoch 12/12, total avg loss: 2043.7623535432963





In [12]:
entities_combined = []
for index, record in test_data_df.iterrows():
    prediction = model(record['sentences'])
    entities = []
    for entity in prediction.ents:
        if entity.label_ in NERS:
            entities.append([entity.start_char, entity.end_char - 1, entity.label_])
    entities_combined.append({'id': record['id'], 'ners': entities})

In [13]:
with open("test.jsonl", "w") as f:
    for data in entities_combined:
        json.dump(data, f)
        f.write('\n')