In [1]:
import os
working_dir = "/home/gpinon/more_europa/clean_rdc_experiments/projects/P04_official_reg_db_creation"
os.chdir(working_dir)
print(f"Changed working directory to {working_dir}")
import logging
import time
import pandas as pd
import json
from pathlib import Path
from dotenv import load_dotenv
import pycountry
import weaviate
import boto3
from tqdm import tqdm

from src.p04_official_reg_db_creation import config
from src.p04_official_reg_db_creation.utils  import load_jsonl_from_s3, upload_jsonl_to_s3
import llm_backends
from llm_backends.cache import DiskCacheStorage
from llm_backends.mistral import dummy_config
from llm_backends.openai import dummy_config

Changed working directory to /home/gpinon/more_europa/clean_rdc_experiments/projects/P04_official_reg_db_creation


# 0. Define variables and load data

In [2]:
FIELD = "medical_condition"
MODEL = "small_mistral"

# Load environment variables from .env file and get API key
load_dotenv()
MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")

In [3]:
# INPUTS
input_registry_data_jsonl_template = "registry_data_catalog_experiments/P04_official_reg_db_creation/registries_dataset_version2/v4/registry_dataset_with_publis_metadata/1.jsonl"
collection = "Publication_v2"
prompt_txt = prompt_txt=f"etc/prompts/extraction/prompt_{FIELD}.txt"
model_config=f"etc/configs/{MODEL}_config.json"

In [4]:
# OUTPUTS
local_output_raw_inferences_jsonl_template="data/from_scripts/SW01/R05_update_medical_condition/raw_inferences/1.jsonl"
local_output_publis_jsonl_template="data/from_scripts/SW01/R05_update_medical_condition/publi_data/1.jsonl"
local_output_registries_jsonl_template="data/from_scripts/SW01/R05_update_medical_condition/registry_data/1.jsonl"
s3_output_publis_dir = "registry_data_catalog_experiments/P04_official_reg_db_creation/datasets_versions/update_medical_condition/publi_data"
s3_output_registries_dir = "registry_data_catalog_experiments/P04_official_reg_db_creation/datasets_versions/update_medical_condition/registry_data"

## a. Load the registry data from s3 bucket

In [5]:
# load the official registry data
input_dir = Path(input_registry_data_jsonl_template).parent
input_dir_str = str(input_dir)
# retrieve total_batches = how many files are in the folder
s3 = boto3.client("s3")
response = s3.list_objects_v2(Bucket=config.BUCKET_NAME_DEV, Prefix=input_dir_str)
total_batches = len(response.get("Contents", []))
registry_dataset = []
for batch_number in range(1, total_batches + 1):
    file_name = f"{batch_number}.jsonl"
    batch = load_jsonl_from_s3(config.BUCKET_NAME_DEV, input_dir, file_name)
    registry_dataset.extend(batch)
print(f"Loaded {len(registry_dataset)} records from the registry data, from {total_batches} batches.")

2025-07-24 11:21:23,399 - botocore.credentials - INFO - Found credentials from IAM Role: vm-gpinon-role20241129125126670400000001
Loaded 54335 records from the registry data, from 28 batches.


## b. Load the publications data from weaviate

In [6]:
# Add the metadata of the publications, previously extracted from Weaviate, to the registries_dict
# track time of execution
start_time = time.time()

weaviate_client = weaviate.connect_to_custom(**config.WEAVIATE_PROD_CONF)
collections = weaviate_client.collections  #
# load publications
collection_publications = collections.get(collection)
# load data source names
publis_dataset_all = []
for item in collection_publications.iterator(include_vector=False):
    # Extract subset of properties
    publis_dataset_all.append(
        {
            k: v
            for k, v in item.properties.items()
            if k
            in [
                "object_id",
                "title",
                "abstract",
                "geographical_area",
                "medical_condition",
                # "outcome_measure",
                # "population_sex",
                # "population_age_group",
                # "population_size",
                # "population_follow_up",
            ]
        }
    )
# close weaviate connection
weaviate_client.close()

2025-07-24 11:21:30,851 - httpx - INFO - HTTP Request: GET https://weaviate-new.eu-more-europa-gpu.quinten.io/v1/.well-known/openid-configuration "HTTP/1.1 404 Not Found"
2025-07-24 11:21:30,920 - httpx - INFO - HTTP Request: GET https://weaviate-new.eu-more-europa-gpu.quinten.io/v1/meta "HTTP/1.1 200 OK"


2025-07-24 11:21:30,992 - httpx - INFO - HTTP Request: GET https://pypi.org/pypi/weaviate-client/json "HTTP/1.1 200 OK"


## c. Load the rest

In [7]:
# Create the output directories
local_output_raw_inferences_dir = Path(local_output_raw_inferences_jsonl_template).parent
local_output_publis_json_dir = Path(local_output_publis_jsonl_template).parent
local_output_registries_jsonl_dir = Path(local_output_registries_jsonl_template).parent

local_output_raw_inferences_dir.mkdir(parents=True, exist_ok=True)
local_output_publis_json_dir.mkdir(parents=True, exist_ok=True)
local_output_registries_jsonl_dir.mkdir(parents=True, exist_ok=True)

In [8]:
# Load model configuration
with open(model_config, "r", encoding="utf-8") as f:
    model_cfg = json.load(f)

model_name = model_cfg.get("model", "unknown")
print(f"Using model: {model_name}")

# Load the annotation prompt
with open(prompt_txt, "r", encoding="utf-8") as f:
    prompt_template = f.read().strip()

Using model: mistral-small-latest


## d. Filter which publis to process

In [9]:
# Retrieve the list of unique publication ids from registry_dataset
# registry_dataset["list_publi_ids"] is a list of publi ids
registry_publi_ids = set()
for record in registry_dataset:
    if isinstance(record.get("list_publi_ids"), list):
        registry_publi_ids.update(record["list_publi_ids"])
print(f"Found {len(registry_publi_ids)} unique publication IDs in the registry dataset.")
# Filter the publications dataset to only include records with IDs in registry_publi_ids (use 'object_id' as the key of the publication)
publis_dataset = [
    publi for publi in publis_dataset_all if publi["object_id"] in registry_publi_ids
]

Found 136643 unique publication IDs in the registry dataset.


In [10]:
# select first 20 publis for testing
publis_dataset = publis_dataset[:2000]

# 1. Prepare prompts

In [11]:
# Prepare prompts for LLMs
prompts_items = []
total_prompts = 0
for publi in publis_dataset:
    object_id = publi.get("object_id", "<unknown>")
    title = publi.get("title", "<no title>")
    abstract = publi.get("abstract", "<no abstract>")
    full_prompt = f"{prompt_template}\nText_to_analyze:\nTitle: {title}\nAbstract: {abstract}"
    prompts_items.append({"prompt": full_prompt, "custom_id": object_id})
    total_prompts += 1
            
# print the number of prompts prepared
print(f"Prepared {total_prompts} prompts for LLM processing.")

Prepared 2000 prompts for LLM processing.


In [12]:
# show fisrt item of prompts_items
print(f"First prompt item: {prompts_items[0]}")

First prompt item: {'prompt': 'CONTEXT:\nYou are an expert of real-world clinical studies, especially at litterature review. You are provided with a publication\'s title and abstract extracted from PubMed or Semantic Scholar, in which use or analysis of patient/medical registry was proven.\nRegister and registry are synonyms. \n\nDEFINITION:\nIn a clinical study or its publication, a medical condition is typically defined as a specific disease, disorder, health-related state, or procedure (or even environmental/demographic/lifestyle factors) that is the focus of the research and directly characterizes/defines the studied population or cohort\n\nDIFFERENCES.\nThis is to be distinguished from:\n1. Outcome measure / endpoint: Results or effects observed in a study.\n2. Risk factors: Characteristics increasing the likelihood of a condition.\n3. Main variables / descriptive statistics: Primary factors or data summaries analyzed in a study.\n4. Target/study population: The specific group of 

In [13]:
batch_prompts_list = []
batch_size = 1000
for i in range(0, len(prompts_items), batch_size):
    batch_prompts_list.append(prompts_items[i:i + batch_size])
# print how many batches we have
print(f"Total batches created: {len(batch_prompts_list)}")

Total batches created: 2


# 2. Make Inferences

In [14]:
batch_raw_responses_list = []
for batch_prompts in batch_prompts_list:
    # Run batch inference based on model type
    print(f"Starting batch inference with {model_name}...")

    is_openai_model = "openai" in model_config.lower()
    # if "istral" in the name of llm_judge_model_config, then we are using Mistral model
    is_mistral_model = "istral" in model_config.lower()
    if is_mistral_model:
        backend = llm_backends.MistralBatchBackend(
            api_key=os.getenv("MISTRAL_API_KEY"), cache_storage=DiskCacheStorage()
        )
    elif is_openai_model:
        backend = llm_backends.OpenAIAsyncBackend(
            api_key=os.getenv("OPENAI_API_KEY"), cache_storage=DiskCacheStorage()
        )

    batch_raw_responses = backend.infer_many(
        prompt_items=batch_prompts,
        model_config=model_cfg,
    )
    batch_raw_responses_list.append(batch_raw_responses)

Starting batch inference with mistral-small-latest...
2025-07-24 11:22:42,607 - llm_backends.cache.disk.DiskCacheStorage - INFO - Disk cache initialized at: /home/gpinon/more_europa/clean_rdc_experiments/src/llm_backends/llm_backends/.cache
Starting batch inference with mistral-small-latest...
2025-07-24 11:22:42,733 - llm_backends.cache.disk.DiskCacheStorage - INFO - Disk cache initialized at: /home/gpinon/more_europa/clean_rdc_experiments/src/llm_backends/llm_backends/.cache


In [15]:
logging.basicConfig(
    level=logging.WARNING,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.WARNING)

In [16]:
# Create a list to store the records with object_id, prompt, and raw response
prompt_response_records = []

# Precompute a mapping from custom_id to prompt object
prompt_map = {p["custom_id"]: p for p in prompts_items}

batch_number = 1
batch_llm_responses_list = []
initial_time = time.time()

for batch_raw_responses in tqdm(batch_raw_responses_list, desc="Processing batches"):
    print(f"--- Processing Batch N°{batch_number} ---")
    start_time = time.time()
    llm_responses = []
    inference_number = 1
    for raw_response in tqdm(batch_raw_responses, desc=f"Batch {batch_number} processing", leave=False):
        custom_id = raw_response.get("custom_id", "")
        prompt_obj = prompt_map.get(custom_id)
        if prompt_obj:
            prompt_response_records.append({
                "custom_id": custom_id,
                "prompt": prompt_obj["prompt"],
                "llm_response": raw_response,
            })
            # Parse raw response and add additional info
            parsed_response = backend._parse_response(raw_response)
            parsed_response["custom_id"] = custom_id
            llm_responses.append(parsed_response)
        inference_number += 1
    
    elapsed_total = (time.time() - start_time) / 60  # Convert to minutes
    print(f"Batch inference completed with {len(llm_responses)} responses")
    print(f"Total time for inference : {elapsed_total:.1f} minutes\n")
    batch_llm_responses_list.append(llm_responses)
    batch_number += 1

total_computation_time = (time.time() - initial_time) / 60  # Convert to minutes
print(f"--> Total computation time for all batches: {total_computation_time:.1f} minutes <--")

Processing batches:   0%|          | 0/2 [00:00<?, ?it/s]

--- Processing Batch N°1 ---




2025-07-24 11:22:42,911 - llm_backends.cache.disk.DiskCacheStorage - INFO - Attempting to retrieve cache for key: d08b54889a133c93e45588ac98c227b34610fbe26ffdb355f3fd50ee80cdcfaf
2025-07-24 11:22:42,912 - llm_backends.cache.disk.DiskCacheStorage - INFO - Cache file not found for key: d08b54889a133c93e45588ac98c227b34610fbe26ffdb355f3fd50ee80cdcfaf


/home/jovyan/.pyenv/versions/3.11.11/envs/P04_official_reg_db_creation/lib/python3.11/site-packages/mistralai/models/batchjobin.py:42: PydanticDeprecatedSince211: Accessing the 'model_fields' attribute on the instance is deprecated. Instead, you should access this attribute from the model class. Deprecated in Pydantic V2.11 to be removed in V3.0.
  for n, f in self.model_fields.items():
  return self._state == HTTPConnectionState.IDLE


2025-07-24 11:23:05,345 - llm_backends.cache.disk.DiskCacheStorage - INFO - Storing cache for key: d08b54889a133c93e45588ac98c227b34610fbe26ffdb355f3fd50ee80cdcfaf
2025-07-24 11:23:05,375 - llm_backends.cache.disk.DiskCacheStorage - INFO - Cache stored successfully for key: d08b54889a133c93e45588ac98c227b34610fbe26ffdb355f3fd50ee80cdcfaf


Processing batches:  50%|█████     | 1/2 [00:22<00:22, 22.54s/it]

Batch inference completed with 1000 responses
Total time for inference : 0.4 minutes

--- Processing Batch N°2 ---




2025-07-24 11:23:05,453 - llm_backends.cache.disk.DiskCacheStorage - INFO - Attempting to retrieve cache for key: f4cccf89f1103be8443d620b5aa4de103683e4f7b4d3e6893848c0ac71b167fb
2025-07-24 11:23:05,454 - llm_backends.cache.disk.DiskCacheStorage - INFO - Cache file not found for key: f4cccf89f1103be8443d620b5aa4de103683e4f7b4d3e6893848c0ac71b167fb


  super().__init__()


2025-07-24 11:23:27,262 - llm_backends.cache.disk.DiskCacheStorage - INFO - Storing cache for key: f4cccf89f1103be8443d620b5aa4de103683e4f7b4d3e6893848c0ac71b167fb
2025-07-24 11:23:27,292 - llm_backends.cache.disk.DiskCacheStorage - INFO - Cache stored successfully for key: f4cccf89f1103be8443d620b5aa4de103683e4f7b4d3e6893848c0ac71b167fb


Processing batches: 100%|██████████| 2/2 [00:44<00:00, 22.22s/it]

Batch inference completed with 1000 responses
Total time for inference : 0.4 minutes

--> Total computation time for all batches: 0.7 minutes <--





# 3. Update the medical condition field in the publications dataset

In [17]:
# Prebuild a mapping from publication object_id to the publication record
pub_map = {pub.get("object_id", ""): pub for pub in publis_dataset}

# Update the publis_dataset with the extracted field using progress bars
for llm_responses in tqdm(batch_llm_responses_list, desc="Updating publications batches"):
    if not llm_responses:
        continue  # Skip empty batches
    for response in tqdm(llm_responses, desc="Processing responses", leave=False):
        publi_id = response.get("custom_id", "")
        updated_field = response.get(FIELD, None)       
        if updated_field is not None:
            publi = pub_map.get(publi_id)
            if publi is None:
                continue
            details = None
            formatted_details = []
            if "[" in updated_field and "]" in updated_field:
                start_idx = updated_field.index("[") + 1
                end_idx = updated_field.index("]")
                details = updated_field[start_idx:end_idx]
                formatted_details = [detail.strip() for detail in details.split(";") if detail.strip()]
                updated_field = updated_field.replace(details, "").replace("[", "").replace("]", "").strip()
            formatted_updated_field = [condition.strip() for condition in updated_field.split(";") if condition.strip()]
            publi[FIELD] = formatted_updated_field
            publi[f"{FIELD}_details"] = formatted_details if details else []

# Print the first 20 updated publications (their IDs and the updated field)
n = 20
print(f"First {n} updated publications with extracted field:")
for publi in publis_dataset[:n]:
    publi_id = publi.get("object_id", "<unknown>")
    extracted_field = publi.get(FIELD, "<not extracted>")
    print(f"Publication ID: {publi_id}")
    print(f"Extracted {FIELD}: {extracted_field}")
    print(f"Details: {publi.get(f'{FIELD}_details', '<no details>')}")
    print("---")

Updating publications batches: 100%|██████████| 2/2 [00:00<00:00, 85.64it/s]

First 20 updated publications with extracted field:
Publication ID: 0000572b-fd58-5db3-b3b5-56e70f2098b5
Extracted medical_condition: ['Fabry Disease']
Details: ['Classical Phenotype', 'Nonclassical Phenotype']
---
Publication ID: 000088cb-3888-5e61-8097-3ef8aad34609
Extracted medical_condition: ['Postpartum Breast Cancer']
Details: []
---
Publication ID: 000168f2-596b-5cef-8c17-fbf40d8c4312
Extracted medical_condition: ['Coronary Artery Disease']
Details: []
---
Publication ID: 0002004d-00b2-52d6-bb7d-ba344e8ab8ec
Extracted medical_condition: ['Non-Small Cell Lung Cancer']
Details: ['Brain Metastases', 'Other Single Organ Metastases']
---
Publication ID: 000295f9-c288-59c0-84a0-37f19af4f42a
Extracted medical_condition: ['Nonvalvular Atrial Fibrillation']
Details: []
---
Publication ID: 0003bfcc-a2d1-59a4-98e8-e5600a19f04b
Extracted medical_condition: ['End-Stage Lung Disease']
Details: ['Chronic Obstructive Pulmonary Disease', 'Cystic Fibrosis']
---
Publication ID: 0003c511-1875-5d3d-




# 4. Update the updated field in the registry_dataset

In [18]:
def format_string(string):
    """Format string to remove unwanted characters."""
    # remove punctuation and special characters, lower case
    return ''.join(e for e in string if e.isalnum() or e.isspace()).lower().strip()

In [19]:
# Prebuild a dictionary mapping publication object_ids to their medical conditions
pub_dict = {publi.get("object_id"): publi.get(FIELD, []) for publi in publis_dataset}

# now update updated_field in registry_dataset, as a list of all the medical conditions found in the publications,
# and the number of times they occurred, then rank them by count (highest to lowest)
for registry in tqdm(registry_dataset, desc="Processing registries"):
    terms_counts = {}
    for pub_id in registry.get("list_publi_ids", []):
        updated_field = pub_dict.get(pub_id, [])
        for term in updated_field:
            formatted_term = format_string(term)
            if formatted_term:
                key = formatted_term.title()
                terms_counts[key] = terms_counts.get(key, 0) + 1
    # Rank the medical conditions by count (highest first)
    ranked_terms = dict(sorted(terms_counts.items(), key=lambda item: item[1], reverse=True))
    registry[FIELD] = ranked_terms

# Print the first 3 registries with their ranked medical condition counts
for registry in registry_dataset[:3]:
    print(f"Registry: {registry.get('registry_name', 'Unknown')}")
    print(f"Ranked Medical Conditions Counts: {registry.get(FIELD, {})}\n")

Processing registries: 100%|██████████| 54335/54335 [00:00<00:00, 188454.69it/s]

Registry: Spanish ABPM Registry
Ranked Medical Conditions Counts: {}

Registry: Fasa Registry for Systolic Heart Failure
Ranked Medical Conditions Counts: {}

Registry: OnCovid Registry
Ranked Medical Conditions Counts: {'Breast Cancer': 1}






# 5. Save the results: registries and publications datasets

In [20]:
# save in batches the raw responses to output_raw_inferences_dir
# print number of batches created
print(f"Total batches created for raw responses: {len(prompt_response_records) // batch_size + 1}")
for i in range(0, len(prompt_response_records), batch_size):
    batch = prompt_response_records[i:i + batch_size]
    batch_number = (i // batch_size) + 1
    local_file_path = local_output_raw_inferences_dir / f"{batch_number}.jsonl"
    with open(local_file_path, "w", encoding="utf-8") as fp:
        for record in batch:
            fp.write(json.dumps(record) + "\n")
    print(f"Saved batch {batch_number} of raw responses to {local_file_path}")

Total batches created for raw responses: 3


Saved batch 1 of raw responses to data/from_scripts/SW01/R05_update_medical_condition/raw_inferences/1.jsonl
Saved batch 2 of raw responses to data/from_scripts/SW01/R05_update_medical_condition/raw_inferences/2.jsonl


In [21]:
# Save the publications info with updated field, in batches, using output_publis_json_dir
# print number of batches created
print(f"Total batches created for publications: {len(publis_dataset) // batch_size + 1}")
for i in range(0, len(publis_dataset), batch_size):
    batch = publis_dataset[i:i + batch_size]
    batch_number = i // batch_size + 1
    local_file_path = local_output_publis_json_dir / f"{batch_number}.jsonl"
    # Save the publications with updated field
    with open(local_file_path, "w", encoding="utf-8") as fp:
        for record in batch:
            fp.write(json.dumps(record) + "\n")
    s3_file_name = f"{batch_number}.jsonl"
    upload_jsonl_to_s3(
        batch,
        config.BUCKET_NAME_DEV,
        s3_output_publis_dir,
        s3_file_name
    )

Total batches created for publications: 3




In [22]:
# Save the registry info with updated field, in batches, using output_publis_json_dir
# print number of batches created
print(f"Total batches created for registry data: {len(registry_dataset) // batch_size + 1}")
for i in range(0, len(registry_dataset), batch_size):
    batch = publis_dataset[i:i + batch_size]
    batch_number = i // batch_size + 1
    local_file_path = local_output_registries_jsonl_dir / f"{batch_number}.jsonl"
    # Save the publications with updated field
    with open(local_file_path, "w", encoding="utf-8") as fp:
        for record in batch:
            fp.write(json.dumps(record) + "\n")
    s3_file_name = f"{batch_number}.jsonl"
    upload_jsonl_to_s3(
        batch,
        config.BUCKET_NAME_DEV,
        s3_output_registries_dir,
        s3_file_name,
    )


Total batches created for registry data: 55


