In [None]:
import os
working_dir = "/home/gpinon/more_europa/clean_rdc_experiments/projects/P04_official_reg_db_creation"
os.chdir(working_dir)
print(f"Changed working directory to {working_dir}")
import logging
import time
import pandas as pd
import json
from pathlib import Path
from dotenv import load_dotenv

from src.p04_official_reg_db_creation import config
import llm_backends
from llm_backends.cache import DiskCacheStorage
from llm_backends.mistral import dummy_config
from llm_backends.openai import dummy_config

In [None]:
FIELD = "medical_condition"
MODEL = "small_mistral"

# Load environment variables from .env file and get API key
load_dotenv()
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [None]:
# INPUTS
base_pubmed_dataset_jsonl = 
registry_related_dataset_json = f"../../datasets/006_evaluate_extraction_process_datasets/registry_related/final_{DATASET_TYPE}_dataset.json"
prompt_txt = prompt_txt=f"etc/prompts/extraction/prompt_{FIELD}.txt"
model_config=f"etc/configs/{MODEL}_config.json"

In [None]:
# OUTPUTS
output_json = f"data/from_notebooks/NW01/R01_extraction/{MODEL}/{FIELD}/{FIELD}_extractions.json"
output_records_jsonl = f"data/from_notebooks/NW01/R01_extraction/{MODEL}/{FIELD}/{FIELD}_extractions_records.json"

In [None]:
# Ensure output directory exists
out_dir = Path(output_json).parent
out_dir.mkdir(parents=True, exist_ok=True)

# Ensure output records directory exists
records_dir = Path(output_records_jsonl).parent
records_dir.mkdir(parents=True, exist_ok=True)

# Load model configuration
with open(model_config, "r", encoding="utf-8") as f:
    model_cfg = json.load(f)

model_name = model_cfg.get("model", "unknown")
print(f"Using model: {model_name}")

# Load the annotation prompt
with open(prompt_txt, "r", encoding="utf-8") as f:
    annotation_prompt = f.read().strip()

# Load registry_names_dataset_json to filter publications
with open(registry_related_dataset_json, "r") as file:
    registry_related_eval_dataset = json.load(file)

# Get object_ids of registry-related publications
registry_related_objectIDs = [
    item["object_id"]
    for item in registry_related_eval_dataset
    if item["correct_registry_related"] == "yes"
]
print(
    f"Found {len(registry_related_objectIDs)} registry-related publications"
)

# Load and filter PubMed records
records = []
with open(base_pubmed_dataset_jsonl, "r") as file:
    for line in file:
        record = json.loads(line)
        object_id = record.get("object_id", "<unknown>")
        if object_id in registry_related_objectIDs:
            records.append(record)
print(f"Loaded and filtered to {len(records)} registry-related records")

In [None]:
# Prepare prompts for LLMs
prompts_items = []
# records = records[:5] # Limit to first 5 records for testing
for rec in records:
    object_id = rec.get("object_id", "<unknown>")
    title = rec.get("title", "<no title>")
    abstract = rec.get("abstract", "<no abstract>")
    full_prompt = f"{annotation_prompt}\nText_to_analyze:\nTitle: {title}\nAbstract: {abstract}"
    prompts_items.append({"prompt": full_prompt, "custom_id": object_id})

# Create a list to store the records with object_id, prompt, and raw response
prompt_response_records = []

In [None]:
start_time = time.time()
# Run batch inference based on model type
print(f"Starting batch inference with {model_name}...")
llm_responses = []

is_openai_model = "openai" in model_config.lower()
# if "istral" in the name of llm_judge_model_config, then we are using Mistral model
is_mistral_model = "istral" in model_config.lower()
if is_mistral_model:
    backend = llm_backends.MistralBatchBackend(
        api_key=os.getenv("MISTRAL_API_KEY"), cache_storage=DiskCacheStorage()
    )
elif is_openai_model:
    backend = llm_backends.OpenAIAsyncBackend(
        api_key=os.getenv("OPENAI_API_KEY"), cache_storage=DiskCacheStorage()
    )

raw_responses = backend.infer_many(
    prompt_items=prompts_items,
    model_config=model_cfg,
)

for raw_response in raw_responses:
    # Store the raw response with object_id and prompt for the records file
    prompt_obj = next(
        (p for p in prompts_items if p["custom_id"] == raw_response["custom_id"]), None
    )
    if prompt_obj:
        prompt_response_records.append(
            {
                "object_id": raw_response["custom_id"],
                "prompt": prompt_obj["prompt"],
                "llm_response": raw_response,
            }
        )
        # parse raw response
        parsed_response = backend._parse_response(raw_response)
        parsed_response["custom_id"] = raw_response.get("custom_id", "")
        # print(response)
        llm_responses.append(parsed_response)

print(f"Batch inference completed with {len(llm_responses)} responses")
elapsed_total = time.time() - start_time
print(f"Total time for batch inference: {elapsed_total:.2f} seconds")

In [None]:
# llm_responses

In [None]:
# Build results DataFrame
results = []
for rec in records:
    object_id = rec.get("object_id", "<unknown>")
    title = rec.get("title", "<no title>")
    abstract = rec.get("abstract", "<no abstract>")

    resp = next((x for x in llm_responses if x["custom_id"] == object_id), None)

    if resp:
        results.append(
            {
                "object_id": object_id,
                "title": title,
                "abstract": abstract,
                "llm_response": resp,
            }
        )

# Save to JSON
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=4, ensure_ascii=False)

print(f"Saved metadata extractions to {output_json}")

# Save records to JSONL file
with open(output_records_jsonl, "w", encoding="utf-8") as fp:
    for record in prompt_response_records:
        fp.write(json.dumps(record) + "\n")
print(f"Saved {len(prompt_response_records)} records to {output_records_jsonl}")