In [3]:
%%capture
!pip install transformers
!pip install ujson

In [4]:
%%capture
from transformers import pipeline
import logging
import os
import ujson
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import requests
from bs4 import BeautifulSoup

In [35]:
def get_entity_name(entity_id):
    # URL per ottenere i dettagli dell'entità
    entity_url = f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json"

    response = requests.get(entity_url)
    if response.status_code != 200:
        raise Exception(f"Errore durante l'accesso ai dettagli dell'entità: {response.status_code}")

    entity_data = response.json()

    # Estrarre i dettagli dell'entità
    entity_labels = entity_data["entities"][next(iter(entity_data["entities"]))]["labels"]
    # Ottenere il nome dell'entità in inglese
    entity_name = entity_labels.get("en", {}).get("value")

    return entity_name

In [36]:
def get_entity_id(entity_name):
    # URL di ricerca sul sito Wikidata
    search_url = f"https://www.wikidata.org/w/index.php"
    params = {
        "search": entity_name,
        "title": "Special:Search",
        "profile": "advanced",  # Usa il profilo avanzato per i risultati più accurati
        "fulltext": 1,
        "ns0": 1,  # Cerca solo nello spazio principale (entità)
    }

    # Richiesta al sito Wikidata
    response = requests.get(search_url, params=params)
    if response.status_code != 200:
        raise Exception(f"Errore durante l'accesso alla pagina di ricerca: {response.status_code}")

    # Analisi del contenuto HTML
    soup = BeautifulSoup(response.text, "html.parser")

    # Trova il primo risultato
    first_result = soup.find("div", class_="mw-search-result-heading")
    if first_result:
        # Estrai il link associato al risultato
        link = first_result.find("a")["href"]
        # Estrai il codice dell'entità dal link
        entity_id = link.split("/")[-1]

        # URL per ottenere i dettagli dell'entità
        entity_url = f"https://www.wikidata.org/wiki/Special:EntityData/{entity_id}.json"

        response = requests.get(entity_url)
        if response.status_code != 200:
            raise Exception(f"Errore durante l'accesso ai dettagli dell'entità: {response.status_code}")

        return list(response.json()['entities'].keys())[0]

    else:
        # print("Nessun risultato trovato.")
        return None

In [37]:
def clean_entities(entities):
    results = []
    for entity in entities:
        # Clean up extra spaces introduced during tokenization
        cleaned_word = entity['word'].replace(" ' ", "' ")
        results.append(cleaned_word)
    return results

In [38]:
def find_entity(sentence):
  entities = ner_pipeline(sentence)
  cleaned_words = clean_entities(entities)
  return combine_words(cleaned_words)

In [10]:
def combine_words(words):
    combined = []
    temp = ""

    for word in words:
        # Se il token inizia con '##', è una parte di una parola più lunga
        if word.startswith('##'):
            temp += word[2:]  # Aggiungi solo la parte dopo '##'
        else:
            if temp:  # Se c'era una parte precedentemente accumulata, aggiungila
                combined.append(temp)
            temp = word  # Inizia una nuova parola
    if temp:
        combined.append(temp)  # Aggiungi l'ultima parte accumulata

    return " ".join(combined)

In [45]:
def process_line(line):
    data = ujson.loads(line)
    sentence = data.get('source')
    entities_id = data.get('wikidata_id')

    tot = 1
    generated = 0
    corrects = 0

    true_entity = get_entity_name(entities_id)
    entitiy_obtained = find_entity(sentence)

    true_id = get_entity_id(true_entity)
    id_obtained = get_entity_id(entitiy_obtained)

    if id_obtained:
      generated = 1
      if id_obtained == true_id:
        corrects = 1

    # print(tot,corrects, generated)
    return tot, corrects, generated

In [40]:
def evaluate_NER(file_path,n_lines):
  with open(file_path, 'r') as file:
    limited_lines = (line for i, line in enumerate(file) if i < n_lines)
    with ThreadPoolExecutor() as executor:
        results = list(tqdm(executor.map(process_line, limited_lines), desc="\tProcessing lines"))

    tot = sum(res[0] for res in results)
    corrects = sum(res[1] for res in results)
    generated = sum(res[2] for res in results)

    return tot,corrects,generated

In [46]:
folder_path = "semeval-data/references/validation"
results= dict()
ner_names = [
    "dslim/bert-base-NER",
    "dslim/bert-large-NER",
    "Jean-Baptiste/camembert-ner"
]

for ner_name in ner_names:
  print("=====================================================")
  print('NER:',ner_name)
  print("=====================================================")
  tot,corrects,generated= 0,0,0
  for file_name in os.listdir(folder_path):

      file_path = os.path.join(folder_path, file_name)
      if os.path.isfile(file_path):
          print('\tanalyzing',file_name)
          # Set logging level to suppress the warning
          logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
          ner_pipeline = pipeline(model=ner_name, aggregation_strategy='simple', device=0)
          res = evaluate_NER(file_path, n_lines=50)
          tot += res[0]
          corrects += res[1]
          generated += res[2]

  precision = corrects / generated if generated > 0 else 0
  recall = corrects / tot if tot > 0 else 0
  f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
  results[ner_name]= {'\tprecision':precision, 'recall': recall, 'f1_score': f1_score}

NER: dslim/bert-base-NER
	analyzing zh_TW.jsonl


	Processing lines: 50it [00:24,  2.05it/s]


	analyzing ar_AE.jsonl


	Processing lines: 50it [00:27,  1.82it/s]


	analyzing th_TH.jsonl


	Processing lines: 50it [00:28,  1.78it/s]


	analyzing es_ES.jsonl


	Processing lines: 50it [00:23,  2.11it/s]


	analyzing de_DE.jsonl


	Processing lines: 50it [00:27,  1.82it/s]


	analyzing tr_TR.jsonl


	Processing lines: 50it [00:25,  1.92it/s]


	analyzing it_IT.jsonl


	Processing lines: 50it [00:25,  1.98it/s]


	analyzing fr_FR.jsonl


	Processing lines: 50it [00:24,  2.02it/s]


	analyzing ko_KR.jsonl


	Processing lines: 50it [00:28,  1.75it/s]


	analyzing ja_JP.jsonl


	Processing lines: 50it [00:29,  1.70it/s]


NER: dslim/bert-large-NER
	analyzing zh_TW.jsonl


	Processing lines: 50it [00:27,  1.79it/s]


	analyzing ar_AE.jsonl


	Processing lines: 50it [00:30,  1.61it/s]


	analyzing th_TH.jsonl


	Processing lines: 50it [00:28,  1.75it/s]


	analyzing es_ES.jsonl


	Processing lines: 50it [00:23,  2.15it/s]


	analyzing de_DE.jsonl


	Processing lines: 50it [00:26,  1.90it/s]


	analyzing tr_TR.jsonl


	Processing lines: 50it [00:25,  1.97it/s]


	analyzing it_IT.jsonl


	Processing lines: 50it [00:25,  1.98it/s]


	analyzing fr_FR.jsonl


	Processing lines: 50it [00:24,  2.01it/s]


	analyzing ko_KR.jsonl


	Processing lines: 50it [00:25,  1.99it/s]


	analyzing ja_JP.jsonl


	Processing lines: 50it [00:30,  1.62it/s]


NER: Jean-Baptiste/camembert-ner
	analyzing zh_TW.jsonl


	Processing lines: 50it [00:34,  1.46it/s]


	analyzing ar_AE.jsonl


	Processing lines: 50it [00:32,  1.53it/s]


	analyzing th_TH.jsonl


	Processing lines: 50it [00:30,  1.62it/s]


	analyzing es_ES.jsonl


	Processing lines: 50it [00:27,  1.80it/s]


	analyzing de_DE.jsonl


	Processing lines: 50it [00:29,  1.67it/s]


	analyzing tr_TR.jsonl


	Processing lines: 50it [00:29,  1.71it/s]


	analyzing it_IT.jsonl


	Processing lines: 50it [00:25,  1.96it/s]


	analyzing fr_FR.jsonl


	Processing lines: 50it [00:25,  1.93it/s]


	analyzing ko_KR.jsonl


	Processing lines: 50it [00:26,  1.86it/s]


	analyzing ja_JP.jsonl


	Processing lines: 50it [00:27,  1.83it/s]


In [47]:
for result in results.items():
  print(result[0])
  print(result[1])
  print()

dslim/bert-base-NER
{'\tprecision': 0.7572815533980582, 'recall': 0.624, 'f1_score': 0.6842105263157895}

dslim/bert-large-NER
{'\tprecision': 0.7598039215686274, 'recall': 0.62, 'f1_score': 0.6828193832599119}

Jean-Baptiste/camembert-ner
{'\tprecision': 0.834070796460177, 'recall': 0.754, 'f1_score': 0.792016806722689}

