In [None]:
import os
working_dir = "/home/gpinon/more_europa/clean_rdc_experiments/projects/P04_official_reg_db_creation"
os.chdir(working_dir)
print(f"Changed working directory to {working_dir}")
import logging
import time
import pandas as pd
import json
from pathlib import Path
from dotenv import load_dotenv
import weaviate

from src.p04_official_reg_db_creation import config
import llm_backends
from llm_backends.cache import DiskCacheStorage
from llm_backends.mistral import dummy_config
from llm_backends.openai import dummy_config

# 0. Raw publications loading

In [None]:
# track time of execution
start_time = time.time()

weaviate_client = weaviate.connect_to_custom(**config.WEAVIATE_PROD_CONF)
collections = weaviate_client.collections  #
# load publications
collection_publications = collections.get("Publication_v2")
# load data source names
items = []
for item in collection_publications.iterator(include_vector=False):
    # Extract subset of properties
    items.append(
        {
            k: v
            for k, v in item.properties.items()
            if k
            in [
                "object_id",
                "title",
                "abstract",
                "data_source_name",
            ]
        }
    )
# close weaviate connection
weaviate_client.close()
# df_items = pd.DataFrame(items)

In [None]:
import random

random.seed(42)

batch_size = 2000
shuffled_items = items.copy()
random.shuffle(shuffled_items)

# ensure output directory exists
output_dir = Path("data/from_scripts/SW01/nbtk_testing/R04_extraction_all/0_raw_publications")
output_dir.mkdir(parents=True, exist_ok=True)

batch_number = 1
for i in range(0, len(shuffled_items), batch_size):
    batch_items = shuffled_items[i : i + batch_size]
    output_file = output_dir / f"{batch_number}.jsonl"
    with output_file.open("w") as f:
        for item in batch_items:
            f.write(json.dumps(item) + "\n")
    batch_number += 1

print(f"Loading time: {(time.time() - start_time):.2f} seconds")

In [None]:
load_dotenv()
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")

# 1. First filter publications that are not registry related

In [None]:
FIELD = "registry_related"
MODEL = "small_mistral"

# INPUTS
raw_publications_dir = "data/from_scripts/SW01/nbtk_testing/R04_extraction_all/0_raw_publications"  # Directory containing batch files
prompt_txt = f"etc/prompts/extraction/prompt_{FIELD}.txt"
model_config = f"etc/configs/{MODEL}_config.json"

# OUTPUTS - template paths that will be formatted with batch number
output_registry_related_template = "data/from_scripts/SW01/nbtk_testing/R04_extraction_all/01_{FIELD}_publications_{batch}.json"
output_all_records_template = "data/from_scripts/SW01/nbtk_testing/R04_extraction_all/01_{FIELD}_raw_inferences_{batch}.json"

In [None]:
# Load model configuration
with open(model_config, "r", encoding="utf-8") as f:
    model_cfg = json.load(f)

model_name = model_cfg.get("model", "unknown")
print(f"Using model: {model_name}")

# Load the annotation prompt
with open(prompt_txt, "r", encoding="utf-8") as f:
    annotation_prompt = f.read().strip()

# Get list of all batch files
batch_files = sorted(
    Path(raw_publications_dir).glob("*.jsonl"), key=lambda p: int(p.stem)
)
# limit to 1 batch for testing
batch_files = batch_files[:1]  # Uncomment to limit to 1 batch for testing
print(f"Found {len(batch_files)} batch files to process")

# Create the base output directory
out_dir = Path(output_registry_related_template.format(FIELD=FIELD, batch=1)).parent
out_dir.mkdir(parents=True, exist_ok=True)

# Process each batch file separately
total_registry_related = 0
total_not_registry_related = 0
total_records = 0

for batch_file in batch_files:
    batch_num = batch_file.stem  # Get batch number from filename
    print(f"\nProcessing batch {batch_num}...")

    # Format output paths for this batch
    output_registry_related_records = output_registry_related_template.format(
        FIELD=FIELD, batch=batch_num
    )
    output_all_records_jsonl = output_all_records_template.format(
        FIELD=FIELD, batch=batch_num
    )

    # Load records from this batch
    records = []
    with open(batch_file, "r") as file:
        for line in file:
            record = json.loads(line)
            object_id = record.get("object_id", "<unknown>")
            records.append(record)
    print(f"Loaded {len(records)} records from batch {batch_num}")

    # Prepare prompts for LLMs
    prompts_items = []
    records = records[:5] # Uncomment to limit to 5 records for testing
    for rec in records:
        object_id = rec.get("object_id", "<unknown>")
        title = rec.get("title", "<no title>")
        abstract = rec.get("abstract", "<no abstract>")
        full_prompt = f"{annotation_prompt}\nText_to_analyze:\nTitle: {title}\nAbstract: {abstract}"
        prompts_items.append({"prompt": full_prompt, "custom_id": object_id})

    # Create a list to store the records with object_id, prompt, and raw response
    prompt_response_records = []

    # Run batch inference
    start_time = time.time()
    print(f"Starting batch inference for batch {batch_num} with {model_name}...")
    llm_responses = []

    is_openai_model = "openai" in model_config.lower()
    is_mistral_model = "istral" in model_config.lower()
    if is_mistral_model:
        backend = llm_backends.MistralBatchBackend(
            api_key=os.getenv("MISTRAL_API_KEY"), cache_storage=DiskCacheStorage()
        )
    elif is_openai_model:
        backend = llm_backends.OpenAIAsyncBackend(
            api_key=os.getenv("OPENAI_API_KEY"), cache_storage=DiskCacheStorage()
        )

    raw_responses = backend.infer_many(
        prompt_items=prompts_items,
        model_config=model_cfg,
    )

    for raw_response in raw_responses:
        # Store the raw response with object_id and prompt for the records file
        prompt_obj = next(
            (p for p in prompts_items if p["custom_id"] == raw_response["custom_id"]),
            None,
        )
        if prompt_obj:
            prompt_response_records.append(
                {
                    "object_id": raw_response["custom_id"],
                    "prompt": prompt_obj["prompt"],
                    "llm_response": raw_response,
                }
            )
            # parse raw response
            parsed_response = backend._parse_response(raw_response)
            parsed_response["custom_id"] = raw_response.get("custom_id", "")
            llm_responses.append(parsed_response)

    elapsed_total = time.time() - start_time
    print(
        f"Batch {batch_num} inference completed with {len(llm_responses)} responses in {elapsed_total:.2f} seconds"
    )

    # Build results for this batch
    results = []
    for rec in records:
        object_id = rec.get("object_id", "<unknown>")
        title = rec.get("title", "<no title>")
        abstract = rec.get("abstract", "<no abstract>")

        resp = next((x for x in llm_responses if x["custom_id"] == object_id), None)

        if resp:
            results.append(
                {
                    "object_id": object_id,
                    "title": title,
                    "abstract": abstract,
                    "llm_response": resp,
                }
            )

    # Count statistics for this batch
    registry_related_records = [
        rec for rec in results if rec["llm_response"].get("registry_related") == "yes"
    ]
    not_registry_related_records = [
        rec for rec in results if rec["llm_response"].get("registry_related") == "no"
    ]
    print(
        f"Batch {batch_num}: Found {len(registry_related_records)} records related to registry"
    )
    print(
        f"Batch {batch_num}: Found {len(not_registry_related_records)} records not related to registry"
    )

    batch_total = len(results)
    rel_percentage = (
        len(registry_related_records) / batch_total * 100 if batch_total > 0 else 0
    )
    print(f"Batch {batch_num}: {rel_percentage:.2f}% of records are registry-related")

    # Update totals
    total_registry_related += len(registry_related_records)
    total_not_registry_related += len(not_registry_related_records)
    total_records += batch_total

    # Save the batch results
    with open(output_registry_related_records, "w", encoding="utf-8") as f:
        json.dump(registry_related_records, f, indent=4, ensure_ascii=False)

    with open(output_all_records_jsonl, "w", encoding="utf-8") as fp:
        for record in prompt_response_records:
            fp.write(json.dumps(record) + "\n")

    print(f"Saved batch {batch_num} results to {output_registry_related_records}")
    print(f"Saved batch {batch_num} raw inferences to {output_all_records_jsonl}")

# Print overall statistics
print("\n--- Overall Statistics ---")
print(f"Total records processed: {total_records}")
print(f"Total registry-related records: {total_registry_related}")
print(f"Total not registry-related records: {total_not_registry_related}")
if total_records > 0:
    print(
        f"Overall percentage of registry-related records: {total_registry_related / total_records * 100:.2f}%"
    )

# 2. Registry Name extractions

In [None]:
FIELD = "registry_name"
MODEL = "small_mistral"

# INPUTS
registry_related_publications = "data/from_scripts/SW01/nbtk_testing/R04_extraction_all/01_registry_related_publications.json"
prompt_txt = prompt_txt=f"etc/prompts/extraction/prompt_{FIELD}.txt"
model_config=f"etc/configs/{MODEL}_config.json"

# OUTPUTS
output_json = f"data/from_scripts/SW01/nbtk_testing/R04_extraction_all/02_{FIELD}_extractions.json"
output_records_jsonl = f"data/from_scripts/SW01/nbtk_testing/R04_extraction_all/02_{FIELD}_raw_inferences.jsonl"
output_registry_names_list = f"data/from_scripts/SW01/nbtk_testing/R04_extraction_all/02_{FIELD}_list.json"

In [None]:
# Ensure output directory exists
out_dir = Path(output_json).parent
out_dir.mkdir(parents=True, exist_ok=True)

# Ensure output records directory exists
records_dir = Path(output_records_jsonl).parent
records_dir.mkdir(parents=True, exist_ok=True)

# Load model configuration
with open(model_config, "r", encoding="utf-8") as f:
    model_cfg = json.load(f)

model_name = model_cfg.get("model", "unknown")
print(f"Using model: {model_name}")

# Load the annotation prompt
with open(prompt_txt, "r", encoding="utf-8") as f:
    annotation_prompt = f.read().strip()

# Load and filter PubMed records
records = []
with open(registry_related_publications, "r") as file:
    for line in file:
        record = json.loads(line)
        object_id = record.get("object_id", "<unknown>")
        records.append(record)
        
print(f"Loaded {len(records)} records for registry_related filter")

In [None]:
# Prepare prompts for LLMs
prompts_items = []
# records = records[:5] # Limit to first 5 records for testing
for rec in records:
    object_id = rec.get("object_id", "<unknown>")
    title = rec.get("title", "<no title>")
    abstract = rec.get("abstract", "<no abstract>")
    full_prompt = f"{annotation_prompt}\nText_to_analyze:\nTitle: {title}\nAbstract: {abstract}"
    prompts_items.append({"prompt": full_prompt, "custom_id": object_id})

# Create a list to store the records with object_id, prompt, and raw response
prompt_response_records = []

In [None]:
start_time = time.time()
# Run batch inference based on model type
print(f"Starting batch inference with {model_name}...")
llm_responses = []

is_openai_model = "openai" in model_config.lower()
# if "istral" in the name of llm_judge_model_config, then we are using Mistral model
is_mistral_model = "istral" in model_config.lower()
if is_mistral_model:
    backend = llm_backends.MistralBatchBackend(
        api_key=os.getenv("MISTRAL_API_KEY"), cache_storage=DiskCacheStorage()
    )
elif is_openai_model:
    backend = llm_backends.OpenAIAsyncBackend(
        api_key=os.getenv("OPENAI_API_KEY"), cache_storage=DiskCacheStorage()
    )

raw_responses = backend.infer_many(
    prompt_items=prompts_items,
    model_config=model_cfg,
)

for raw_response in raw_responses:
    # Store the raw response with object_id and prompt for the records file
    prompt_obj = next(
        (p for p in prompts_items if p["custom_id"] == raw_response["custom_id"]), None
    )
    if prompt_obj:
        prompt_response_records.append(
            {
                "object_id": raw_response["custom_id"],
                "prompt": prompt_obj["prompt"],
                "llm_response": raw_response,
            }
        )
        # parse raw response
        parsed_response = backend._parse_response(raw_response)
        parsed_response["custom_id"] = raw_response.get("custom_id", "")
        # print(response)
        llm_responses.append(parsed_response)

print(f"Batch inference completed with {len(llm_responses)} responses")
elapsed_total = time.time() - start_time
print(f"Total time for batch inference: {elapsed_total:.2f} seconds")

In [None]:
# Build results DataFrame
results = []
for rec in records:
    object_id = rec.get("object_id", "<unknown>")
    title = rec.get("title", "<no title>")
    abstract = rec.get("abstract", "<no abstract>")

    resp = next((x for x in llm_responses if x["custom_id"] == object_id), None)

    if resp:
        results.append(
            {
                "object_id": object_id,
                "title": title,
                "abstract": abstract,
                "llm_response": resp,
            }
        )

# Save to JSON
with open(output_json, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=4, ensure_ascii=False)

print(f"Saved metadata extractions to {output_json}")

# Save records to JSONL file
with open(output_records_jsonl, "w", encoding="utf-8") as fp:
    for record in prompt_response_records:
        fp.write(json.dumps(record) + "\n")
print(f"Saved {len(prompt_response_records)} records to {output_records_jsonl}")

In [None]:
# reload the results from the output_json
with open(output_json, "r", encoding="utf-8") as f:
    publications_w_list_of_registries = json.load(f)

# save the registry names instead of the publications
registry_name_list = []
# loop on the publications
for publication in publications_w_list_of_registries:
    # get the object_id and llm_response["List of Registry names"]
    object_id = publication["object_id"]
    list_of_registries = publication["llm_response"].get("List of Registry names", [])
    index = 0
    for registry in list_of_registries:
        registry_name_list.append(
            {
                'index': index,
                "registry_name": registry.get('registry_name', ''),
                "acronym": registry.get('acronym', ''),
                "is_official": registry.get('is_official', ''),
                "object_id": object_id,
            }
        )
        index += 1

# print the number of registry names found using final index value
print(f"Found {index} registry names in total")

# Save the registry names list to a JSON file
with open(output_registry_names_list, "w", encoding="utf-8") as f:
    json.dump(registry_name_list, f, indent=4, ensure_ascii=False)

In [None]:
# === Define Inputs and Outputs for Registry Name Extractions ===

import json
import glob
from pathlib import Path

# INPUTS:
# Pattern to load all registry-related publication files (from Section 1 outputs)
registry_publications_pattern = "data/from_scripts/SW01/nbtk_testing/R04_extraction_all/01_registry_related_publications_*.json"
print("Using registry publication files matching:", registry_publications_pattern)

# Batching parameter for the new aggregation
batch_size_registry = 2000

# Model and prompt specifications:
FIELD = "registry_name"
MODEL = "small_mistral"
prompt_txt = f"etc/prompts/extraction/prompt_{FIELD}.txt"
model_config = f"etc/configs/{MODEL}_config.json"

# OUTPUTS:
# Directories for saving the new batches and corresponding outputs
raw_inferences_dir = Path(f"data/from_scripts/SW01/nbtk_testing/R04_extraction_all/02_{FIELD}_raw_inferences")
parsed_extractions_dir = Path(f"data/from_scripts/SW01/nbtk_testing/R04_extraction_all/02_{FIELD}_extractions")
registry_list_dir = Path(f"data/from_scripts/SW01/nbtk_testing/R04_extraction_all/02_{FIELD}_list")

# Create the output directories if they do not exist
raw_inferences_dir.mkdir(parents=True, exist_ok=True)
parsed_extractions_dir.mkdir(parents=True, exist_ok=True)
registry_list_dir.mkdir(parents=True, exist_ok=True)

print("Output directories:")
print(" Raw extractions:", raw_inferences_dir)
print(" Parsed extractions:", parsed_extractions_dir)
print(" Registry names list:", registry_list_dir)

In [None]:
# Load inputs
registry_files = sorted(glob.glob(registry_publications_pattern))
print(f"Found {len(registry_files)} registry-related publication files.")

all_registry_publications = []
for file in registry_files:
    with open(file, "r", encoding="utf-8") as f:
        publications = json.load(f)
        all_registry_publications.extend(publications)
print(f"Aggregated {len(all_registry_publications)} registry-related publications.")

# Load model configuration and annotation prompt
with open(model_config, "r", encoding="utf-8") as f:
    model_cfg = json.load(f)
model_name = model_cfg.get("model", "unknown")
print(f"Using model: {model_name}")

with open(prompt_txt, "r", encoding="utf-8") as f:
    annotation_prompt = f.read().strip()

In [None]:
import random

random.seed(42)
random.shuffle(all_registry_publications)
batches = [all_registry_publications[i : i + batch_size_registry] 
           for i in range(0, len(all_registry_publications), batch_size_registry)]
print(f"Divided into {len(batches)} batches.")

In [None]:
# For each new batch, run the registry name extraction inference and save outputs.
for idx, batch in enumerate(batches, start=1):
    print(f"\nProcessing extraction batch {idx} with {len(batch)} records...")
    
    # Prepare prompt items for inference
    prompts_items = []
    for rec in batch:
        object_id = rec.get("object_id", "<unknown>")
        title = rec.get("title", "<no title>")
        abstract = rec.get("abstract", "<no abstract>")
        full_prompt = f"{annotation_prompt}\nText_to_analyze:\nTitle: {title}\nAbstract: {abstract}"
        prompts_items.append({"prompt": full_prompt, "custom_id": object_id})
    
    # Lists for storing raw and parsed responses
    prompt_response_records = []
    llm_responses = []
    
    start_time = time.time()
    print(f"Starting batch {idx} registry inference with {model_name}...")

    # Instantiate the backend based on the model configuration  
    is_openai_model = "openai" in model_config.lower()
    is_mistral_model = "istral" in model_config.lower()
    if is_mistral_model:
        backend = llm_backends.MistralBatchBackend(
            api_key=os.getenv("MISTRAL_API_KEY"), cache_storage=DiskCacheStorage()
        )
    elif is_openai_model:
        backend = llm_backends.OpenAIAsyncBackend(
            api_key=os.getenv("OPENAI_API_KEY"), cache_storage=DiskCacheStorage()
        )
    
    # Run inference
    raw_responses = backend.infer_many(
        prompt_items=prompts_items,
        model_config=model_cfg,
    )
    
    # Parse and record responses
    for raw_response in raw_responses:
        prompt_obj = next((p for p in prompts_items if p["custom_id"] == raw_response["custom_id"]), None)
        if prompt_obj:
            prompt_response_records.append({
                "object_id": raw_response["custom_id"],
                "prompt": prompt_obj["prompt"],
                "llm_response": raw_response,
            })
            parsed_response = backend._parse_response(raw_response)
            parsed_response["custom_id"] = raw_response.get("custom_id", "")
            llm_responses.append(parsed_response)
    
    elapsed = time.time() - start_time
    print(f"Batch {idx} completed with {len(llm_responses)} responses in {elapsed:.2f} seconds")
    
    # Build parsed results by matching each publication with its inference response
    results = []
    for rec in batch:
        object_id = rec.get("object_id", "<unknown>")
        title = rec.get("title", "<no title>")
        abstract = rec.get("abstract", "<no abstract>")
        resp = next((x for x in llm_responses if x["custom_id"] == object_id), None)
        if resp:
            results.append({
                "object_id": object_id,
                "title": title,
                "abstract": abstract,
                "llm_response": resp,
            })
    
    # Extract registry names list from parsed responses
    registry_name_list = []
    for publication in results:
        object_id = publication["object_id"]
        list_of_registries = publication["llm_response"].get("List of Registry names", [])
        for i, registry in enumerate(list_of_registries):
            registry_name_list.append({
                "index": i,
                "registry_name": registry.get("registry_name", ""),
                "acronym": registry.get("acronym", ""),
                "is_official": registry.get("is_official", ""),
                "object_id": object_id,
            })
    
    # Define output file paths for the current batch using the naming convention
    raw_out_file = raw_inferences_dir / f"{idx}.jsonl"
    parsed_out_file = parsed_extractions_dir / f"{idx}.json"
    list_out_file = registry_list_dir / f"{idx}.json"
    
    # Save raw extractions (each record on a new line in a JSONL file)
    with open(raw_out_file, "w", encoding="utf-8") as f:
        for rec in prompt_response_records:
            f.write(json.dumps(rec) + "\n")
    print(f"Saved raw extractions to {raw_out_file}")
    
    # Save parsed extractions (a JSON file with an array of records)
    with open(parsed_out_file, "w", encoding="utf-8") as f:
        json.dump(results, f, indent=4, ensure_ascii=False)
    print(f"Saved parsed extractions to {parsed_out_file}")
    
    # Save extracted registry names list (JSON file)
    with open(list_out_file, "w", encoding="utf-8") as f:
        json.dump(registry_name_list, f, indent=4, ensure_ascii=False)
    print(f"Saved registry names list to {list_out_file}")