Libraries

In [None]:
# Libraries
import sys
import os
import json
import litellm 
from pydantic import BaseModel
from enum import Enum
from datasets import load_dataset, load_from_disk("...")

# add path to the dataset entities
sys.path.append(os.path.abspath("../0. Helpers"))
sys.path.append(os.path.abspath("../2. Data Processing/_dataset_entities"))

from datasetProcessing import tokens_to_sentence, tokens_to_entities, join_datasets, recursive_fix
from performance import Prediction

Configuations

In [None]:
# Prepare LLM environment
os.environ["AZURE_API_KEY"] = "..."
os.environ["AZURE_API_BASE"] = "..."

class LLM_Entity(BaseModel):
    span: str
    entity: str

class LLM_Entity_List(BaseModel):
    entities: list[LLM_Entity]

In [None]:
results_folder = "results/entity_info"

class Config:
    def __init__(self, lang, prompt_type, prompt_subtype=None):
        self.lang = lang
        self.prompt_type = prompt_type
        self.prompt_subtype = prompt_subtype

    def __str__(self):
        
        prompt_str = self.prompt_type 
        if self.prompt_subtype:
            prompt_str += "_" + self.prompt_subtype

        return f"{prompt_str}"
    
all_configs = {
    "ai": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "literature": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "music": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "politics": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "science": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "multinerd_pt": [
        Config("pt", "list"), Config("pt", "description", "general"), Config("pt", "description", "in-context"), Config("pt", "point", "token"), Config("pt", "point", "span")
    ],
    "multinerd_en": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "ener": [
        Config("en", "list"), Config("en", "description", "general"), Config("en", "description", "in-context"), Config("en", "point", "token"), Config("en", "point", "span")
    ],
    "lener": [
        Config("pt", "list"), Config("pt", "description", "general"), Config("pt", "description", "in-context"), Config("pt", "point", "token"), Config("pt", "point", "span")
    ],
    "neuralshift": [
        Config("pt", "list"), Config("pt", "description", "general"), Config("pt", "description", "in-context"), Config("pt", "point", "token"), Config("pt", "point", "span")
    ]
}

Prompt

In [None]:
def get_prompt_prefix(topic, lang, prompt_type, prompt_sub_type=""):
    entity_info = ""

    # Split by prompt type

    if prompt_type == "list":
        entity_info = open(f"entity_info/list/{topic}.txt", "r", encoding="utf-8").read()

    elif prompt_type == "description":
        entity_info = open(f"entity_info/description/{prompt_sub_type}/{topic}.txt", "r", encoding="utf-8").read()

    elif prompt_type == "point":

        if prompt_sub_type == "span":
            point_file = "_point_span_4"
        elif prompt_sub_type == "token":
            point_file = "_point_token_6"

        point_dict = json.load(open(f"entity_info/point_entities/{prompt_sub_type}/{topic}/train/{point_file}.json", "r", encoding="utf-8"))
        for entity, clusters in point_dict.items():
            entity_info += f"- \"{entity}\" e.g. {', '.join(clusters)}\n"

    # Split by topic

    # AI
    if topic == "ai":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to in the domain of artificial intelligence.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}

If an entity does not fit the types above it is considered "misc".
Be sure to prioritize more specific entities, such as "researcher" over "person", "conference" over "location" and "university" over "organisation", when it makes sense.
"""
    # LITERATURE
    elif topic == "literature":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to in the domain of literature.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}

If an entity does not fit the types above it is considered "misc".
Be sure to prioritize more specific entities, such as "writer" over "person", when it makes sense.
"""
    
    # MUSIC
    elif topic == "music":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to in the domain of music.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}

If an entity does not fit the types above it is considered "misc".
Be sure to prioritize more specific entities, such as "musical artist" over "person" and "band" over "organisation", when it makes sense.
"""

    # POLITICS
    elif topic == "politics":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to in the domain of politics.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}

If an entity does not fit the types above it is considered "misc".
Be sure to prioritize more specific entities, such as "politician" over "person" and "political party" over "organisation", when it makes sense.
"""
        
    # SCIENCE
    elif topic == "science":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to in the domain of science.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}

Abstract scientific concepts can be entities if they have a name associated with them.

If an entity does not fit the types above it is considered "misc".
Be sure to prioritize more specific entities, such as "scientist" over "person" and "university" over "organisation" or "location", when it makes sense.
"""

    # MULTINERD PT
    elif topic == "multinerd_pt":
        prompt_prefix = f"""Neste contexto, uma *entidade* refere-se a qualquer objeto ou conceito do mundo real que seja especificamente mencionado ou referido.
Datas, horas, conceitos abstratos, adjetivos e verbos NÃO são entidades.

Usa o seguinte conjunto de tipos possíveis de entidade:
{entity_info}

Se uma entidade não se encaixar em nenhum dos tipos acima, não a incluas na resposta.
"""
        
    # MULTINERD EN
    elif topic == "multinerd_en":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}
"""
        
    # E-NER
    elif topic == "ener":
        prompt_prefix = f"""In this context, an *entity* refers to any real-world object or concept that is specifically named or referred to in the legal domain.
Dates, times, abstract concepts, adjectives and verbs are NOT entities.

Use the following set of possible entity labels:
{entity_info}

If an entity does not fit the types above it is considered "misc".
"""
        
    # LeNER-Br + NEURALSHIFT
    elif (topic == "lener" or topic == "neuralshift"):
        prompt_prefix = f"""Neste contexto, uma *entidade* refere-se a qualquer objeto ou conceito do mundo real que seja especificamente mencionado ou referido no domínio legal.
Conceitos abstratos, adjetivos e verbos NÃO são entidades.

Usa o seguinte conjunto de tipos possíveis de entidade:
{entity_info}

Se uma entidade não se encaixar em nenhum dos tipos acima, não a incluas na resposta.
"""
        
################## FINAL PROMPT ##################

    # validation
    if entity_info=="" or prompt_prefix == "":
        raise ValueError(f"Error retrieving entity info for topic {topic}, prompt type {prompt_type} and prompt sub type {prompt_sub_type}.")

    # final prompt instruction
    if lang == "en":
        return f"""{prompt_prefix}

Return the entities in a structured JSON format with the following fields:
- "span": the exact span of the entity as it appears in the input
- "entity": the category/type of the entity

Now extract entities from the following text:
"""
    elif lang == "pt":
        return f"""{prompt_prefix}

Retorna as entidades num formato JSON estruturado com os seguintes campos:
- "span": o span exato da entidade conforme aparece no texto de input
- "entity": a classe/tipo da entidade

Agora extrai as entidades do seguinte texto:
"""
    else:
        raise ValueError(f"Language {lang} not supported.")

LLM functions

In [None]:
# Call LLM
def safe_llm_call(prompt, system, instance):
    try:
        
        response = litellm.completion(
            model = "azure/gpt-4o-mini",
            messages = [
                {"role": "system", "content": system},
                {"role": "user", "content": prompt},
            ],

            temperature = 0.1,
            response_format = LLM_Entity_List,

            # stream = False,
            # top_p = 1,
        )

        # extract LLM predictions
        return response.choices[0].message["content"]

    except Exception as e:
        print(f"\n❌❌ LLM call failed: {e}\n")
        print(f"\nInstance: {instance}")
        raise

def process_instance(topic, config: Config, i, instance, system_prompt, entity_names_parsed, start_of_entity_indices, entity_index_to_name):
    
    results_path = f"{results_folder}/{topic}/{str(config)}/{i}.json"

    # Check if the results file already exists
    if os.path.exists(results_path):
        print(f" >>> Results for sentence #{i+1} already exist. Skipping...")
        return
    
    # get the instance
    sentence = tokens_to_sentence(instance['tokens'])
    true_entities = tokens_to_entities(instance['tokens'], instance['ner_tags'], entity_names_parsed, start_of_entity_indices, entity_index_to_name)

    # set prompt
    prompt = get_prompt_prefix(topic, config.lang, config.prompt_type, config.prompt_subtype)
    prompt = f"{prompt}\n{sentence}"

    # Save prompt to txt file
    prompt_file_path = f"{results_folder}/{topic}/{str(config)}/prompts/prompt_{i}.txt"
    with open(prompt_file_path, "w", encoding="utf-8") as f:
        f.write(prompt)

    # create prediction object
    prediction = Prediction(i, sentence)

    # call the LLM
    try:
        llm_response = safe_llm_call(prompt, system_prompt, sentence)

        if llm_response is None:
            print(f"❌ LLM response is None for sentence #{i}: {sentence}")
            return

        llm_json = json.loads(llm_response)
        llm_entities = llm_json.get('entities', [])

    except Exception as e:
        print(f"❌ Error on sentence #{i}: {e}")
        return

    # compute predictions
    prediction.set_results(true_entities, llm_entities)
    prediction.compute_performance()
    prediction.compute_relaxed_performance()

    # write json to file
    result_json = {
        "id": i,
        "sentence": sentence,
        "true_entities": [entity.to_dict() for entity in true_entities],
        "entities": llm_entities,
         "performance": {
            "tp": prediction.performance.tp,
            "fp": prediction.performance.fp,
            "fn": prediction.performance.fn,
            "precision": prediction.performance.precision(),
            "recall": prediction.performance.recall(),
            "f1": prediction.performance.f1()
        },
        "relaxed_performance": {
            "tp": prediction.relaxed_performance.tp,
            "fp": prediction.relaxed_performance.fp,
            "fn": prediction.relaxed_performance.fn,
            "precision": prediction.relaxed_performance.precision(),
            "recall": prediction.relaxed_performance.recall(),
            "f1": prediction.relaxed_performance.f1()
        },
        "tokens": instance['tokens'],
        "ner_tags": instance['ner_tags']
    }  

    # save results to file
    with open(results_path, "a", encoding="utf-8") as f:
        f.write(json.dumps(result_json, ensure_ascii=False, indent=4))

Run for each config

In [None]:
def system_prompt(lang):
    if lang == "en":
        return "You are a named entity recognition (NER) system. Your task is to extract all entities mentioned in the input text. Always respond with JSON containing named entities."
    elif lang == "pt":
        return "És um sistema de reconhecimento de entidades (NER). A tua tarefa é extrair todas as entidades mencionadas no texto de input. Responde sempre no formato JSON com as entidades."

In [None]:
# system based on lang
for topic, configs in all_configs.items():

    # Load the dataset
    if topic == "lener":
        from entities_leNER import entity_names, entity_names_parsed
        dataset = load_from_disk("...")

    elif topic == "neuralshift":
        from entities_neuralshift import entity_names, entity_names_parsed
        dataset = load_from_disk("...")

    elif topic == "ener":
        from entities_eNER import entity_names, entity_names_parsed
        dataset = load_from_disk("...")

    elif topic == "multinerd_en":
        from entities_multinerd_en import entity_names, entity_names_parsed
        dataset = load_from_disk("...")

    elif topic == "multinerd_pt":
        from entities_multinerd_pt import entity_names, entity_names_parsed
        dataset = load_from_disk("...")

    else:
        from entities_crossNER import entity_names, entity_names_parsed
        dataset = load_dataset("...")

    # all_data = test
    all_data = dataset['test']

    # get the entity names
    start_of_entity_indices = [i for i in range(len(entity_names)) if (entity_names[i].startswith("B-") or entity_names[i].startswith("U-"))]
    entity_index_to_name = {i: entity_names[i].split("-")[1] for i in range(len(entity_names)) if entity_names[i] != "O"}
    entity_index_to_name[0] = "O"

    # Run for each config
    for config in configs:

        lang = config.lang
        config_folder = str(config)

        # Ensure the results directory exists
        os.makedirs(f"{results_folder}/{topic}/{config_folder}", exist_ok=True)
        os.makedirs(f"{results_folder}/{topic}/{config_folder}/prompts", exist_ok=True)

        # Run through all instances
        print("\n\nRunning config:", topic, config_folder)
        for i, instance in enumerate(all_data):
            print(f"\r\tProcessing instance {i+1}/{len(all_data)}", end='', flush=True)
            process_instance(topic, config, i, instance, system_prompt(lang), entity_names_parsed, start_of_entity_indices, entity_index_to_name)