In [1]:
import os
working_dir = "/home/gpinon/more_europa/clean_rdc_experiments/projects/P05_refine_dedup"
os.chdir(working_dir)
print(f"Changed working directory to {working_dir}")
import logging
import time
import pandas as pd
import json
from pathlib import Path
from dotenv import load_dotenv
import weaviate
import boto3
from tqdm import tqdm

from src.p05_refine_dedup import config
from src.p05_refine_dedup.utils.s3_io_functions import (
    load_jsonl_from_s3,
    upload_jsonl_to_s3,
)
import llm_backends
from llm_backends.cache import DiskCacheStorage
from llm_backends.mistral import dummy_config
from llm_backends.openai import dummy_config

Changed working directory to /home/gpinon/more_europa/clean_rdc_experiments/projects/P05_refine_dedup


In [2]:
s3_input_registries_dir = "registry_data_catalog_experiments/P04_official_reg_db_creation/datasets_versions/update_medical_condition/registry_data"
local_output_embeddings_jsonl_template = "data/W01/R01_embed_registry_names/1.jsonl"
s3_output_embeddings_dir = "registry_data_catalog_experiments/P05_refine_dedup/registry_name_embeddings"

In [3]:
logging.basicConfig(
    level=logging.WARNING,
    format="%(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
logging.getLogger("httpx").setLevel(logging.WARNING)

# Load environment variables from .env file
load_dotenv()

# Get the API key from environment variables
api_key = os.environ["MISTRAL_API_KEY"]
MISTRAL_EMBEDDING_MODEL = "mistral-embed"
MISTRAL_EMBEDDING_CONFIG = {
    "model": MISTRAL_EMBEDDING_MODEL,
}

In [4]:
# 1. Load registry data from S3
# Retrieve total_batches = how many files are in the folder
s3 = boto3.client("s3")
response = s3.list_objects_v2(
    Bucket=config.BUCKET_NAME_DEV, Prefix=s3_input_registries_dir
)
# total_batches = len(response.get("Contents", []))
total_batches = 1

registry_dataset = []
for batch_number in range(1, total_batches + 1):
    file_name = f"{batch_number}.jsonl"
    batch = load_jsonl_from_s3(
        config.BUCKET_NAME_DEV, s3_input_registries_dir, file_name
    )
    registry_dataset.extend(batch)

logger.warning(
    f"Loaded {len(registry_dataset)} records from the registry data, from {total_batches} batches."
)
# test on 5 records
registry_dataset = registry_dataset[:5]
logger.warning(f"Using a subset of {len(registry_dataset)} records for testing.")

2025-07-29 07:56:13,563 - botocore.credentials - INFO - Found credentials from IAM Role: vm-gpinon-role20241129125126670400000001


In [9]:
# 2. Initialize the backend for Mistral embeddings
backend = llm_backends.MistralEmbeddingBackend(
    api_key=os.getenv("MISTRAL_API_KEY"), cache_storage=DiskCacheStorage()
)

2025-07-29 08:07:37,519 - llm_backends.cache.disk.DiskCacheStorage - INFO - Disk cache initialized at: /home/gpinon/more_europa/clean_rdc_experiments/src/llm_backends/llm_backends/.cache


In [10]:
# 3. Prepare prompts for embeddings
# for all registries, create a new field "full_name" that is the concatenation of "registry_name" and "acronym" IF "acronym" is not empty string or None
for registry in registry_dataset:
    registry_name = registry.get("registry_name", "")
    acronym = registry.get("acronym", "")
    if acronym:
        full_name = f"{registry_name} ({acronym})"
    else:
        full_name = registry_name
    registry["full_name"] = full_name

prompt_items = [
    {"custom_id": registry["object_id"], "prompt": registry["full_name"]}
    for registry in registry_dataset
]

batch_size = 1000
# batch_prompts_list = []
# for i in range(0, len(prompt_items), batch_size):
#     batch_prompts_list.append(prompt_items[i : i + batch_size])
# # print how many batches we have
# print(f"Total batches created: {len(batch_prompts_list)}")

In [11]:
from itertools import islice

def batched(iterable, n, *, strict=False):
    # batched('ABCDEFG', 2) â†’ AB CD EF G
    if n < 1:
        raise ValueError('n must be at least one')
    iterator = iter(iterable)
    while batch := tuple(islice(iterator, n)):
        if strict and len(batch) != n:
            raise ValueError('batched(): incomplete batch')
        yield batch

In [12]:
# 4. Run inference to get embeddings
# track time of execution
start_time = time.time()
for prompt_items_batch in tqdm(
    batched(prompt_items, batch_size), total=1 + len(prompt_items) // batch_size
):
    # print prompt_items_batch
    for prompt_item in prompt_items_batch:
        print(
            f"Processing prompt item: {prompt_item['custom_id']} - {prompt_item['prompt']}"
        )

    # print model config used
    logger.warning(
        f"Using model config: {json.dumps(MISTRAL_EMBEDDING_CONFIG, indent=2)}"
    )
    for result_item in backend.infer_many(
        prompt_items=prompt_items_batch, model_config=MISTRAL_EMBEDDING_CONFIG
    ):
        registry_id = result_item["custom_id"]
        registry_embedding = result_item["embedding"]
        registry_name[registry_id]["registry_embedding"] = registry_embedding
    intermediate_time = time.time()
    logger.warning(f"Processed batch in {intermediate_time - start_time:.0f} seconds")
end_time = time.time()
# logg time in minutes
logger.warning(f"Total time for inference: {(end_time - start_time) / 60:.1f} minutes")

  0%|          | 0/1 [00:00<?, ?it/s]

Processing prompt item: 1 - Spanish ABPM Registry (ABPM)
Processing prompt item: 2 - Fasa Registry for Systolic Heart Failure (FARSH)
Processing prompt item: 3 - OnCovid Registry (OnCovid)
Processing prompt item: 4 - New York State Cancer Registry (NYSCR)
Processing prompt item: 5 - China Liver Transplant Registry (CLTR)
  "model": "mistral-embed"
}


  0%|          | 0/1 [00:00<?, ?it/s]


TypeError: 'str' object does not support item assignment

In [None]:
# 5. Save locally then on s3 bucket
# First create the batches
output_batches = []
for i in range(0, len(registry_dataset), batch_size):
    batch = registry_dataset[i : i + batch_size]
    output_batches.append(batch)

# Then Save to local files and s3
local_output_dir = Path(local_output_embeddings_jsonl_template).parent
local_output_dir.mkdir(parents=True, exist_ok=True)
for batch_number, batch in enumerate(output_batches, start=1):
    output_file = local_output_dir / f"{batch_number}.jsonl"
    with open(output_file, "w") as f:
        for record in batch:
            f.write(f"{record}\n")
    logger.warning(f"Saved batch {batch_number} to {output_file}")
    # Save to S3
    s3_file_name = f"{batch_number}.jsonl"
    upload_jsonl_to_s3(
        batch, config.BUCKET_NAME_DEV, s3_output_embeddings_dir, s3_file_name
    )

In [None]:
## Reload to check the results
from src.p05_refine_dedup.utils.s3_io_functions import (
    load_parquet_from_s3,
    upload_parquet_to_s3,
)

In [None]:
s3_input_embeddings = "registry_data_catalog_experiments/P05_refine_dedup/registry_name_embeddings.parquet"

bucket_name = config.BUCKET_NAME_DEV
