In [None]:
import requests
import re
import csv
import json
from tqdm import tqdm

In [None]:
import pickle
from genre.trie import Trie, MarisaTrie
from genre.fairseq_model import mGENRE

In [None]:
with open("data/lang_title2wikidataID-normalized_with_redirect.pkl", "rb") as f:
    lang_title2wikidataID = pickle.load(f)

with open("data/titles_lang_all105_marisa_trie_with_redirect.pkl", "rb") as f2:
    trie = pickle.load(f2)
    
model = mGENRE.from_pretrained("models/fairseq_multilingual_entity_disambiguation").eval()

In [None]:
with open("../data/sentences_it.csv", "r") as f:
    dict_reader = csv.DictReader(f, delimiter=",")
    data = list(dict_reader)
print(data[0])

In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("Babelscape/wikineural-multilingual-ner")
tagger2 = AutoModelForTokenClassification.from_pretrained("Babelscape/wikineural-multilingual-ner")
nlp = pipeline("ner", model=tagger2, tokenizer=tokenizer)

In [None]:
txt = data[0]["sentence"]
ner_result = nlp(txt, aggregation_strategy='simple')
print(ner_result)

In [None]:
output=[]

pbar = tqdm(total=len(data))
for item in data:
    entities = list()
    text=item["sentence"]
    ner_result = nlp(text, aggregation_strategy='simple')
    entities = []
    start_positions = []
    end_positions = []
    labels = []
    sentences = []
    wb_ids = []
    scores = []
    for ent in ner_result:
        start_pos = ent["start"]
        end_pos = ent["end"]
        label = ent["entity_group"]
        mention = text[0:start_pos]+" [START] "+ text[start_pos:end_pos]+ " [END]"+text[end_pos:]
        start_positions.append(start_pos)
        end_positions.append(end_pos)
        labels.append(label)
        sentences.append(mention)
    results = model.sample(
        sentences,
        prefix_allowed_tokens_fn=lambda batch_id, sent: [
            e for e in trie.get(sent.tolist()) if e < len(model.task.target_dictionary)
            ],
            text_to_id=lambda x: max(lang_title2wikidataID[tuple(reversed(x.split(" >> ")))], key=lambda y: int(y[1:])),
            marginalize=True,
        )
    
    for result in results:
        candidate = result[0]
        name = candidate["texts"][0]
        score = candidate["score"].item()
        wb_id = candidate["id"]
        entities.append(name)
        scores.append(score)
        wb_ids.append(wb_id)
    
    labels = list(zip(start_positions, end_positions, labels, scores, entities, wb_ids))
    for start_pos, end_pos, label, score, alias, wb_ids in labels:
        output.append(
            {
                "id":item["id"],
                "start_pos":start_pos,
                "end_pos":end_pos,
                "type":label,
                "alias":alias,
                "wb_id":wb_ids,
                "score":score
            }
        )
    pbar.update(1)
pbar.close()

In [None]:
keys = output[0].keys()

a_file = open("output_nel.csv", "w")
dict_writer = csv.DictWriter(a_file, keys)
dict_writer.writeheader()
dict_writer.writerows(output)
a_file.close()