In [None]:
import os
import polars as pl
import sentence_transformers
from pathlib import Path
from model2vec import StaticModel
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from tqdm import tqdm
from datetime import datetime

# Set the path (not needed now)

# path = "/home/alex/ews/diagnoses"

# os.chdir(path)

# print("Current working directory: ", os.getcwd()) # And here we can check it

In [2]:
# Set up the output directory
output_dir = "/home/alex/ews/diagnoses"

# List all processed batch files
processed_files = list(Path(output_dir).glob("diagnoses_ordered_new*.parquet"))

In [5]:
# Verify that the data for a given batch make sense
# CHECK PASSED
small = pl.read_parquet("diagnoses_ordered_new_batch_1.parquet")

In [7]:
# This is not important here
# We basically use a sentence transformer to extract embeddings
# Those embeddings extracted will not be used!!

model = SentenceTransformer("tomaarsen/static-similarity-mrl-multilingual-v1", 
                            trust_remote_code=True, truncate_dim=30, device="cuda")

In [None]:
# Create a function which extract the embeddings per batch and then concatenates all the information

def process_batches(processed_files, model, verbose_every=50):
    """Process batch files and return a Polars DataFrame with PT_IDs, CSNs, Previous_Diagnoses, and embeddings.

    Args:
        processed_files: List of parquet files to process
        model: Embedding model
        verbose_every: Print detailed info every N batches (default: 50)
    """

    all_pt_ids = []
    all_csn = []
    all_previous_diagnoses = []
    all_embeddings = []

    total_batches = len(processed_files)
    start_time = datetime.now()

    # Process each batch file with progress bar
    for idx, batch_file in enumerate(tqdm(processed_files, desc="Processing batches")):
        # Read parquet file and replace None with "NA" in Previous_Diagnoses
        df = pl.read_parquet(batch_file).select(["PT_ID", "CSN", "Previous_Diagnoses"])
        df = df.with_columns(pl.col("Previous_Diagnoses").fill_null("NA"))

        # Get diagnoses list
        diagnoses_list = df['Previous_Diagnoses'].to_list()

        embeddings_batch = model.encode(
            diagnoses_list,
            show_progress_bar=False,
            device="cuda",
            normalize_embeddings=True
        )

        # Store PT_IDs, CSNs, Previous_Diagnoses, and embeddings
        all_pt_ids.extend(df['PT_ID'].to_list())
        all_csn.extend(df['CSN'].to_list())
        all_previous_diagnoses.extend(df['Previous_Diagnoses'].to_list())
        all_embeddings.append(embeddings_batch)

        # Print verbose information every N batches
        if (idx + 1) % verbose_every == 0:
            current_time = datetime.now()
            elapsed_time = current_time - start_time
            avg_time_per_batch = elapsed_time / (idx + 1)
            remaining_batches = total_batches - (idx + 1)
            estimated_remaining_time = remaining_batches * avg_time_per_batch

            print(f"\n{'='*80}")
            print(f"Batch Progress Report ({idx + 1}/{total_batches}):")
            print(f"Total PT_IDs processed: {len(all_pt_ids)}")
            print(f"Memory usage of embeddings: {sum(e.nbytes for e in all_embeddings)/1e9:.2f} GB")
            print(f"Time elapsed: {elapsed_time}")
            print(f"Estimated time remaining: {estimated_remaining_time}")
            print(f"Average time per batch: {avg_time_per_batch.total_seconds():.2f} seconds")
            print(f"{'='*80}\n")

        # Clear memory
        del df
        del diagnoses_list
        del embeddings_batch

    # Combine all embeddings
    final_embeddings = np.concatenate(all_embeddings, axis=0)

    # Create embedding column names
    embedding_cols = [f"diagn_embed_{i}" for i in range(final_embeddings.shape[1])]

    # Create Polars DataFrame
    embeddings_df = pl.DataFrame(
        {
            "PT_ID": all_pt_ids,
            "CSN": all_csn,
            "Previous_Diagnoses": all_previous_diagnoses,
            **{col: final_embeddings[:, i] for i, col in enumerate(embedding_cols)}
        }
    )

    # Print final summary
    total_time = datetime.now() - start_time
    print(f"\n{'='*80}")
    print("Final Processing Summary:")
    print(f"Total batches processed: {total_batches}")
    print(f"Total PT_IDs processed: {len(all_pt_ids)}")
    print(f"Final DataFrame shape: {embeddings_df.shape}")
    print(f"Total processing time: {total_time}")
    print(f"Average time per batch: {total_time.total_seconds()/total_batches:.2f} seconds")
    print(f"{'='*80}\n")

    return embeddings_df

# Usage
embeddings_df = process_batches(processed_files, model, verbose_every=50)


In [None]:
# This dataframe contains all the personal identifiers (PT_ID), the hospitalization number (CSN), and the previous diagnoses, along with the final embeddings from the sentence transformer
embeddings_df.shape

In [10]:
# Save the data as parquet files (we will use that file later on)
# embeddings_df.write_parquet("embed_diagnoses_updated_prevs.parquet")