In [None]:
!git clone https://github.com/nerel-ds/NEREL

In [None]:
!unzip NEREL-v1.0.zip -d NEREL

In [1]:
## Read the data# Reading files
from collections import namedtuple
import re
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm
import nltk

Ann = namedtuple('annotation', ['tag', 'start1', 'end1', 'start2', 'end2', 'text'])
Rel = namedtuple('relationship', ['tag', 'arg1', 'arg2'])

def read_files(folder):
    filenames = sorted(set(e[:e.rfind('.')] for e in os.listdir(folder)))[1:]
    
    texts, entities, relationships  = [], [], []
    for file in tqdm(filenames):
       # if (not file): continue
        path1 = os.path.join(folder, file+'.txt')
        path2 = os.path.join(folder, file+'.ann')
        if (not os.path.exists(path2)):
            print(f'{path2} not found')
            continue
            #with open(path2, 'w'):
            #    pass
            
        with open(path1, 'r', encoding="utf8") as text, open(path2, 'r', encoding="utf8") as ann:
            texts.append(text.read())

            file_entities = {}
            file_relationship = []
            regex_ent = r'T(?P<id>\d+)\s(?P<tag>\w+)\s(?P<start1>\d+) (?P<end1>\d+)(;(?P<start2>\d+) (?P<end2>\d+))?\s(?P<text>.*)'
            regex_rel = r'R(?P<id>\d+)\s(?P<tag>\w+)\sArg1:T(?P<arg1>\d+) Arg2:T(?P<arg2>\d+)'
            
            
            for row in sorted(ann.readlines(), reverse=True):
                #print(row)
                match_ent = re.match(regex_ent, row)
                match_rel = re.match(regex_rel, row)
                if (match_ent):
                    res = match_ent.groupdict()
                    res['start1'] = int(res['start1'])
                    res['end1'] = int(res['end1'])
                    if (res['start2'] is not None):
                        res['start2'] = int(res['start2'])
                        res['end2'] = int(res['end2'])
                    id = res.pop('id')
                    file_entities[id] = Ann(**res)
                elif (match_rel):
                    try:
                        res = match_rel.groupdict()
                        res['arg1'] = file_entities[res['arg1']]
                        res['arg2'] = file_entities[res['arg2']]
                        id = res.pop('id')
                        file_relationship.append(Rel(**res))
                    except KeyError as e:
                        print(f'not found T{e} row={row}')
                else:
                    print(f'incorrect format in: row={row} file={file}')
            entities.append(file_entities)
            relationships.append(file_relationship)
    entities = [sorted(e.values(), key = lambda x: (x.start1, x.end1)) for e in entities]
    return texts, entities, relationships, filenames

In [2]:
folder = 'NEREL/NEREL-v1.0/train'
texts, entities, relationships, filenames = read_files(folder)

incorrect format in: row=По словам очевидцев пешехо
 file=21013_text


100%|██████████| 745/745 [00:02<00:00, 314.00it/s]


In [3]:
def in_range(range1, range2):
    # range1 in range2
    if range1[0] is None: return True
    if range2[0] is None: return False
    return int(range2[0]) <= int(range1[0]) and int(range2[1]) >= int(range1[1])

def is_nested_anns(ent1: Ann, ent2: Ann):
    # ent 1 in ent 2
    res = True
    ent1_range1 = (ent1.start1, ent1.end1)
    ent1_range2 = (ent1.start2, ent1.end2)
    ent2_range1 = (ent2.start1, ent2.end1)
    ent2_range2 = (ent2.start2, ent2.end2)
    #print(f'{ent1_range1=} {ent1_range2=} {ent1_range2=} {ent2_range2=}')
    res = res and (in_range(ent1_range1, ent2_range1) or in_range(ent1_range1, ent2_range2))
    res = res and (in_range(ent1_range2, ent2_range1) or in_range(ent1_range2, ent2_range2))
    return res


def is_nested_anns2(ent1: Ann, ent2: Ann):
    return is_nested_anns(ent1,ent2) or is_nested_anns(ent2, ent1)

def is_nested(rel: Rel):
    return is_nested_anns(rel.arg1, rel.arg2) or is_nested_anns(rel.arg2, rel.arg1)

![](https://i.imgur.com/tgDfc8i.png)             | ![](https://i.imgur.com/oWa5vWo.png)
:-------------------------:|:-------------------------:


In [4]:
%pip install transformers

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m


In [5]:
# #!g1.1
# pre_dataset = []

# for text_id in range(len(texts)):
#     relationships_nested = {(e.arg1, e.arg2):e for e in relationships[text_id] if is_nested(e)}
#     nes = []
#     for i in range(len(entities[text_id])):
#         # O(n^2) eeeeeeeeee
#         for j in range(i+1, len(entities[text_id])):
#             if (is_nested_anns2(entities[text_id][i], entities[text_id][j])):
#                 nes.append((entities[text_id][i], entities[text_id][j]))
#     for e in nes:
#         if (e in relationships_nested):
#             pre_dataset.append((*e, relationships_nested[e].tag))
#         elif ((e[1], e[0]) in relationships_nested):
#             pre_dataset.append((e[1], e[0], relationships_nested[(e[1], e[0])].tag))
#         else:
#             pre_dataset.append((*e, 'None'))
# #             pre_dataset.append((e[1], e[0], 'None'))

# Dataset

In [6]:
#!g1.1
import os
import torch
import time
import torch.optim as optim
from torch.utils.data import DataLoader
import numpy as np
from tqdm.notebook import tqdm
import torchvision
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from matplotlib import pyplot as plt
import random
from torch.optim.lr_scheduler import ReduceLROnPlateau

In [7]:
#!g1.1
from transformers import BertForSequenceClassification, BertTokenizer

In [8]:
#!g1.1
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

Using cuda device


In [9]:
#!g1.1
tokenizer = BertTokenizer.from_pretrained('DeepPavlov/rubert-base-cased-sentence')

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=1649718.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=112.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=24.0), HTML(value='')))






In [10]:
#!g1.1
MAX_LENGTH = 100

In [13]:
#!g1.1
import nltk
nltk.download('punkt')

[nltk_data] Downloading package punkt to /home/jupyter/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


True

In [15]:
#!g1.1
sents = nltk.sent_tokenize(texts[0])
len(' '.join(sents))

3855

In [49]:
#!g1.1
class Binary_Tag:
    def __getitem__(self, x):
        if (x == 'None'): return 0
        else: return 1

# tags = list(set(e.tag for e in sum(relationships,[]))) + ['None'] # ids --> string tag
reversed_tags = Binary_Tag() # string tag --> ids

def tok(text1):
    # longest entity is 36 tokens long
    res = tokenizer.encode_plus(text1,
                      max_length = MAX_LENGTH,
                      pad_to_max_length = True,
                      return_tensors = 'pt',
                      )
    return res['input_ids'][0], res['attention_mask'][0], res['token_type_ids'][0]
    
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, folder):
        texts, entities, relationships, filenames = read_files(folder)
        
        pre_dataset = []

        for text_id in range(len(texts)):
            relationships_nested = {(e.arg1, e.arg2):e for e in relationships[text_id] if is_nested(e)}
            nes = []
            for i in range(len(entities[text_id])):
                # O(n^2) eeeeeeeeee
                for j in range(i+1, len(entities[text_id])):
                    if (is_nested_anns2(entities[text_id][i], entities[text_id][j])):
                        nes.append((entities[text_id][i], entities[text_id][j]))
            for e in nes:
                if (e in relationships_nested):
                    pre_dataset.append((*e, relationships_nested[e].tag, text_id))
                elif ((e[1], e[0]) in relationships_nested):
                    pre_dataset.append((e[1], e[0], relationships_nested[(e[1], e[0])].tag, text_id))
                else:
                    pre_dataset.append((*e, 'None', text_id))

#         random.seed(2021)
        data = []
        for e1, e2, tag, text_id in tqdm(pre_dataset):
            
            if (len(e1.text) < len(e2.text)):
                e1,e2 = e2, e1
            
            text = texts[text_id][max(e1.start1 - 50,0): e1.end1 + 50]
#             while (text[0] != ' ' and not e1.start1-50 <= 0): text = text[1:]
#             while (text[-1] != ' '): text = text[:-1]
            subtext = ' { ' + e1.text.replace(e2.text, '[ ' + e2.text + ' ]') + ' } '
            text = text.replace(e1.text, subtext)
#             print(text)
            data.append((reversed_tags[tag], *tok(text)))
        
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)
    

In [50]:
#!g1.1
train_data = MyDataset('NEREL/NEREL-v1.0/train')
train_dl = DataLoader(train_data, shuffle=True, batch_size=32)

dev_data = MyDataset('NEREL/NEREL-v1.0/dev')
dev_dl = DataLoader(dev_data, batch_size=32)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=745.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=11175.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=93.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1301.0), HTML(value='')))

incorrect format in: row=По словам очевидцев пешехо
 file=21013_text






In [51]:
#!g1.1
next(iter(dev_dl))

[tensor([0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0,
         1, 1, 0, 1, 1, 1, 1, 0]),
 tensor([[  101,   284,   222,  ...,     0,     0,     0],
         [  101,   284, 17710,  ...,     0,     0,     0],
         [  101, 58205, 23010,  ...,     0,     0,     0],
         ...,
         [  101, 21551,  1916,  ...,     0,     0,     0],
         [  101,   841,   128,  ...,     0,     0,     0],
         [  101, 13880, 22443,  ...,     0,     0,     0]]),
 tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]])]

In [52]:
#!g1.1
next(iter(dev_dl))[1].size()

torch.Size([32, 100])

# Training

In [54]:
#!g1.1
model = BertForSequenceClassification.from_pretrained(
    "DeepPavlov/rubert-base-cased-sentence", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels = 2, #len(tags), # The number of output labels--18 for our NER dataset
    output_attentions = False, # Whether the model returns attentions weights.
    output_hidden_states = False, # Whether the model returns all hidden-states.
    max_length = MAX_LENGTH
)

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=711456784.0), HTML(value='')))




In [55]:
#!g1.1
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elemen

In [89]:
#!g1.1
def calculate_accuracy(y_pred, y):
    return sum(torch.argmax(y_pred, axis = 1) == y)*1.0/len(y)

def train():
    model.train()
    running_loss = 0
    epoch_accuracy = 0

    pbar = tqdm(enumerate(train_dl), total = len(train_dl))
    for i, (labels, input_ids, attention_mask, token_type_ids) in pbar:
        labels = labels.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)
        
        optimizer.zero_grad()
        
        outputs = model(input_ids, 
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask, 
                    labels=labels)
        
        loss = outputs[0]
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_accuracy += calculate_accuracy(outputs[1], labels).item()
        running_loss += loss.item()
        pbar.set_description(f'training: running_loss = {running_loss/(i+1.0):.4f} accuracy = {epoch_accuracy/(i+1.0):.4f},')
    
    scheduler.step(running_loss/(i+1))
    print(f'train loss= {running_loss/(i+1):.4f} \n train accuracy = {epoch_accuracy/(i+1):.4f},')

from sklearn.metrics import f1_score
def test():
    with torch.no_grad():
        model.eval()
        running_loss = 0
        epoch_accuracy = 0

        dev_dl = DataLoader(dev_data, batch_size=32)

        pbar = tqdm(enumerate(dev_dl), total = len(dev_dl))

        y_true = []
        y_pred = []

        for i, (labels, input_ids, attention_mask, token_type_ids) in pbar:
            labels = labels.to(device)
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            token_type_ids = token_type_ids.to(device)

            outputs = model(input_ids, 
                        token_type_ids=token_type_ids,
                        attention_mask=attention_mask, 
                        labels=labels)


            loss = outputs[0]
            calculate_accuracy(outputs[1], labels)
            y_true += [e.item() for e in labels]
            y_pred += [e.item() for e in torch.argmax(outputs[1], axis = 1)]

            running_loss += loss.item()
            epoch_accuracy += calculate_accuracy(outputs[1], labels).item()
            # pbar.set_description(f'testing: running_loss = {running_loss/(i+1):.4f} accuracy = {epoch_accuracy/(i+1):.4f},')
    print(f'test loss= {running_loss/(i+1):.4f} \n test accuracy = {epoch_accuracy/(i+1):.4f}, F1 = {f1_score(y_true, y_pred)}')



In [90]:
#!g1.1
optimizer = torch.optim.AdamW(model.parameters(),
                  lr = 5e-5, # args.learning_rate 
                  eps = 5e-8#1e-8 # args.adam_epsilon 
                )
scheduler = ReduceLROnPlateau(optimizer, patience=5, cooldown = 1, factor = 0.5)



In [61]:
#!g1.1
seed_val = 2021

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)
for epoch in range(1, 20):
    print(f'epoch = {epoch}')
    print('lr=', optimizer.param_groups[0]['lr'])
    train()
    test()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

epoch = 1
lr= 5e-05

train loss= 0.4092 
 train accuracy = 0.8111,

test loss= 0.4341 
 test accuracy = 0.8250, F1 = 0.7927272727272728
epoch = 2
lr= 5e-05

train loss= 0.2580 
 train accuracy = 0.9046,

test loss= 0.3834 
 test accuracy = 0.8383, F1 = 0.7666666666666667
epoch = 3
lr= 5e-05

train loss= 0.1998 
 train accuracy = 0.9313,

test loss= 0.3822 
 test accuracy = 0.8586, F1 = 0.8059071729957804
epoch = 4
lr= 5e-05

train loss= 0.1566 
 train accuracy = 0.9484,

test loss= 0.4976 
 test accuracy = 0.8525, F1 = 0.8076152304609219
epoch = 5
lr= 5e-05

train loss= 0.1353 
 train accuracy = 0.9593,

test loss= 0.4825 
 test accuracy = 0.8502, F1 = 0.808259587020649
epoch = 6
lr= 5e-05

train loss= 0.1234 
 train accuracy = 0.9654,

test loss= 0.5506 
 test accuracy = 0.8601, F1 = 0.8286252354048963
epoch = 7
lr= 5e-05

train loss= 0.1011 
 train accuracy = 0.9741,

test loss= 0.4922 
 test accuracy = 0.8631, F1 = 0.8265107212475633
epoch = 8
lr= 5e-05

train loss= 0.0972 
 train a

In [62]:
#!g1.1
for epoch in range(20, 21):
    print(f'epoch = {epoch}')
    print('lr=', optimizer.param_groups[0]['lr'])
    train()
    test()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=350.0), HTML(value='')))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=41.0), HTML(value='')))

epoch = 20
lr= 2.5e-05

train loss= 0.0308 
 train accuracy = 0.9928,

test loss= 0.7566 
 test accuracy = 0.8498, F1 = 0.8123195380173244


In [63]:
#!g1.1
torch.save(model.state_dict(), 'models/Bert_binary_nested_only_context')

In [65]:
#!g1.1
#!g1.1
from sklearn.metrics import f1_score

with torch.no_grad():
    model.eval()
    running_loss = 0
    epoch_accuracy = 0

    dev_dl = DataLoader(dev_data, batch_size=1)

    pbar = tqdm(enumerate(dev_dl), total = len(dev_dl))

    y_true = []
    y_pred = []

    for i, (labels, input_ids, attention_mask, token_type_ids) in pbar:
        labels = labels.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)

        outputs = model(input_ids, 
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask, 
                    labels=labels)


        loss = outputs[0]
#         print(input_ids[0].detach().cpu().numpy())
        if (torch.argmax(outputs[1], axis = 1) != labels):
            print('expected:', labels.detach().cpu().item())
            print('found:', torch.argmax(outputs[1], axis = 1).detach().cpu().item(), '(', outputs[1].detach().cpu().numpy() ,')')
            print(tokenizer.decode(input_ids[0].detach().cpu().numpy()).replace(r'[PAD]', ''))
            print()

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1301.0), HTML(value='')))

expected: 0
found: 1 ( [[-3.0432782  3.0335584]] )
[CLS] роходил срочную службу. В 1981 году — перевёлся в { [ Строгановское ] высшее художественное училище }, которое окончил в 1985 году. Записывал и сочиня [SEP]                                                                  

expected: 0
found: 1 ( [[-3.0214646  2.9982398]] )
[CLS] вил музыкант в ходе брифинга с главным менеджером { Международного музыкального фестиваля « [ Ереванские ] перспективы } » Соной Ованнисян. По словам Кисина, Армения засл [SEP]                                                                

expected: 1
found: 0 ( [[ 2.6974087 -2.5493507]] )
[CLS] нерального секретаря Osapp, крупнейшего профсоюза { сотрудников исправительных учреждений [ Италии ] }, в тюрьмах страны находится около 12 тысяч членов [SEP]                                                                         

expected: 0
found: 1 ( [[-3.0436735  3.03302  ]] )
[CLS] ритании возобновит выполнение своих обязанностей { Премьер - министр [ Ве



In [105]:
#!g1.1
#!g1.1
class Binary_Tag:
    def __getitem__(self, x):
        if (x == 'None'): return 0
        else: return 1

# tags = list(set(e.tag for e in sum(relationships,[]))) + ['None'] # ids --> string tag
reversed_tags = Binary_Tag() # string tag --> ids

def tok(text1):
    # longest entity is 36 tokens long
    res = tokenizer.encode_plus(text1,
                      max_length = MAX_LENGTH,
                      pad_to_max_length = True,
                      return_tensors = 'pt',
                      )
    return res['input_ids'][0], res['attention_mask'][0], res['token_type_ids'][0]
    
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, folder):
        texts, entities, relationships, filenames = read_files(folder)
        
        pre_dataset = []

        for text_id in range(len(texts)):
            relationships_nested = {(e.arg1, e.arg2):e for e in relationships[text_id] if is_nested(e)}
            nes = []
            for i in range(len(entities[text_id])):
                # O(n^2) eeeeeeeeee
                for j in range(i+1, len(entities[text_id])):
                    if (is_nested_anns2(entities[text_id][i], entities[text_id][j])):
                        nes.append((entities[text_id][i], entities[text_id][j]))
            for e in nes:
                if (e in relationships_nested):
                    pre_dataset.append((*e, relationships_nested[e].tag, text_id))
                elif ((e[1], e[0]) in relationships_nested):
                    pre_dataset.append((e[1], e[0], relationships_nested[(e[1], e[0])].tag, text_id))
                else:
                    pre_dataset.append((*e, 'None', text_id))

#         random.seed(2021)
        data = []
        for e1, e2, tag, text_id in tqdm(pre_dataset):
            tag1, tag2 = e1.tag, e2.tag
            if (len(e1.text) < len(e2.text)):
                e1,e2 = e2, e1
            
            text = texts[text_id][max(e1.start1 - 50,0): e1.end1 + 50]
#             while (text[0] != ' ' and not e1.start1-50 <= 0): text = text[1:]
#             while (text[-1] != ' '): text = text[:-1]
            subtext = ' { ' + e1.text.replace(e2.text, '[ ' + e2.text + ' ]') + ' } '
            text = text.replace(e1.text, subtext)
#             print(text)
            data.append((reversed_tags[tag], e1.tag, e2.tag, tag, *tok(text)))
        
        self.data = data
    
    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)

dev_data = MyDataset('NEREL/NEREL-v1.0/dev')
dev_dl = DataLoader(dev_data, batch_size=1)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=93.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1301.0), HTML(value='')))




In [117]:
#!g1.1
with torch.no_grad():
    model.eval()
    running_loss = 0
    epoch_accuracy = 0

    dev_dl = DataLoader(dev_data, batch_size=1)

    pbar = tqdm(enumerate(dev_dl), total = len(dev_dl))

    y_true = []
    y_pred = []
    tags = []

    for i, (labels, tag1, tag2, true_tag, input_ids, attention_mask, token_type_ids) in pbar:
        labels = labels.to(device)
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        token_type_ids = token_type_ids.to(device)

        outputs = model(input_ids, 
                    token_type_ids=token_type_ids,
                    attention_mask=attention_mask, 
                    labels=labels)


        loss = outputs[0]
        calculate_accuracy(outputs[1], labels)
        y_true += [e.item() for e in labels]
        y_pred += [e.item() for e in torch.argmax(outputs[1], axis = 1)]
        tags += [(tag1, tag2)]

        running_loss += loss.item()
        epoch_accuracy += calculate_accuracy(outputs[1], labels).item()
        if (torch.argmax(outputs[1], axis = 1) != labels):
            print('expected:', labels.detach().cpu().item())
            print('found:', torch.argmax(outputs[1], axis = 1).detach().cpu().item(), '(', outputs[1].detach().cpu().numpy() ,')')
            print(f'{tag1[0]} ---> {tag2[0]} : {true_tag[0]}')
            print(tokenizer.decode(input_ids[0].detach().cpu().numpy()).replace(r'[PAD]', ''))
print()
print(f'test loss= {running_loss/(i+1):.4f} \n test accuracy = {epoch_accuracy/(i+1):.4f}, F1 = {f1_score(y_true, y_pred)}')

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1301.0), HTML(value='')))

expected: 0
found: 1 ( [[-3.0432782  3.0335584]] )
ORGANIZATION ---> PERSON : None
[CLS] роходил срочную службу. В 1981 году — перевёлся в { [ Строгановское ] высшее художественное училище }, которое окончил в 1985 году. Записывал и сочиня [SEP]                                                                  
expected: 0
found: 1 ( [[-3.0214646  2.9982398]] )
EVENT ---> CITY : None
[CLS] вил музыкант в ходе брифинга с главным менеджером { Международного музыкального фестиваля « [ Ереванские ] перспективы } » Соной Ованнисян. По словам Кисина, Армения засл [SEP]                                                                
expected: 1
found: 0 ( [[ 2.6974087 -2.5493507]] )
PROFESSION ---> COUNTRY : ORIGINS_FROM
[CLS] нерального секретаря Osapp, крупнейшего профсоюза { сотрудников исправительных учреждений [ Италии ] }, в тюрьмах страны находится около 12 тысяч членов [SEP]                                                                         
expected: 0
found: 1 ( [[-3.0436735  3.



In [111]:
#!g1.1
from collections import Counter
l = []
for a,b,t in zip(y_true, y_pred, tags):
    if (a!=b):
        l.append(t)

In [116]:
#!g1.1
Counter([(e[0][0], e[1][0]) for e in l]).most_common()

[(('PROFESSION', 'ORGANIZATION'), 22),
 (('LAW', 'LAW'), 13),
 (('ORGANIZATION', 'ORGANIZATION'), 12),
 (('ORGANIZATION', 'COUNTRY'), 7),
 (('ORGANIZATION', 'PERSON'), 6),
 (('EVENT', 'DISEASE'), 6),
 (('EVENT', 'COUNTRY'), 6),
 (('PROFESSION', 'PROFESSION'), 6),
 (('EVENT', 'PERSON'), 6),
 (('PROFESSION', 'COUNTRY'), 5),
 (('EVENT', 'LOCATION'), 5),
 (('DISEASE', 'DISEASE'), 5),
 (('EVENT', 'EVENT'), 5),
 (('ORGANIZATION', 'LOCATION'), 4),
 (('LAW', 'COUNTRY'), 4),
 (('LAW', 'ORDINAL'), 4),
 (('ORGANIZATION', 'STATE_OR_PROVINCE'), 4),
 (('FACILITY', 'PERSON'), 4),
 (('WORK_OF_ART', 'PERSON'), 4),
 (('ORGANIZATION', 'IDEOLOGY'), 4),
 (('EVENT', 'PROFESSION'), 4),
 (('ORGANIZATION', 'PROFESSION'), 3),
 (('PROFESSION', 'STATE_OR_PROVINCE'), 3),
 (('FACILITY', 'LOCATION'), 3),
 (('EVENT', 'CITY'), 2),
 (('PROFESSION', 'LOCATION'), 2),
 (('LAW', 'STATE_OR_PROVINCE'), 2),
 (('WORK_OF_ART', 'FACILITY'), 2),
 (('PENALTY', 'MONEY'), 2),
 (('PROFESSION', 'PERSON'), 2),
 (('ORGANIZATION', 'DISTR