In [None]:
import os
working_dir = "/home/gpinon/more_europa/clean_rdc_experiments/projects/P04_official_reg_db_creation"
os.chdir(working_dir)
print(f"Changed working directory to {working_dir}")
import logging
import time
import pandas as pd
import json
from pathlib import Path
from dotenv import load_dotenv

from src.p04_official_reg_db_creation import config
import llm_backends
from llm_backends.cache import DiskCacheStorage
from llm_backends.mistral import dummy_config
from llm_backends.openai import dummy_config
from p04_official_reg_db_creation.config import MAPPING, INVERSE_MAPPING

In [None]:
DATASET_TYPE ="eval" # "test" # 
RAW_PUBLICATIONS_DICT = {
    "eval": "../../datasets/001_publications_dataset/publications_dataset.jsonl",
    "test": "../../datasets/001_publications_dataset/prod_publication_test_dataset.jsonl",
}
FIELD = "registry_name"
MODEL = "small_mistral"

# Load environment variables from .env file and get API key
load_dotenv()
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [None]:
metadata_extractions=f"data/from_notebooks/NW01/R01_extraction/{MODEL}/{FIELD}/{FIELD}_extractions.json"
eval_dataset_json=f"../../datasets/005_evaluate_extraction_process_datasets/{FIELD}/final_eval_dataset.json"
llm_judge_model_config= "etc/configs/gpt4o_openai_config.json" # "etc/configs/large_mistral_config.json" # "etc/configs/gpt4_1_openai_config.json" # 
prompt_llm_judge = f"etc/prompts/llm_as_a_judge/compare_{FIELD}.txt"

In [None]:
output_json=f"data/from_notebooks/NW01/R02_comparison/{MODEL}/{FIELD}/compare_{FIELD}.json"

In [None]:
# Ensure output directory exists
out_dir = Path(output_json).parent
out_dir.mkdir(parents=True, exist_ok=True)

# Get the mapped field name if it exists
field_name = INVERSE_MAPPING.get(FIELD, FIELD)
print(f"Processing field: {FIELD} (mapped to {field_name})")

In [None]:
with open(prompt_llm_judge, "r", encoding="utf-8") as f:
    prompt_template = f.read().strip()

# Load the model configuration
# if "openai" in the name of llm_judge_model_config, then we are using OpenAI model
is_openai_model = "openai" in llm_judge_model_config.lower()
# if "istral" in the name of llm_judge_model_config, then we are using Mistral model
is_mistral_model = "istral" in llm_judge_model_config.lower()

with open(llm_judge_model_config, "r", encoding="utf-8") as f:
    judge_model_cfg = json.load(f)
print(f"Using model config: {judge_model_cfg.get('model', 'unknown')}")

# Load metadata extractions
with open(metadata_extractions, "r", encoding="utf-8") as f:
    extraction_records = json.load(f)
print(f"Loaded {len(extraction_records)} records from metadata extractions")

# Load evaluation dataset
with open(eval_dataset_json, "r", encoding="utf-8") as f:
    eval_records = json.load(f)
print(f"Loaded {len(eval_records)} records from evaluation dataset")

# Filter out records that need manual annotation
filtered_eval_records = [
    r for r in eval_records if not r.get("needs_manual_annotation", False)
]
print(f"Filtered to {len(filtered_eval_records)} records with completed annotations")

# Create a lookup dictionary for evaluation records by object_id
eval_lookup = {r["object_id"]: r for r in filtered_eval_records}

In [None]:
# Prepare records for processing
processed_records = []
llm_judge_prompts = []

# Statistics counters
stats = {
    "total_extracted": len(extraction_records),
    "in_eval_dataset": 0,
    "model_agreement": 0,
    "one_model_unspecified": 0,
    "need_llm_judge": 0,
    "llm_same": 0,
    "llm_different": 0,
}

correct_field_col = f"correct_{FIELD}"

In [None]:
# Process each extraction record
for rec in extraction_records:
    object_id = rec.get("object_id")

    # Check if this record exists in the evaluation dataset
    if object_id not in eval_lookup:
        continue

    stats["in_eval_dataset"] += 1
    eval_rec = eval_lookup[object_id]

    # Get the model's extracted value for this field
    model_response = rec.get("llm_response", {})

    inferred_list = model_response.get("List of Registry names", "Not found")
    # print(f"Processing record {object_id} with inferred_list: {inferred_list}")
    # inferred_value is a string of all the registry_name + (acronym) separated by commas
    if isinstance(inferred_list, list):
        inferred_value = ""
        for registry in inferred_list:
            # if acronym is not emrty string ""
            # print(f"Processing registry: {registry}")
            if registry["acronym"]!= "":
                inferred_value += f"{registry['registry_name']} ({registry['acronym']}), "
            else:
                inferred_value += f"{registry['registry_name']}, "
        # Remove the trailing comma and space
        inferred_value = inferred_value.rstrip(", ")
            
    correct_value = eval_rec.get(correct_field_col, "Not found")
    # print both values to compare
    # print(f"Object ID: {object_id}, Inferred: {inferred_value}, Correct: {correct_value}")
    # Create the basic record structure
    output_record = {
        "object_id": object_id,
        "title": rec.get("title", ""),
        "abstract": rec.get("abstract", ""),
        f"inferred_{FIELD}": inferred_value,
        correct_field_col: correct_value,
    }

    # Check for exact match (case insensitive)
    if inferred_value.lower() == correct_value.lower():
        output_record["final_label"] = 1
        output_record["labeling_reason"] = "model_agreement"
        stats["model_agreement"] += 1
        processed_records.append(output_record)
        continue

    # Check if one is "Not specified"
    inferred_unspecified = (
        inferred_value.lower() == "not found"
        or inferred_value.lower() == "not specified"
    )
    correct_unspecified = (
        correct_value.lower() == "not found"
        or correct_value.lower() == "not specified"
    )

    if (inferred_unspecified and not correct_unspecified) or (
        not inferred_unspecified and correct_unspecified
    ):
        output_record["final_label"] = 0
        output_record["labeling_reason"] = "one_model_unspecified"
        stats["one_model_unspecified"] += 1
        processed_records.append(output_record)
        continue

    # Need LLM judge
    stats["need_llm_judge"] += 1

    # Prepare prompt for LLM judge
    full_prompt = prompt_template.replace("{{content_a}}", inferred_value)
    full_prompt = full_prompt.replace("{{content_b}}", correct_value)
    # print the last 100 characters of the prompt
    print('----')
    # print all characters of the prompt after '</example6>'
    stop = full_prompt.find('</example6>') + 13
    print(full_prompt[stop:])
    llm_judge_prompts.append({"prompt": full_prompt, "custom_id": object_id})

    # Save the record for later updating with LLM judgment
    processed_records.append(output_record)


In [None]:
start_time = time.time()
# print running LLM as a judge with model
print(f"Running LLM as a judge with model: {judge_model_cfg.get('model', 'unknown')}")
# Run LLM judgments if needed
if llm_judge_prompts:
    print(f"Running LLM judgment for {len(llm_judge_prompts)} records")
    if is_openai_model:
        # Initialize the OpenAI backend
        backend = llm_backends.OpenAIAsyncBackend(
            api_key=OPENAI_API_KEY, cache_storage=DiskCacheStorage()
        )
    elif is_mistral_model:
        # Initialize the Mistral backend
        backend = llm_backends.MistralBatchBackend(
            api_key=MISTRAL_API_KEY, cache_storage=DiskCacheStorage()
        )
    else:
        raise ValueError("Unsupported model type. Please use OpenAI or Mistral models.")

    # Perform batch inference
    judge_results = backend.infer_many(llm_judge_prompts, judge_model_cfg)

    # Process LLM judge results
    for result in judge_results:
        parsed_response = backend._parse_response(result)
        parsed_response["custom_id"] = result.get("custom_id", "")
        object_id = result["custom_id"]

        # Find the corresponding record
        for record in processed_records:
            if record["object_id"] == object_id and "final_label" not in record:
                if parsed_response["final_decision"].lower() == "same":
                    record["final_label"] = 1
                    record["labeling_reason"] = "llm_same"
                    stats["llm_same"] += 1
                else:
                    record["final_label"] = 0
                    record["labeling_reason"] = "llm_different"
                    stats["llm_different"] += 1

                # Store LLM explanation
                record["llm_explanation"] = parsed_response["explanation"]
                break

    # Ensure all records have the necessary fields
    for record in processed_records:
        if "final_label" not in record:
            record["final_label"] = (
                0  # Default to not matching if we couldn't determine
            )
            record["labeling_reason"] = "undetermined"


elapsed_time = time.time() - start_time
print(f"Total processing time: {elapsed_time:.2f} seconds")

In [None]:
processed_records

In [None]:
# Log statistics
print("Comparison statistics:")
for key, value in stats.items():
    print(f"  {key}: {value}")

# Count the final labels
positive_labels = sum(1 for r in processed_records if r.get("final_label") == 1)
negative_labels = sum(1 for r in processed_records if r.get("final_label") == 0)
print(
    f"Final labels: {positive_labels} positive, {negative_labels} negative"
)

# Detailed breakdown of LLM judge results
if stats["need_llm_judge"] > 0:
    llm_same_percent = (stats["llm_same"] / stats["need_llm_judge"]) * 100
    llm_different_percent = (stats["llm_different"] / stats["need_llm_judge"]) * 100
    print(
        f"LLM judge breakdown: {stats['llm_same']} same ({llm_same_percent:.1f}%), {stats['llm_different']} different ({llm_different_percent:.1f}%)"
    )

# Save the results to JSON
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(processed_records, f, indent=4, ensure_ascii=False)
print(f"Saved comparison results to {output_json}")

# Save to Excel for easier viewing
output_excel = output_json.replace(".json", ".xlsx")
pd.DataFrame(processed_records).to_excel(output_excel, index=False)
print(f"Saved comparison results to {output_excel}")