In [1]:
!pip install transformers
!pip install mendelai-brat-parser
!pip install smart-open
!pip install mendelai-brat-parser

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.25.1-py3-none-any.whl (5.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m73.0 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m46.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.11.1-py3-none-any.whl (182 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.11.1 tokenizers-0.13.2 transformers-4.25.1
Looking in indexes: https://pypi.org/simple, http

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import pyarrow.parquet as pq
from brat_parser import get_entities_relations_attributes_groups
import glob
import torch
import torch.nn as nn
import torch.utils.data as torch_data
import torch.nn.functional as F
from dataclasses import dataclass
from tqdm import tqdm
import os
from transformers import pipeline, AutoTokenizer, AutoModelForTokenClassification
import pickle

In [3]:
class ReClassifier(nn.Module):
    def __init__(self, input_dim, num_classes):
        super().__init__()
        self.relation_classifier = nn.Sequential(
            nn.Linear(3 * input_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )
        self.has_relation = nn.Sequential(
            nn.Linear(3 * input_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.3),
            nn.Linear(2048, 1024),
            nn.BatchNorm1d(1024),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.Dropout(0.3),
            nn.Linear(512, 1)
        )
    
    def forward(self, cls_emb, subj_emb, obj_emb):
        x = torch.cat([cls_emb, subj_emb, obj_emb], dim=1)
        relation_logits = self.relation_classifier(x)
        has_relation    = self.has_relation(x)
        return relation_logits, has_relation

In [4]:
idx2label = {0: 'PPS',1: 'GOL',2: 'TSK',3: 'NPS',4: 'FNT',5: 'PNT',6: 'FNG',7: 'NNT',8: 'FPS',9: 'NNG',10: 'PNG',11: 'NO_RE'}
label2idx = {v: k for k, v in idx2label.items()}

In [5]:
@dataclass
class ReEntity:
    token: str
    embeddings: list

def infer_and_merge_embeddings(tokenizer, model, dataset, disable_tqdm=False):
    model.eval()
    def merge_tokens(tokens, embeddings=None):
        assert embeddings is None or len(tokens.tokens()) == len(embeddings)
        result = []
        for i, token in enumerate(tokens.tokens()):
            if token.startswith('##'):
                result[-1].token += token.lstrip('##')
                if embeddings is not None:
                    result[-1].embeddings.append(embeddings[i].unsqueeze(0))
            else:
                result.append(ReEntity(token, []))
                if embeddings is not None:
                    result[-1].embeddings.append(embeddings[i].unsqueeze(0))
        return tuple(zip(*list(map(lambda entity: (entity.token, torch.cat(entity.embeddings).mean(dim=0) if embeddings is not None else -1), result))))
    
    def get_sub_idx(find_in, to_find):
        l1, l2 = len(find_in), len(to_find)
        for i in range(l1):
            if find_in[i:i+l2] == to_find:
                return i
        raise Exception(f"Something went wrong. Cannot find {to_find} in {find_in}")
    result = []
    bar = dataset if disable_tqdm else tqdm(dataset)
    for text, subj, obj, is_subj_first, label in bar:
        tokenized_text = tokenizer(text, return_tensors="pt")
        with torch.no_grad():
            embeddings = model(**tokenized_text, output_hidden_states=True).hidden_states[-3][0]
        tokenized_subj = tokenizer(subj)
        tokenized_obj = tokenizer(obj)
        
        merged_tokens, merged_embeddings = merge_tokens(tokenized_text, embeddings)
        merged_tokens, merged_embeddings = merged_tokens[:-1], merged_embeddings[:-1] # drop SEP
        
        merged_subj, _ = merge_tokens(tokenized_subj) 
        merged_obj, _ = merge_tokens(tokenized_obj)
        
        merged_subj = merged_subj[1:-1] # drop CLS and SEP
        merged_obj = merged_obj[1:-1] # drop CLS and SEP
        
        cls_embed = merged_embeddings[0]
        subj_start_idx = get_sub_idx(merged_tokens, merged_subj)
        obj_start_idx = get_sub_idx(merged_tokens, merged_obj)
        result.append((cls_embed, merged_subj, merged_embeddings[subj_start_idx:subj_start_idx + len(merged_subj)],
                       merged_obj, merged_embeddings[obj_start_idx:obj_start_idx + len(merged_obj)], is_subj_first, label))
    return result

In [6]:
def re_collate(batch):
    (cls_embeds, subj, subj_embeds, obj, obj_embeds, is_subj_first, label) = zip(*batch)
    clses = torch.stack(cls_embeds)
    subjects_embs = torch.stack(list(map(lambda x: torch.stack(x).mean(dim=0), subj_embeds)))
    objects_embs = torch.stack(list(map(lambda x: torch.stack(x).mean(dim=0), obj_embeds)))
    labels = torch.tensor(list(map(lambda x: label2idx[x], label)))

    return clses, subjects_embs, objects_embs, labels

In [8]:
tokenizer = AutoTokenizer.from_pretrained("surdan/LaBSE_ner_nerel")
model = AutoModelForTokenClassification.from_pretrained("surdan/LaBSE_ner_nerel")
classifier = ReClassifier(768, 11)
classifier.load_state_dict(torch.load(f'/content/drive/MyDrive/re_classifier_large_final_epochs=500_lr=1e-3_stepsize=250_gamma=0.1.model', map_location=torch.device('cpu')))

<All keys matched successfully>

In [45]:
def find_relations(classifier, bert, tokenizer, text, has_relation_treshold=0.5, visualizer=None):
    def pack_dataset(text, subj, obj):
        return [
            (text, subj, obj, True, "TSK")
        ]
    ans = {}
    classifier.eval()
    bert.eval()
    answer = []
    tokenized = tokenizer(text)
    tokens = tokenized.tokens()
    seen = set()
    for subj_idx, subj_token in enumerate(tokens):
        if subj_idx == 0 or subj_idx == len(tokens):
            continue
        for obj_idx in range(subj_idx + 1, len(tokens) - 1):
            subj_start, subj_end = tokenized.word_to_chars(tokenized.token_to_word(subj_idx))
            obj_start, obj_end   = tokenized.word_to_chars(tokenized.token_to_word(obj_idx))
            if subj_start == obj_start and subj_end == obj_end or ((subj_start, subj_end), (obj_start, obj_end)) in seen:
                continue
            seen.add(((subj_start, subj_end), (obj_start, obj_end)))
            subj, obj = text[subj_start:subj_end], text[obj_start:obj_end]
            packed = pack_dataset(text, subj, obj)
            preprocess = infer_and_merge_embeddings(tokenizer, bert, packed, disable_tqdm=True)
            clses, subjects_embs, objects_embs, labels = list(map(lambda x: x, re_collate(preprocess)))
            with torch.no_grad():
                relation_to, has_relation_to = classifier(clses, subjects_embs, objects_embs)
                relation_from, has_relation_from = classifier(clses, objects_embs, subjects_embs)
                has_relation_to_score = torch.sigmoid(has_relation_to).item()
                has_relation_from_score = torch.sigmoid(has_relation_from).item()
                if has_relation_to_score < has_relation_treshold and has_relation_from_score < has_relation_treshold:
                    continue
                is_to_relation = has_relation_to_score > has_relation_from_score
                relation = relation_to if is_to_relation else relation_from
                distribution = nn.functional.softmax(relation.squeeze(), dim=0)
                label = distribution.argmax().item()
                confidence = distribution[label].item()
                if visualizer is not None:
                    from_vert = f"{subj}_{subj_idx}" if is_to_relation else f"{obj}_{obj_idx}"
                    to_vert = f"{subj}_{subj_idx}" if not is_to_relation else f"{obj}_{obj_idx}"
                    visualizer.add_edge(from_vert, to_vert, idx2label[label])
                # print(confidence, max(has_relation_to_score, has_relation_from_score), idx2label[label], subj, obj)
                ans[subj + obj] = idx2label[label]
    if visualizer is not None:
        visualizer.visualize()
    return ans

In [14]:
find_relations(classifier, model, tokenizer, "Увеличить количество волонтеров анти  наркотической направленности")

0.5827641487121582 1.0 FPS Увеличить количество
0.9984435439109802 1.0 FPS Увеличить волонтеров
0.9879553318023682 1.0 FPS Увеличить анти
0.8079472780227661 0.9995242357254028 TSK Увеличить наркотической
0.9861086010932922 1.0 NNT Увеличить направленности
0.8610538840293884 1.0 FPS количество волонтеров
0.9998779296875 0.9999986886978149 NPS количество анти
0.9996383190155029 1.0 NPS количество направленности
0.9999880790710449 1.0 NPS волонтеров анти
1.0 0.9821747541427612 FPS волонтеров наркотической
0.9177203178405762 1.0 NNG волонтеров направленности
0.9653666019439697 0.9844523668289185 NNG наркотической направленности


In [49]:
from brat_parser import get_entities_relations_attributes_groups

entities, relations, attributes, groups = get_entities_relations_attributes_groups("/content/test.ann")
relations

{'R3': Relation(id='R3', type='NPS', subj='T6', obj='T5'),
 'R4': Relation(id='R4', type='GOL', subj='T8', obj='T9'),
 'R5': Relation(id='R5', type='GOL', subj='T11', obj='T12'),
 'R6': Relation(id='R6', type='GOL', subj='T15', obj='T16'),
 'R7': Relation(id='R7', type='GOL', subj='T19', obj='T20'),
 'R9': Relation(id='R9', type='GOL', subj='T28', obj='T29'),
 'R10': Relation(id='R10', type='GOL', subj='T30', obj='T31'),
 'R11': Relation(id='R11', type='GOL', subj='T35', obj='T36'),
 'R12': Relation(id='R12', type='GOL', subj='T38', obj='T39'),
 'R13': Relation(id='R13', type='GOL', subj='T42', obj='T43'),
 'R14': Relation(id='R14', type='GOL', subj='T45', obj='T46'),
 'R15': Relation(id='R15', type='NNG', subj='T48', obj='T47'),
 'R16': Relation(id='R16', type='NPS', subj='T53', obj='T52'),
 'R17': Relation(id='R17', type='GOL', subj='T72', obj='T73'),
 'R18': Relation(id='R18', type='GOL', subj='T75', obj='T76'),
 'R19': Relation(id='R19', type='GOL', subj='T80', obj='T81'),
 'R20': 

In [48]:
import numpy
import os

def get_f1(filename):
    seq_example = open(filename + ".txt").read()
    #calc preds
    preds = find_relations(classifier, model, tokenizer, seq_example)
    #read labels
    entities, relations, attributes, groups = get_entities_relations_attributes_groups(filename + ".ann")
    labels = {}
    for t in relations.values():
        labels[entities[t.subj].text + entities[t.obj].text] = t.type
    #calc f1
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    for j in preds.keys():
        if labels.keys().__contains__(j) and labels[j] == preds[j] and labels[j] != "NO_RE":
            tp += 1
        if labels.keys().__contains__(j) and labels[j] == preds[j] and labels[j] == "NO_RE":
            tn += 1
        if labels.keys().__contains__(j) and  labels[j] != "NO_RE" and preds[j] == "NO_RE":
            fn += 1
        if labels.keys().__contains__(j) and  labels[j] == "NO_RE" and preds[j] != "NO_RE":
            fp += 1
        if not labels.keys().__contains__(j) and preds[j] == "NO_RE":
            tn += 1
        if not labels.keys().__contains__(j) and preds[j] != "NO_RE":
            fp += 1
    for j in labels.keys():
        if preds.keys().__contains__(j):
            continue
        if labels[j] == "NO_RE":
            tn += 1
        if labels[j] != "NO_RE":
            fp += 1
    recall = tp / (tp + fn + 0.00001)
    precision = tp / (tp + fp + 0.00001)
    f1 = (2 * precision * recall) / (precision + recall + 0.00001)
    return f1

f1s = list()
folder_name = "test"
for path in os.listdir(folder_name):
    if ".ann" in path:
        continue
    f1s.append(get_f1(folder_name + "/" + path[:-4]))
print(numpy.mean(f1s))

0.126836707
