In [2]:
from genre.hf_model import GENRE
from genre.entity_linking import get_end_to_end_prefix_allowed_tokens_fn_hf as get_prefix_allowed_tokens_fn
from genre.utils import get_entity_spans_hf as get_entity_spans
import os
import json
import torch
from multiprocessing import Pool
from tqdm import trange
import numpy as np
from rouge import Rouge

In [2]:
with open("/harddisk/user/keminglu/pretrained_data_processed/wikipedia_with_mention_wo_title_simplified_aug_eval/corpus_filtered_test_prompt_rephrased") as f:
    data = [json.loads(line) for line in f.readlines()]
data = [sample for sample in data if sample['aug_type'] == 'aug_default'][:100]

In [8]:
def inference(
        node_data,
        batch_size=4,
        config="./models/hf_e2e_entity_linking_aidayago",
    ):
    num_batch = int(np.ceil(len(node_data) / batch_size))

    device = torch.device(f"cuda:4")
    model = GENRE.from_pretrained(config).eval()
    model = model.to(device)

    results = []
    for i in trange(num_batch):
        batch_begin, batch_end = i * batch_size, (i+1) * batch_size
        batch_data = node_data[batch_begin:batch_end]
        text_inputs = [each['inputs'] for each in batch_data]
        outputs = get_entity_spans(
            model,
            text_inputs
        )
        assert len(batch_data) == len(outputs)
        for k in range(len(outputs)):
            try:
                processed_output = [(item[3], item[2].replace("_", " ")) for item in outputs[k] if len(item) == 4]
            except IndexError:
                print(outputs[k])
            batch_data[k]['outputs'] = processed_output
        results.extend(batch_data)
    return results

In [9]:
results = inference(data)

100%|██████████| 25/25 [12:15<00:00, 29.42s/it]


In [10]:
with open("genre_output.json", "w") as f:
    for result in results:
        f.write(json.dumps(result) + "\n")

# Calculating Scores

In [3]:
with open("genre_output.json") as f:
    data_0 = [json.loads(line) for line in f.readlines()]
with open("genre_output_1.json") as f:
    data_1 = [json.loads(line) for line in f.readlines()]
with open("genre_output_2.json") as f:
    data_2 = [json.loads(line) for line in f.readlines()]
with open("genre_output_3.json") as f:
    data_3 = [json.loads(line) for line in f.readlines()]
data = data_0 + data_1 + data_2 + data_3

In [14]:
report = {
    "pos_mention": 0,
    "pos_mention_seen": 0,
    "pos_mention_unseen": 0,
    "pos_title": 0,
    "pos_title_seen": 0,
    "pos_title_unseen": 0,
    "total": 0,
    "pred_total": 0,
    "total_seen": 0,
    "total_unseen": 0,
}
threshold = 0.99

R = Rouge()
for sample in data:
    targets = {each['mention']: each['title'] for each in json.loads(sample['targets'])['entities']}
    unseen_targets = {each['mention']: each['title'] for each in json.loads(sample['targets'])['entities'] if each['ood'] != 'in'}
    seen_targets = {each['mention']: each['title'] for each in json.loads(sample['targets'])['entities'] if each['ood'] == 'in'}
    outputs = {item[0]: item[1] for item in sample['outputs']}


    pos_mention = 0
    pos_title = 0
    for key in targets:
        if key in outputs:
            pos_mention += 1
            rouge = R.get_scores(outputs[key], targets[key])[0]['rouge-l']['f']
            if rouge >= threshold:
                pos_title += 1

    pos_mention_seen = 0
    pos_title_seen = 0
    for key in seen_targets:
        if key in outputs:
            pos_mention_seen += 1
            if outputs[key] == seen_targets[key]:
                pos_title_seen += 1

    pos_mention_unseen = 0
    pos_title_unseen = 0
    for key in unseen_targets:
        if key in outputs:
            pos_mention_unseen += 1
            if outputs[key] == unseen_targets[key]:
                pos_title_unseen += 1

    report['pos_mention'] += pos_mention
    report['pos_mention_seen'] += pos_mention_seen
    report['pos_mention_unseen'] += pos_mention_unseen

    report['pos_title'] += pos_title
    report['pos_title_seen'] += pos_title_seen
    report['pos_title_unseen'] += pos_title_unseen

    report['pred_total'] += len(outputs)

    report['total'] += len(targets)
    report['total_seen'] += len(seen_targets)
    report['total_unseen'] += len(unseen_targets)

report['mention_recall'] = report['pos_mention'] / report['total']
report['mention_precision'] = report['pos_mention'] / report['pred_total']
report['mention_f1'] = 2 * report['mention_precision'] * report['mention_recall'] / (report['mention_precision'] + report['mention_recall'])

report['title_recall'] = report['pos_title'] / report['total']
report['title_precision'] = report['pos_title'] / report['pred_total']
report['title_f1'] = 2 * report['title_precision'] * report['title_recall'] / (report['title_precision'] + report['title_recall'])

report['seen_mention_recall'] = report['pos_mention_seen'] / report['total_seen']
report['seen_title_recall'] = report['pos_title_seen'] / report['total_seen']
report['unseen_mention_recall'] = report['pos_mention_unseen'] / report['total_unseen']
report['unseen_title_recall'] = report['pos_title_unseen'] / report['total_unseen']

In [15]:
report['seen_title_recall']

0.247637608293346