Libraries

In [None]:
# Libraries
import numpy as np
import sys
import os
import json
from pathlib import Path
import re

# add path to the dataset entities
sys.path.append(os.path.abspath("../0. Helpers"))
sys.path.append(os.path.abspath("../2. Data Processing/_dataset_entities"))

from datasetProcessing import Entity, recursive_fix
from performance import Prediction, Performance

In [None]:
class ErrorClass:
    def __init__(self, tp, type, misalign, fp, fn):
        self.tp = tp
        self.type = type
        self.misalign = misalign
        self.fp = fp
        self.fn = fn


def spans_overlap(span1: str, span2: str) -> bool:
    return span1 in span2 or span2 in span1

def spans_match(span1: str, span2: str) -> bool:
    return span1 == span2

def classify_error(true_entities: list[dict], pred_entities: list[dict]):
    n_tp = 0
    n_type = 0
    n_misalign = 0
    n_coverage_fp = 0
    n_coverage_fn = 0

    matched_true = set()
    matched_pred = set()

    for i, t in enumerate(true_entities):
        for j, p in enumerate(pred_entities):
            
            if spans_match(t["span"], p["span"]):
                matched_true.add(i); matched_pred.add(j)
                if t["entity"] != p["entity"]:
                    n_type += 1  # same span, wrong type
                else:
                    n_tp += 1  # correct prediction
                break

            elif spans_overlap(t["span"], p["span"]) and t["entity"] == p["entity"]:
                matched_true.add(i); matched_pred.add(j)
                # overlapping spans, same type
                n_misalign += 1
                break
        else:
            n_coverage_fn += 1  # gold entity not matched by any prediction

    # remaining predictions are false positives
    for j, p in enumerate(pred_entities):
        if j not in matched_pred:
            n_coverage_fp += 1

    return ErrorClass(n_tp, n_type, n_misalign, n_coverage_fp, n_coverage_fn)


In [None]:
def process_instance(file_path):
    
    try:
        with open(file_path, mode='r', encoding="utf-8") as f:
            content = f.read()

        if not content.strip():
            print(f"üóëÔ∏è Empty file detected: {file_path}")
            # file_path.unlink()
            return None

        # Fix JSON extra comma
        content = re.sub(r',\s*$', '', content)
        data = json.loads(content)

        # Apply encoding fix
        data = recursive_fix(data)

        # extract entities  
        true_entities = data.get("true_entities", [])
        llm_entities = data.get("entities", [])

        # remove duplicates from llm entities
        predicted_entities = []
        for entity in llm_entities:

            for added_entity in predicted_entities:
                if entity["span"].strip().lower() == added_entity["span"].strip().lower() and entity["entity"].strip().lower() == added_entity["entity"].strip().lower():
                    break
            else:
                predicted_entities.append(entity)

        return classify_error(true_entities, predicted_entities)

    except Exception as e:
        print(f"‚ùå Error reading {file_path}: {e}")
        # file_path.unlink()
        return None

Evaluate each model

In [None]:
folder_prefix = "results/demo_type"
folder_suffix = "in_context_top"

all_configs = {
    "ai": 10,
    "literature": 10,
    "music": 10,
    "politics": 20,
    "science": 20,
    "multinerd_en": 20,
    "multinerd_pt": 20,
    "ener": 20,
    "lener": 20,
    "neuralshift": 20
}

In [None]:
for topic, n in all_configs.items():
    
    print(f"Processing topic: {topic}")
    topic_path = Path(f"{folder_prefix}/{topic}/{folder_suffix}{n}")
    print(topic_path)

    if not topic_path.exists():
        print(f"‚ùå Topic folder {topic} does not exist.")
        continue

    # Process all instances in the folder
    dataset_performance = []
    for file_path in topic_path.glob("*.json"):
        result = process_instance(file_path)
        dataset_performance.append(result)

    # Filter out None results
    dataset_performance = [instance for instance in dataset_performance if instance is not None]
    if not dataset_performance:
        print(f"‚ö†Ô∏è No valid instances found in {topic_path.name}. Skipping...")
        continue
    else:
        print(f"‚úÖ Processed {len(dataset_performance)} valid instances in {topic_path.name}.")

    # compute individual performance metrics
    metrics_dict = {
        "total_samples": len(dataset_performance),
        "type_errors": sum(instance.type for instance in dataset_performance),
        "misalign_errors": sum(instance.misalign for instance in dataset_performance),
        "coverage_fn": sum(instance.fn for instance in dataset_performance),
        "coverage_fp": sum(instance.fp for instance in dataset_performance),
        "true_positives": sum(instance.tp for instance in dataset_performance),
    }

    print("topic:", topic)
    print("metrics:", metrics_dict)
    print("\n\n")