In [1]:
import re
import csv
import json
from tqdm import tqdm
from nltk.tokenize import sent_tokenize
import pickle
from genre.trie import Trie, MarisaTrie
from genre.fairseq_model import mGENRE

with open("../GENRE/data/lang_title2wikidataID-normalized_with_redirect.pkl", "rb") as f:
    lang_title2wikidataID = pickle.load(f)

with open("../GENRE/data/titles_lang_all105_marisa_trie_with_redirect.pkl", "rb") as f2:
    trie = pickle.load(f2)
    
model = mGENRE.from_pretrained("../GENRE/models/fairseq_multilingual_entity_disambiguation").eval()

with open("data/pages_vol_1.csv", "r", encoding="utf-8") as f:
    dict_reader = csv.DictReader(f, delimiter=",")
    data = list(dict_reader)

with open("data/pages_names.csv", "r") as f:
    dict_reader = csv.DictReader(f, delimiter=",")
    ner_output = list(dict_reader)

output=[]

pbar = tqdm(total=len(data))
for item in data:
    text=item["text"]
    data_id = item["pg_num"]
    entities = []
    sentences = []
    begin = []
    end = []
    wb_ids = []
    scores = []
    surface_forms = [(int(ent["start_pos"]), int(ent["end_pos"])) for ent in ner_output \
                     if ent["page_number"] == data_id]
    for ent in surface_forms:
        start_pos = ent[0]
        end_pos = ent[1]
        if start_pos >= 500:
            history_start = start_pos-500
        else:
            history_start = 0
        if end_pos + 500 <= len(text):
            future_end = end_pos+500
        else:
            future_end = len(text)
        mention = text[history_start:start_pos]+" [START] "+ text[start_pos:end_pos]+ " [END]"+text[end_pos:future_end]
        sentences.append(mention)
        begin.append(start_pos)
        end.append(end_pos)
    results = model.sample(
        sentences,
        prefix_allowed_tokens_fn=lambda batch_id, sent: [
            e for e in trie.get(sent.tolist()) if e < len(model.task.target_dictionary)
            ],
            text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(" >> ")))], key=lambda y: int(y[1:])),
            marginalize=True,
        )
    
    for result in results:
        candidate_ids = [candidate["texts"][0]+" ("+candidate["id"]+")" for candidate in result]
        candidate_scores= [str(candidate["score"].item()) for candidate in result]
        entities.append("/ ".join(candidate_ids))
        scores.append("/ ".join(candidate_scores))

    labels = list(zip(begin, end, scores, entities))
    for start_pos, end_pos, score, alias in labels:
        output.append(
            {
                "id":item["pg_num"],
                "start_pos":start_pos,
                "end_pos":end_pos,
                "candidates":alias,
                "scores":score
            }
        )
    pbar.update(1)


keys = output[0].keys()

a_file = open("data/disambiguated_names.csv", "w")
dict_writer = csv.DictWriter(a_file, keys)
dict_writer.writeheader()
dict_writer.writerows(output)
a_file.close()


2023-05-23 17:09:28 | INFO | fairseq.tasks.text_to_speech | Please install tensorboardX: pip install tensorboardX
2023-05-23 17:12:35 | INFO | fairseq.file_utils | loading archive file ../GENRE/models/fairseq_multilingual_entity_disambiguation
2023-05-23 17:12:36 | INFO | fairseq.tasks.translation | [source] dictionary: 256001 types
2023-05-23 17:12:36 | INFO | fairseq.tasks.translation | [target] dictionary: 256001 types
100%|█████████████████████████████████████████████████████████████████████████████████| 221/221 [31:27<00:00,  8.85s/it]