In [1]:
import re
import json
import random
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from itertools import combinations, permutations

# to split text into sentences
from nltk.tokenize import sent_tokenize

# LUKE model
from transformers import LukeTokenizer, LukeForEntityPairClassification, LukeForEntitySpanClassification

  from .autonotebook import tqdm as notebook_tqdm


In [10]:
path_to_data = '../benchmark_data/'
path_to_models = 'models/'

# Data

In [3]:
results = pd.DataFrame(columns=['model_name', 'macro_precision', 'macro_recall', 'macro_f1', 'p_works_at', 'p_partners_with', 'p_acquired_by', 'r_works_at', 'r_partners_with', 'r_acquired_by', 'f1_works_at', 'f1_partners_with', 'f1_acquired_by'])

In [4]:
class NoSentenceError(Exception):
    pass

def get_sentence(text, span):
    """Returns sentence for the given span"""

    sentences = sent_tokenize(text)
    for sentence in sentences:
        start_idx = text.find(sentence)
        end_idx = start_idx + len(sentence)
        if span[0] >= start_idx and span[1] <= end_idx:
            return sentence, start_idx
    
    raise NoSentenceError("entities are not in one sentence")

def adjust_span(span, offset):
    """Adjusts span"""

    return (span[0] - offset, span[1] - offset)

def transform_json_re(path, additional_rows=100):
    """Transfomrs json file into dataframe with entity relations, adds also rows with no relation"""

    with open(path) as f:
        data = json.load(f)

    texts = []
    entity_pairs = []
    entity_spans_pairs = []
    relation_labels = []

    total_additional_rows = 0

    for item in data:
        item_entities = {}
        text = item['data']['text']
        annotations = item['annotations'][0]['result']

        related_pairs = set()
        for annotation in annotations:

            # first create list of entities in this text
            if annotation['type'] == 'labels':
                entity_span = (annotation['value']['start'], annotation['value']['end'])
                entity_text = annotation['value']['text']
                entity_id = annotation['id']
                entity_label = annotation['value']['labels'][0]
                item_entities[entity_id] = {'text': entity_text, 'span': entity_span, 'label': entity_label}

            # create entity pairs according to annotations
            elif annotation['type'] == 'relation':
                from_entity = item_entities[annotation['from_id']]
                to_entity = item_entities[annotation['to_id']]
                try:
                    label = annotation['labels'][0]
                except KeyError:
                    # if the labels is missing skip annotation
                    continue

                from_span = from_entity['span']
                to_span = to_entity['span']

                try: 
                    if from_span[0] < to_span[0]:
                        sentence, offset = get_sentence(text, (from_span[0], to_span[1]))
                    else:
                        sentence, offset = get_sentence(text, (to_span[0], from_span[1]))
                except NoSentenceError as e:
                    print(e)
                    continue

                # adjust span of the entities according to the sentence
                from_span_adjusted = adjust_span(from_span, offset)
                to_span_adjusted = adjust_span(to_span, offset)

                # add relation to the dataframe
                texts.append(sentence)
                entity_pairs.append((from_entity['text'], to_entity['text']))
                entity_spans_pairs.append((from_span_adjusted, to_span_adjusted))
                relation_labels.append(label)

                # list of entity pairs that are actually related
                related_pairs.add((from_entity['text'], to_entity['text']))

        # create all possible pairs of the entities in this text
        all_entities = list(item_entities.values())
        all_pairs = list(combinations(all_entities, 2))

        # entity pairs that does not have a relation
        non_related_pairs = [(e1, e2) for e1, e2 in all_pairs if (e1['text'], e2['text']) not in related_pairs]

        random.seed(4)
        random.shuffle(non_related_pairs)
        for non_related_pair in non_related_pairs:
            if total_additional_rows >= additional_rows:
                break

            e1, e2 = non_related_pair
            from_span = e1['span']
            to_span = e2['span']

            try: 
                if from_span[0] < to_span[0]:
                    sentence, offset = get_sentence(text, (from_span[0], to_span[1]))
                else:
                    sentence, offset = get_sentence(text, (to_span[0], from_span[1]))
            except NoSentenceError as e:
                continue

            from_span_adjusted = adjust_span(from_span, offset)
            to_span_adjusted = adjust_span(to_span, offset)

            texts.append(sentence)
            entity_pairs.append((e1['text'], e2['text']))
            entity_spans_pairs.append((from_span_adjusted, to_span_adjusted))
            relation_labels.append('NIL')

            total_additional_rows += 1

    df = pd.DataFrame({'text': texts, 'entity_pairs': entity_pairs, 'entity_spans_pairs': entity_spans_pairs, 'label': relation_labels})
    return df


In [5]:
df_re = transform_json_re(path_to_data + 'annotations.json')
test_split = pd.read_csv(path_to_data + 'train_test_split/test.csv')

test_df = pd.merge(df_re, test_split, how='inner', on='text')

test_df

entities are not in one sentence
entities are not in one sentence
entities are not in one sentence
entities are not in one sentence
entities are not in one sentence
entities are not in one sentence
entities are not in one sentence
entities are not in one sentence
entities are not in one sentence


Unnamed: 0,text,entity_pairs,entity_spans_pairs,label
0,FLX Networks Partners with GK3 Capital to Prov...,"(FLX Networks, GK3 Capital)","((0, 12), (27, 38))",partners_with
1,Informatica Expands Partnership with Google Cl...,"(Informatica, Google Cloud)","((0, 11), (37, 49))",partners_with
2,Macy's has settled its proxy fight with Arkhou...,"(Macy, Arkhouse)","((0, 4), (40, 48))",NIL
3,"Atto, a leading provider of credit risk soluti...","(Atto, Fico)","((0, 4), (129, 133))",partners_with
4,Bosch and Randox: Strategic partnership brings...,"(Bosch, Randox)","((0, 5), (10, 16))",partners_with
...,...,...,...,...
126,"TCG World, the fast-growing and immersive Web3...","(TCG World, STYNGR)","((0, 9), (126, 132))",partners_with
127,"TCG World, the fast-growing and immersive Web3...","(TCG World, Downtown)","((0, 9), (203, 211))",partners_with
128,"TCG World, the fast-growing and immersive Web3...","(TCG World, STYNGR)","((320, 329), (331, 337))",partners_with
129,"TCG World, the fast-growing and immersive Web3...","(TCG World, Downtown)","((320, 329), (343, 351))",partners_with


In [6]:
grouped_df = test_df.groupby('text').agg({
    'entity_pairs': list,
    'entity_spans_pairs': list,
    'label': list
}).reset_index()

grouped_df

Unnamed: 0,text,entity_pairs,entity_spans_pairs,label
0,"""Petrolicious' often-copied but unparalleled f...","[(Antoine Tessier, DRG)]","[((132, 147), (123, 126))]",[works_at]
1,"(marketscreener.com) April 10, 2024 Accenture ...","[(Unlimited, Accenture), (Unlimited, Accenture)]","[((55, 64), (36, 45)), ((182, 191), (159, 168))]","[acquired_by, acquired_by]"
2,"(marketscreener.com) DENVER, April 10, 2024 /P...","[(Dennis Pullin, DaVita Inc.)]","[((158, 171), (60, 71))]",[works_at]
3,"(marketscreener.com) John Marshall Bank , subs...","[(Sean Biehl, John Marshall Bank), (John Marsh...","[((126, 136), (21, 39)), ((21, 39), (126, 136)...","[works_at, NIL, NIL, NIL]"
4,(marketscreener.com) The following is a round-...,"[(Sylvanite Gold Tailings, Fulcrum Metals Cana...","[((315, 338), (233, 258))]",[acquired_by]
...,...,...,...,...
91,Unveiling its more powerful next-generation AI...,"[(Pat Gelsinger, Intel)]","[((91, 104), (81, 86))]",[works_at]
92,VYRE Network announces a groundbreaking partne...,"[(VYRE Network, Triad Entertainment Network)]","[((0, 12), (57, 84))]",[partners_with]
93,Weave has plans to undertake an extensive eigh...,"[(Weave, KKR)]","[((93, 98), (239, 242))]",[partners_with]
94,"Wednesday 10 April, 2024 Bosch and Randox Labo...","[(Bosch, Randox Laboratories Ltd.), (Bosch, Bo...","[((25, 30), (35, 59)), ((25, 30), (264, 269))]","[partners_with, NIL]"


# Models

In [11]:
# models' names
model_name_ner = "nk_LUKE_ner"
model_name_re = "nk_LUKE_re"
model_name_luke_base = "studio-ousia/luke-base"

task_entity_span = "entity_span_classification"
task_pair_class = "entity_pair_classification"

# models
model_luke_ner = LukeForEntitySpanClassification.from_pretrained(path_to_models + model_name_ner)
model_luke_re = LukeForEntityPairClassification.from_pretrained(path_to_models + model_name_re)

# tokenizer
tokenizer_ner = LukeTokenizer.from_pretrained(model_name_luke_base, task=task_entity_span)
tokenizer_re = LukeTokenizer.from_pretrained(model_name_luke_base, task=task_pair_class)

id2label = {0: 'NIL', 1: 'works_at', 2: 'partners_with', 3: 'acquired_by'}
label2id = {'NIL': 0, 'works_at': 1, 'partners_with': 2, 'acquired_by': 3}

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Evaluation

In [14]:
word_start_pattern = re.compile(r'\b\w')
word_end_pattern = re.compile(r'\w\b')

def luke_ner_generate_entity_spans(text):

    model_entity_spans = []

    # word start and end positions to calculate spans 
    word_start_positions = [match.start() for match in word_start_pattern.finditer(text)]
    word_end_positions = [match.end() for match in word_end_pattern.finditer(text)] 

    # all possible entity spans
    # we consider only entity spans that are not longer then 6 words
    for i, start_pos in enumerate(word_start_positions):
        for end_pos in word_end_positions[i:i+6]:
            model_entity_spans.append((start_pos, end_pos))

    model_entity_spans.sort(key=lambda x: x[0])

    return model_entity_spans


def evaluation_luke_pipeline(name, ner_model, re_model, tokenizer_ner, tokenizer_re, test_df):
    ner_model.eval()
    re_model.eval()

    class_labels = [1, 2, 3]

    TP = {1:0, 2:0, 3:0}
    FP = {1:0, 2:0, 3:0}
    FN = {1:0, 2:0, 3:0}

    precisions = []
    recalls = []
    f1_scores = []

    for row in test_df.itertuples():#tqdm(test_df.itertuples()):
        text = row.text
        true_pairs = row.entity_spans_pairs
        true_labels = row.label

        detected_entities = []
        
        # first perform NER to deted entities
        entity_spans = luke_ner_generate_entity_spans(text)
        ner_encoding = tokenizer_ner(text, entity_spans=entity_spans, truncation=True, return_tensors='pt')
        ner_output = ner_model(**ner_encoding)

        ner_logits = ner_output.logits

        ner_predicted_class_indices = ner_logits.argmax(-1).squeeze().tolist()
            
        for span, predicted_class_idx in zip(entity_spans, ner_predicted_class_indices):
            predicted_label = ner_model.config.id2label[predicted_class_idx]
            if predicted_label in ['ORG', 'PER']:
                detected_entities.append((span[0], span[1]))

        if len(detected_entities) == 0:
            for true_label in true_labels:
                if true_label != 'NIL':
                    FN[label2id[true_label]] = FN[label2id[true_label]] + 1
                    
            continue

        all_pairs = list(permutations(detected_entities, 2))

        """
        true_spans = set([item for sublist in true_pairs for item in sublist])
        print('Entity spans: ', entity_spans)
        print('True spans: ', true_pairs)
        print('Intersection: ', set(entity_spans).intersection(true_spans))
        print('Detected entities: ', detected_entities
        """

        # classify all possible pairs of entities
        for pair in all_pairs:
            re_encoding = tokenizer_re(text, entity_spans=list(pair), return_tensors='pt')
            re_output = re_model(**re_encoding)

            re_logits = re_output.logits
            predicted_class_idx = re_logits.argmax(-1).item()

            #print('Pred', pair, re_model.config.id2label[predicted_class_idx])

            # calculate the results
            if pair in true_pairs:
                inx = true_pairs.index(pair)
                true_class = label2id[true_labels[inx]]
                if predicted_class_idx == true_class:
                    if true_class == 0:
                        continue
                    elif true_class in class_labels:
                        TP[true_class] = TP[true_class] + 1
                else:
                    if true_class == 0:
                        FP[predicted_class_idx] = FP[predicted_class_idx] + 1
                    elif predicted_class_idx == 0:
                        FN[true_class] = FN[true_class] + 1
                    else:
                        FN[true_class] = FN[true_class] + 1
                        FP[predicted_class_idx] = FP[predicted_class_idx] + 1
            else:
                if predicted_class_idx == 0:
                    continue
                else:
                    FP[predicted_class_idx] = FP[predicted_class_idx] + 1

        
    try:
        for class_ in class_labels:
            precision_class = TP[class_] / (TP[class_] + FP[class_])
            recall_class = TP[class_] / (TP[class_] + FN[class_])

            precisions.append(precision_class)
            recalls.append(recall_class)
            f1_scores.append(2 * precision_class * recall_class / (precision_class + recall_class))

        macro_precision = sum(precisions) / len(class_labels)
        macro_recall = sum(recalls) / len(class_labels)
        macro_f1 = sum(f1_scores) / len(class_labels)

        print(name)
        print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
        print(f"Precision macro: {macro_precision}")
        print(f"Recall macro: {macro_recall}")
        print(f"F1 Score macro: {macro_f1}")
        print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')

    except ZeroDivisionError as e:
        print(e)
        macro_precision, macro_recall, macro_f1 = 0, 0, 0

    return [name, macro_precision, macro_recall, macro_f1] + precisions + recalls + f1_scores


In [15]:
"""
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
Precision macro: 0.5887799564270153  
Recall macro: 0.3829365079365079  
F1 Score macro: 0.45279790660225444  
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%  
"""

result = evaluation_luke_pipeline('pipeline', model_luke_ner, model_luke_re, tokenizer_ner, tokenizer_re, grouped_df)
result

pipeline
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Precision macro: 0.5887799564270153
Recall macro: 0.3829365079365079
F1 Score macro: 0.45279790660225444
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%


['pipeline',
 0.5887799564270153,
 0.3829365079365079,
 0.45279790660225444,
 0.9166666666666666,
 0.5555555555555556,
 0.29411764705882354,
 0.4583333333333333,
 0.35714285714285715,
 0.3333333333333333,
 0.611111111111111,
 0.43478260869565216,
 0.3125]