<a href="https://colab.research.google.com/github/Gaussiandra/RuREBus_NER_RE/blob/master/solution.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

global_path = '/content/drive/My Drive/ML/NLP sber/8/RuREBus/'
global_repo_path = global_path+'data/RuREBus-master/'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%%writefile requirements.txt
razdel
ipymarkup
pytorch-crf
tensorflow-gpu==1.15.2
deeppavlov
russian_tagsets
conllu

Writing requirements.txt


In [None]:
!pip install -r requirements.txt
!python -m deeppavlov install squad_bert
!python -m deeppavlov install syntax_ru_syntagrus_bert

In [None]:
import numpy as np
import torch
import os
import sys
import razdel
import ipymarkup
import cloudpickle
import conllu
import re
import networkx as nx
import matplotlib.pyplot as plt
from itertools import combinations
from collections import Counter
from torch import nn
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from collections import namedtuple
from torchcrf import CRF
from zipfile import ZipFile
from deeppavlov.models.embedders.elmo_embedder import ELMoEmbedder
from deeppavlov import build_model, configs
from nltk import DependencyGraph

In [None]:
elmo_embedder = ELMoEmbedder(
    "http://files.deeppavlov.ai/deeppavlov_data/elmo_ru-wiki_600k_steps.tar.gz", 
    elmo_output_names=["elmo"], 
)

In [None]:
syntax_model = build_model(configs.syntax.ru_syntagrus_joint_parsing, download=True)

In [None]:
DocRepr = namedtuple('DocRepr', 'text ann')
RESentRepr = namedtuple('RESentRepr', 'doc_idx token_range graph re_type ent1_idx ent2_idx')
WordsRepr = namedtuple('WordRepr', [
    'raw_text', 
    'tokens', 
    'char_ids', 
    'ner_spans', 
    'labels',
    'pos_tags',
    'deprels', 
    'file_name', 
    'elmo_embedding'
])

class NER_RE_Dataset(Dataset):
    def __init__(self, folder_path, syntax_model, elmo_embedder, sentence_wise,
                 batch_size=2, is_train=True, 
                 custom_ent2idx=None,
                 custom_idx2ent=None, 
                 custom_char2idx=None,
                 custom_re2idx=None,
                 custom_pos2idx=None,
                 custom_deprel2idx=None):

        self.folder_path = folder_path
        self.sentence_wise = sentence_wise

        self.docs = {}
        for dirpath, dirnames, filenames in os.walk(self.folder_path):
            for extension_file_name in filenames:
                filename = extension_file_name[:extension_file_name.rfind('.')]
                try:
                    self.docs[filename] = DocRepr(
                        open(os.path.join(dirpath, filename+'.txt')).read(),
                        open(os.path.join(dirpath, filename+'.ann')).read(),
                    )
                except FileNotFoundError:
                    print(f'"{filename}" was ignored.')                       

        self.samples = []
        self.ent2idx = {'OUT': 0} if is_train else custom_ent2idx
        self.idx2ent = ['OUT'] if is_train else custom_idx2ent
        self.char2idx = {'<pad>': 0, '<unk>': 1} if is_train else custom_char2idx
        self.re2idx = {'NONE': 0} if is_train else custom_re2idx
        self.pos2idx = {'<unk>': 0} if is_train else custom_pos2idx
        self.deprel2idx = {'<unk>': 0} if is_train else custom_deprel2idx

        self.skipped_sentences = 0
        self.re_sentences = []
        for doc_idx, (name, doc) in enumerate(tqdm(self.docs.items())):
            tokens =  list(razdel.tokenize(doc.text))

            ann_spans = [i.split('\t') for i in doc.ann.split('\n')]
            ner_spans = {}
            re_tokens = {}
            for i in range(len(ann_spans)):
                if ann_spans[i][0].startswith('T'):
                    entity, start, stop = ann_spans[i][1].split()
                    ner_spans[ann_spans[i][0]] = (entity, int(start), int(stop))
                elif ann_spans[i][0].startswith('R'):
                    # ['R1', 'TSK Arg1:T10 Arg2:T11']
                    splitted_re = np.array(re.split('\s|:', ann_spans[i][1]))
                    if splitted_re[0] not in self.re2idx:
                        self.re2idx[splitted_re[0]] = len(self.re2idx)
                    re_tokens[ann_spans[i][0]] = splitted_re[[0, 2, 4]]

            labels = []
            T_name2token_idx = {}
            for token in tokens:
                label = self.ent2idx['OUT']
                for T_name, (entity, start, stop) in ner_spans.items():
                    if 'B-'+entity not in self.ent2idx and is_train:
                        self.ent2idx['B-'+entity] = len(self.ent2idx)
                        self.ent2idx['I-'+entity] = len(self.ent2idx)
                        self.idx2ent += ['B-'+entity, 'I-'+entity]

                    if token.start == start:
                        label = self.ent2idx['B-'+entity]
                        T_name2token_idx[T_name] = len(labels)
                    elif token.start > start and token.stop <= stop:
                        label = self.ent2idx['I-'+entity]
                
                labels.append(label)

            char_ids = []
            for token in tokens:
                char_word_repr = []
                for char in token.text:
                    if char not in self.char2idx and is_train:
                        self.char2idx[char] = len(self.char2idx)
                    if not is_train and char not in self.char2idx:
                        char_word_repr.append(self.char2idx['<unk>'])
                    else:
                        char_word_repr.append(self.char2idx[char])
                char_ids.append(char_word_repr)
            char_ids = np.array(char_ids)

            pos_tags = []
            deprels = []
            sentenized_doc = list(razdel.sentenize(doc.text))
            prev_stop, sent_num = 0, 0
            for token_num, token in enumerate(tokens):
                is_last_token = token_num + 1 == len(tokens)
                if token.stop > sentenized_doc[sent_num].stop or is_last_token:
                    text_sent_tokens = [i.text for i in tokens[prev_stop: token_num+is_last_token]]
                    try:
                        conllu_data = syntax_model([text_sent_tokens])[0]
                        sent_graph = DependencyGraph(conllu_data).nx_graph().to_undirected()

                        for word in conllu.parse(conllu_data)[0]:
                            if word['upos'] not in self.pos2idx:
                                self.pos2idx[word['upos']] = len(self.pos2idx)
                            pos_tags.append(self.pos2idx[word['upos']])

                        for word in conllu.parse(conllu_data)[0]:
                            if word['deprel'] not in self.deprel2idx:
                                self.deprel2idx[word['deprel']] = len(self.deprel2idx)
                            deprels.append(self.deprel2idx[word['deprel']])
                        
                        for re_type, ent1, ent2 in re_tokens.values():
                            if (ner_spans[ent1][1] >= tokens[prev_stop].start and
                                ner_spans[ent1][2] <  tokens[token_num].stop and
                                ner_spans[ent2][1] >= tokens[prev_stop].start and
                                ner_spans[ent2][2] <  tokens[token_num].stop):
                                self.re_sentences.append(RESentRepr(
                                    doc_idx, 
                                    slice(prev_stop, token_num+is_last_token),
                                    sent_graph,
                                    re_type,
                                    T_name2token_idx[ent1],
                                    T_name2token_idx[ent2]
                                ))
                                break
                        else:
                            ner_sent_tokens_idx = [
                                i for i, cur_label in enumerate(labels[prev_stop: token_num+is_last_token], prev_stop)
                                if self.idx2ent[cur_label].startswith('B-')
                            ]
                            if len(ner_sent_tokens_idx) >= 2:
                                for st, en in combinations(ner_sent_tokens_idx, 2):
                                    self.re_sentences.append(RESentRepr(
                                            doc_idx, 
                                            slice(prev_stop, token_num+is_last_token),
                                            sent_graph,
                                            'NONE',
                                            st,
                                            en
                                        ))
                    # sentenize может выдать слишком длинное предложение из-за которого падает syntax_model
                    except RuntimeError:
                        self.skipped_sentences += 1

                        n_skipped_tokens = token_num+is_last_token - prev_stop
                        pos_tags += n_skipped_tokens * [self.pos2idx['<unk>']]
                        deprels += n_skipped_tokens * [self.deprel2idx['<unk>']]
                    prev_stop = token_num
                    sent_num += 1
            assert len(pos_tags) == len(tokens), (len(pos_tags), len(tokens))
            pos_tags = np.array(pos_tags)
            deprels = np.array(deprels)

            self.samples.append(WordsRepr(
                doc.text,
                tokens,
                char_ids,
                ner_spans,
                labels,
                pos_tags,
                deprels,
                name,
                None,
            ))

        for batch_num in tqdm(range(0, len(self.samples), batch_size)):
            embeddings = elmo_embedder([
                [token.text for token in sample.tokens] 
                for sample in self.samples[batch_num:batch_num+batch_size]
            ])
            for i, emb in enumerate(embeddings):
                self.samples[batch_num+i] = self.samples[batch_num+i]._replace(elmo_embedding = emb)

        print(f'{self.skipped_sentences} sentences was skipped')

    def __getitem__(self, idx):
        if self.sentence_wise: 
            def prep_sentence(sent, ent1_idx, ent2_idx):
                syntax_order = [
                    sent.token_range.start+i-1 for i in nx.shortest_path(
                    sent.graph, 
                    ent1_idx-sent.token_range.start+1, 
                    ent2_idx-sent.token_range.start+1
                )]

                return (
                    self.samples[sent.doc_idx].elmo_embedding[sent.token_range],
                    self.samples[sent.doc_idx].char_ids[sent.token_range],
                    self.samples[sent.doc_idx].pos_tags[sent.token_range]
                    ), (
                    self.samples[sent.doc_idx].elmo_embedding[syntax_order],
                    self.samples[sent.doc_idx].char_ids[syntax_order],
                    self.samples[sent.doc_idx].pos_tags[syntax_order],
                    self.samples[sent.doc_idx].deprels[syntax_order]
                )

            re_sent = self.re_sentences[idx]
            return (
                prep_sentence(re_sent, re_sent.ent1_idx, re_sent.ent2_idx),
                self.re2idx[re_sent.re_type]
            )
        else:
            return (
                self.samples[idx].elmo_embedding,
                self.samples[idx].char_ids,
                self.samples[idx].labels,
            )
            
    def __len__(self):
        if self.sentence_wise:
            return len(self.re_sentences)
        else:
            return len(self.samples)

In [None]:
task12_train_ds = NER_RE_Dataset(global_repo_path+'train_data/train_parts', 
                                syntax_model, elmo_embedder, sentence_wise=True)

"" was ignored.


HBox(children=(FloatProgress(value=0.0, max=188.0), HTML(value='')))

  "The graph doesn't contain a node " "that depends on the root element."





HBox(children=(FloatProgress(value=0.0, max=94.0), HTML(value='')))


8 sentences was skipped


In [None]:
task12_test_ds = NER_RE_Dataset(global_repo_path+'test_data/test_parts',
                             syntax_model, elmo_embedder, sentence_wise=True, is_train=False,
                             custom_ent2idx=task12_train_ds.ent2idx,
                             custom_idx2ent=task12_train_ds.idx2ent,
                             custom_char2idx=task12_train_ds.char2idx,
                             custom_re2idx=task12_train_ds.re2idx,
                             custom_pos2idx=task12_train_ds.pos2idx,
                             custom_deprel2idx=task12_train_ds.deprel2idx)

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))

  "The graph doesn't contain a node " "that depends on the root element."





HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))


0 sentences was skipped


In [None]:
(len(task12_train_ds), 
 len(task12_train_ds.ent2idx), 
 len(task12_train_ds.char2idx), 
 len(task12_train_ds.re2idx),
 len(task12_train_ds.pos2idx),
 len(task12_train_ds.deprel2idx))

(8470, 17, 179, 12, 19, 38)

In [None]:
Counter(map(lambda x: x.re_type, task12_train_ds.re_sentences))

Counter({'FNG': 98,
         'FNT': 75,
         'FPS': 342,
         'GOL': 1326,
         'NNG': 370,
         'NNT': 317,
         'NONE': 37660,
         'NPS': 408,
         'PNG': 54,
         'PNT': 127,
         'PPS': 307,
         'TSK': 1681})

In [None]:
def pad_seq(batch):
    emb = [torch.tensor(sample[0]) for sample in batch]

    max_word_len = max([len(word) for sample in batch for word in sample[1]])
    max_word_len = min(50, max_word_len)
    char_list = []
    for sample in batch:
        words = []
        for word in sample[1]:
            words.append(word[:max_word_len] + [0]*(max_word_len-len(word)))
        char_list.append(torch.tensor(words))

    pos_tags = [torch.tensor(sample[2]) for sample in batch]

    padded_emb = pad_sequence(emb, batch_first=True)
    padded_chars = pad_sequence(char_list, batch_first=True)
    padded_pos_tags = pad_sequence(pos_tags, batch_first=True)

    if len(sample) == 4:
        deprels = [torch.tensor(sample[3]) for sample in batch]
        padded_deprels = pad_sequence(deprels, batch_first=True)
        return padded_emb, padded_chars, padded_pos_tags, padded_deprels
    else:
        return padded_emb, padded_chars, padded_pos_tags

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
lib_path = os.path.abspath(os.path.join(global_repo_path, 'eval_scripts'))
sys.path.append(lib_path)
import evaluate_ners, evaluate_rels, brat_format

### NER

In [None]:
batch_size = 2

train_iter = DataLoader(task12_train_ds, batch_size, collate_fn=pad_seq, shuffle=True, pin_memory=True, num_workers=2)
test_iter = DataLoader(task12_test_ds, batch_size,  collate_fn=pad_seq)

In [None]:
class NERModel(nn.Module):
    def __init__(self, n_classes=17, n_chars=179):
        super().__init__()

        self.char_embedding = nn.Embedding(n_chars, 250)
        convolution_params = [(3, 75), (4, 75), (5, 75), (4, 75), (7, 50)]
        self.conv_list = nn.ModuleList([nn.Conv2d(1, n, (k, 250)) for k, n in convolution_params])
        self.lstm = nn.LSTM(1374, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)
        self.dropout_layer = nn.Dropout(0.1)
        self.output_layer = nn.Linear(2*256, n_classes)

    def forward(self, elmo_emb, chars):
        char_emb = self.char_embedding(chars)
        
        conv_res = []
        for f in self.conv_list:
            convolved_res = f(char_emb.reshape(-1, 1, char_emb.shape[2], char_emb.shape[3]))
            conv_res.append(convolved_res.max(dim=2)[0].squeeze(2).reshape(*char_emb.shape[:2], f.out_channels))
        conv_res = torch.cat(conv_res, dim=2)
        
        elmo_cnn_features = torch.cat([elmo_emb.detach(), conv_res], dim=2)
        x, _ = self.lstm(elmo_cnn_features)
        x = self.dropout_layer(x)
        x = self.output_layer(x)
        return x

model1 = NERModel().to(device)

In [None]:
loss_fn = CRF(17, batch_first=True).to(device)
optimizer = torch.optim.Adam(model1.parameters(), lr=0.001)
losses = []
n_iter = 0

print(sum(p.numel() for p in model1.parameters()))

5360617


In [None]:
for epoch in range(10):
    for elmo_emb, chars, y in tqdm(train_iter):
        optimizer.zero_grad()
        model1.train()

        elmo_emb, chars, y = elmo_emb.to(device), chars.to(device), y.to(device)
        y_pred = model1(elmo_emb, chars)

        loss = -loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        n_iter += 1
        if n_iter % 5 == 0:
            print(n_iter, np.mean(losses[-5:]))

In [None]:
#pickle.dump(model1, open(global_repo_path+'model1_1024.pkl', 'wb'))

model1 = pickle.load(open(global_repo_path+'model1_1024.pkl', 'rb'))
model1.lstm.flatten_parameters()

In [None]:
final_predicts = []
for batch_idx, (elmo_emb, chars, y) in enumerate(tqdm(test_iter)):
    model1.eval()

    elmo_emb, chars, y = elmo_emb.to(device), chars.to(device), y.to(device)
    y_pred = model1(elmo_emb, chars)

    for sample_idx, sample in enumerate(y_pred):
        cur_len = len(task1_test_ds.samples[batch_idx*batch_size + sample_idx].tokens)
        final_predicts.append(torch.argmax(sample[:cur_len], dim=1).cpu().numpy())

HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))




In [None]:
idx2ent = {}
for k, v in task1_test_ds.ent2idx.items():
    idx2ent[v] = k

In [None]:
n_ners = 0
for sample_num, sample in enumerate(tqdm(final_predicts)):
    orig_sample = task1_test_ds.samples[sample_num]
    brat_doc = brat_format.BratDoc(orig_sample.raw_text)

    # Убираем неконсистентность
    out_idx = task1_test_ds.ent2idx['OUT']
    for token_num, token in enumerate(sample):
        if idx2ent[token].startswith('I-'):
            if token_num == 0:
                sample[token_num] = out_idx
            elif sample[token_num - 1] == out_idx:
                sample[token_num] = out_idx

    brat_spans = []
    for token_num, token in enumerate(sample):
        if idx2ent[token].startswith('B-'):
            brat_spans.append([
                idx2ent[token][2:], 
                orig_sample.tokens[token_num].start,
                orig_sample.tokens[token_num].stop,
            ])
        elif idx2ent[token].startswith('I-') and len(brat_spans) and brat_spans[-1][0] == idx2ent[token][2:]:
            brat_spans[-1][2] =  orig_sample.tokens[token_num].stop

    for span in brat_spans:
        n_ners += 1
        brat_doc.add_ner(n_ners, *span)

    brat_doc.write_to_file(global_repo_path+f'test_data/pred_parts/{orig_sample.file_name}.ann')

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))




In [None]:
evaluate_ners.calc_ner_f1(global_repo_path+f'test_data/test_parts/', 
                          global_repo_path+f'test_data/pred_parts/')

0.4928182807399347

### RE

https://www.aclweb.org/anthology/P16-1072.pdf

In [None]:
class BasicModel(nn.Module):
    def __init__(self, n_classes, n_chars, n_pos_tags, n_deprels):
        super().__init__()

        self.use_deprel = n_deprels is not None

        local_convolution_params = [(3, 75), (4, 75), (5, 75), (6, 75), (7, 50)]
        global_convolution_params = [(3, 75), (4, 75), (5, 75), (6, 75), (7, 50), (8, 50)]
        self.dropout_layer = nn.Dropout(0.1)

        self.char_embedding = nn.Embedding(n_chars, 250)
        self.pos_embedding = nn.Embedding(n_pos_tags, 32)
        if self.use_deprel:
            self.deprel_embedding = nn.Embedding(n_deprels, 32)
        self.local_conv_list = nn.ModuleList([nn.Conv2d(1, n, (k, 250)) for k, n in local_convolution_params])

        lstm_input_size = 1438 if self.use_deprel else 1406
        self.lstm = nn.LSTM(lstm_input_size, 256, num_layers=2, batch_first=True, dropout=0.1, bidirectional=True)

        self.global_conv_list = nn.ModuleList([nn.Conv2d(1, n, (k, 512)) for k, n in global_convolution_params])
        self.output_head = nn.Linear(400, n_classes)
    
    def forward(self, elmo_emb, chars, pos, deprels):
        char_emb = self.char_embedding(chars)

        conv_res = []
        for f in self.local_conv_list:
            convolved_res = f(char_emb.reshape(-1, 1, char_emb.shape[2], char_emb.shape[3]))
            conv_res.append(convolved_res.max(dim=2)[0].squeeze(2).reshape(*char_emb.shape[:2], f.out_channels))
        conv_res = torch.cat(conv_res, dim=2)

        head_features = torch.cat([
            elmo_emb.detach(),
            self.pos_embedding(pos),
            self.deprel_embedding(deprels) if self.use_deprel else torch.tensor([], device=device),
            conv_res
        ], dim=2)

        lstm_output, _ = self.lstm(head_features)
        conv_res = []
        for f in self.global_conv_list:
            convolved_res = f(lstm_output.reshape(-1, 1, lstm_output.shape[1], lstm_output.shape[2]))
            conv_res.append(convolved_res.max(dim=2)[0].squeeze(2))
        conv_res = torch.cat(conv_res, dim=1)
        
        output = self.output_head(conv_res)

        return conv_res, output

class REModel(nn.Module):
    def __init__(self, n_classes=17, n_chars=179, n_pos_tags=19, n_deprels=38):
        super().__init__()

        self.head1 = BasicModel(n_classes, n_chars, n_pos_tags, None)
        self.head2 = BasicModel(n_classes, n_chars, n_pos_tags, n_deprels)
        self.final_output = nn.Linear(2*400, n_classes)

    def forward(self, orig_elmo_emb, orig_chars, orig_pos, 
                syntax_elmo_emb, syntax_chars, syntax_pos, syntax_deprels):
        features1, output1 = self.head1(orig_elmo_emb, orig_chars, orig_pos, None)
        features2, output2 = self.head2(syntax_elmo_emb, syntax_chars, syntax_pos, syntax_deprels)
        all_features = torch.cat([features1, features2], dim=1)
        final_output = self.final_output(all_features)
        return output1, output2, final_output

model2 = REModel().to(device)

In [None]:
def re_col(batch):
    orig_sent, syntax_sent, re = [], [], []
    for sample in batch:
        orig_sent.append(sample[0][0])
        syntax_sent.append(sample[0][1])
        re.append(sample[1])
    return pad_seq(orig_sent), pad_seq(syntax_sent), torch.tensor(re)

In [None]:
batch_size = 96

train_iter_re = DataLoader(task12_train_ds, batch_size, collate_fn=re_col, shuffle=True, pin_memory=True, num_workers=2)
test_iter_re = DataLoader(task12_test_ds, 150, collate_fn=re_col)

In [None]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model2.parameters(), lr=0.001)
losses = []
n_iter = 0

print(sum(p.numel() for p in model2.parameters()))

13156283


In [None]:
for epoch in range(10):
    for orig_sent, syntax_sent, labels in tqdm(train_iter_re):

        orig_sent = list(map(lambda x: x.to(device), orig_sent))
        syntax_sent = list(map(lambda x: x.to(device), syntax_sent))
        labels = labels.to(device)

        optimizer.zero_grad()
        model2.train()
        
        y_pred1, y_pred2, main_pred = model2(*orig_sent, *syntax_sent)

        loss = (
            0.25*loss_fn(y_pred1, labels) + 
            0.25*loss_fn(y_pred2, labels) + 
            0.5*loss_fn(main_pred, labels)
        )

        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        n_iter += 1
        if n_iter % 150 == 0:
            print(n_iter, np.mean(losses[-150:]))

    cloudpickle.dump(model2, open(global_repo_path+'model2.pkl', 'wb'))
    cloudpickle.dump(optimizer, open(global_repo_path+'optimizer2.pkl', 'wb'))

In [None]:
model2 = cloudpickle.load(open(global_repo_path+'model2.pkl', 'rb'))
model2.head1.lstm.flatten_parameters()
model2.head2.lstm.flatten_parameters()

In [None]:
final_predicts = []
for orig_sent, syntax_sent, _ in tqdm(test_iter_re):
    model2.eval()
    
    orig_sent = list(map(lambda x: x.to(device), orig_sent))
    syntax_sent = list(map(lambda x: x.to(device), syntax_sent))

    y_pred1, y_pred2, final_pred = model2(*orig_sent, *syntax_sent)

    final_pred = 0.25*y_pred1 + 0.25*y_pred2 + 0.5*final_pred

    for sample in final_pred:
        final_predicts.append(torch.argmax(sample).item())

HBox(children=(FloatProgress(value=0.0, max=35.0), HTML(value='')))




In [None]:
lib_path = os.path.abspath(os.path.join(global_repo_path, 'eval_scripts'))
sys.path.append(lib_path)
import evaluate_ners, brat_format

In [None]:
idx2re = {}
for k, v in task12_test_ds.re2idx.items():
    idx2re[v] = k

In [None]:
for sample_idx, sample in enumerate(tqdm(task12_test_ds.samples)):
    brat_doc = brat_format.BratDoc(sample.raw_text)
    token_idx2T = {}
    n_rels = 0

    for token_idx, token in enumerate(sample.tokens):
        for T_name, (entity, start, stop) in sample.ner_spans.items():
            if token.start == start:
                brat_doc.add_ner(T_name[1:], entity, start, stop)
                token_idx2T[token_idx] = T_name
                break

    for sent_id, sent in enumerate(task12_test_ds.re_sentences):
        if (sent.doc_idx == sample_idx and 
            idx2re[final_predicts[sent_id]] != 'NONE'):
            n_rels += 1
            brat_doc.add_relation(
                n_rels, 
                idx2re[final_predicts[sent_id]], 
                token_idx2T[sent.ent1_idx][1:], 
                token_idx2T[sent.ent2_idx][1:]
            )

    brat_doc.write_to_file(global_repo_path+f'test_data/pred_parts_re/{sample.file_name}.ann')

HBox(children=(FloatProgress(value=0.0, max=30.0), HTML(value='')))




In [None]:
evaluate_rels.calc_rels_f1(global_repo_path+f'test_data/pred_parts_re/',
                           global_repo_path+f'test_data/test_parts/',)

0.3216783216783217