In [5]:
from contextlib import contextmanager
from datetime import datetime
from fastembed import TextEmbedding
from fuzzywuzzy import fuzz
from langchain_core.prompts import ChatPromptTemplate
from langchain_groq import ChatGroq
from requests.exceptions import HTTPError
from ratelimit import limits, sleep_and_retry
from tabulate import tabulate
from tqdm import tqdm

import requests
import faiss
import json
import numpy as np
import os
import pandas as pd
import re
import string
import subprocess
import time
import sys
import pickle
import ast

class LLMClient:
    """Class to manage LLM API configurations and queries."""
    def __init__(self, api_key, base_url, model_name="llama3-groq-70b-8192-tool-use-preview", max_tokens_per_day=500000, max_queries_per_minute=30, temperature=0.7):
        self.api_key = api_key
        self.base_url = base_url
        self.model_name = model_name
        self.max_tokens_per_day = max_tokens_per_day
        self.max_queries_per_minute = max_queries_per_minute
        self.total_tokens_used = 0
        self.temperature = temperature
        self.headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}

    def query(self, user_input, system_message):
        """Sends a query to the LLM API."""
        print(f"TOTAL_TOKENS_USED before query: {self.total_tokens_used}")

        # Check token limit
        estimated_tokens = len(user_input.split()) + len(system_message.split())
        if self.total_tokens_used + estimated_tokens > self.max_tokens_per_day:
            raise Exception("Token limit exceeded for the day.")

        # Enforce rate limit
        time.sleep(60 / self.max_queries_per_minute)

        # Construct the payload
        payload = {
            "model": self.model_name,
            "messages": [
                {"role": "system", "content": system_message},
                {"role": "user", "content": user_input}
            ],
            "temperature": self.temperature,  # Use the temperature from the instance
        }

        # Send the request
        response = requests.post(self.base_url, headers=self.headers, json=payload)
        response.raise_for_status()  # Raise an error for bad responses

        # Update token usage
        self.total_tokens_used += estimated_tokens

        # Parse and return the response content
        result = response.json()
        return result["choices"][0]["message"]["content"] if "choices" in result else "No content returned."

# Global Constants
FLAG_FILE = "process_completed.flag"
INITIALIZED = False  # Tracks if the function has been called
TOTAL_TOKENS_USED = 0  # Initialize token usage tracking globally
MAX_QUERIES_PER_MINUTE = 30
MAX_TOKENS_PER_DAY = 500000
MAX_QUERIES_PER_DAY = MAX_QUERIES_PER_MINUTE * 60 * 24

def load_prompts(file_path="system_prompts.json"):
    with open(file_path, "r") as file:
        return json.load(file)
    
# Load prompts
prompts = load_prompts()

# Access prompts when needed
system_message_I = prompts["system_message_I"]
system_message_II = prompts["system_message_II"]

def initialize_groq_environment():
    """Initializes the LLMClient with user input or default settings."""
    global INITIALIZED, llm_client

    if INITIALIZED:
        print("GROQ environment is already initialized. Skipping.")
        return

    # Gather inputs from the user
    api_key = input("Enter your API key (required): ").strip()
    if not api_key:
        raise ValueError("API key is required to proceed.")

    base_url = input("Enter base URL for API use (default 'https://api.groq.com/openai/v1/chat/completions'): ").strip()
    base_url = base_url if base_url else "https://api.groq.com/openai/v1/chat/completions"

    model_name = input("Enter LLM model name (default 'llama3-groq-70b-8192-tool-use-preview'): ").strip()
    model_name = model_name if model_name else "llama3-groq-70b-8192-tool-use-preview"

    max_tokens_per_day = input("Enter max tokens per day (default 500000): ").strip()
    max_tokens_per_day = int(max_tokens_per_day) if max_tokens_per_day else MAX_TOKENS_PER_DAY

    max_queries_per_minute = input("Enter max queries per minute (default 30): ").strip()
    max_queries_per_minute = int(max_queries_per_minute) if max_queries_per_minute else MAX_QUERIES_PER_MINUTE

    temperature = input("Enter the temperature for the model (default 0.2, range 0.0-1.0): ").strip()
    try:
        temperature = float(temperature) if temperature else 0.2
        if not 0.0 <= temperature <= 1.0:
            raise ValueError
    except ValueError:
        print("Invalid temperature value. Setting to default (0.7).")
        temperature = 0.7

    # Create the LLMClient instance
    llm_client = LLMClient(api_key, base_url, model_name, max_tokens_per_day, max_queries_per_minute)
    llm_client.temperature = temperature  # Add the temperature to the client object

    # Mark initialization as complete
    INITIALIZED = True

    # Save process flag
    with open(FLAG_FILE, "w") as flag:
        flag.write("Process completed.")

    print("GROQ environment successfully initialized.")
    print(f"  API Key: {'*' * len(api_key)} (hidden)")
    print(f"  Base URL: {base_url}") 
    print(f"  Max Queries Per Minute: {max_queries_per_minute}")
    print(f"  Max Tokens Per Day: {max_tokens_per_day}")
    print(f"  Model Name: {model_name}")
    print(f"  Temperature: {temperature}")

def generate_hpo_terms(df, system_message):
    """Generates HPO terms from a DataFrame of phrases and metadata."""
    responses = []
    for _, row in df.iterrows():
        user_input = row['phrase']
        unique_metadata_list = row['unique_metadata']
        original_sentence = row['original_sentence']
        
        # Prepare the human message with context
        context_items = []
        for item in unique_metadata_list:
            parsed_item = clean_and_parse(item)
            if parsed_item:
                for description, hp_id in parsed_item.items():
                    context_items.append(f"- {description} ({hp_id})")

        context_text = '\n'.join(context_items)
        human_message = (f"Query: {user_input}\n"
                         f"Original Sentence: {original_sentence}\n"
                         f"Context: The following related information is available to assist in determining the appropriate HPO terms:\n"
                         f"{context_text}")
        
        # Query the LLM using the unified function
        response_text = llm_client.query(human_message, system_message)
        
        # Extract HPO terms with regex
        hpo_terms = re.findall(r'HP:\d+', response_text)
        if hpo_terms:
            responses.append({"phrase": user_input, "response": ', '.join(hpo_terms)})
        else:
            responses.append({"phrase": user_input, "response": 'No HPO terms found'})
    
    return pd.DataFrame(responses)

# Context manager for subprocess handling
@contextmanager
def managed_subprocess(*args, **kwargs):
    """Manages a subprocess, ensuring it terminates properly upon completion or failure."""
    proc = subprocess.Popen(*args, **kwargs)
    try:
        yield proc
    finally:
        proc.terminate()  # Ensures proper termination of the subprocess
        proc.wait()

# Function for printing with timestamps
def timestamped_print(message):
    """Prints a message with the current timestamp for easy log tracking."""
    print(f"{time.strftime('%Y-%m-%d %H:%M:%S')} - {message}")

# Embeddings related functions
def initialize_embeddings_model():
    """Initializes the embeddings model for processing clinical notes."""
    model_name = "BAAI/bge-small-en-v1.5"
    try:
        embeddings_model = TextEmbedding(model_name=model_name)
        return embeddings_model
    except Exception:
        exit(1)  # Exit on failure to initialize the embeddings model
        print("Error: Unable to initialize the embeddings model.")

def load_embedded_documents(file_path):
    """Loads embedded documents (embeddings) from a given file path."""
    if os.path.exists(file_path):
        return np.load(file_path, allow_pickle=True)
    else:
        exit(1)  # Exit if file not found
        print("Error: File not found.")

def prepare_embeddings_list(embedded_documents):
    """Converts the embedded documents into a NumPy array of embeddings."""
    embeddings_list = [np.array(doc['embedding']) for doc in embedded_documents if isinstance(doc['embedding'], np.ndarray) and doc['embedding'].size > 0]
    if not embeddings_list:
        exit(1)  # Exit if no valid embeddings found
    first_embedding_size = embeddings_list[0].shape[0]  # Ensure uniform embedding size
    return np.vstack([emb for emb in embeddings_list if emb.shape[0] == first_embedding_size])

def create_faiss_index(embeddings_array):
    """Creates a FAISS index for efficient similarity searching on embeddings."""
    if embeddings_array.dtype != np.float32:
        embeddings_array = embeddings_array.astype(np.float32)  # Ensure the correct data type for FAISS
    dimension = embeddings_array.shape[1]  # Determine embedding dimensionality
    index = faiss.IndexFlatL2(dimension)  # L2 (Euclidean) distance index
    index.add(embeddings_array)  # Add embeddings to FAISS index
    return index

def process_row(clinical_note, system_message_I, embeddings_model, index, embedded_documents):
    """Processes a single clinical note by extracting findings and matching metadata."""
    findings_text = llm_client.query(clinical_note, system_message_I)  # Extract findings from clinical note
    if not findings_text:
        return None  # If no findings, skip processing

    findings = extract_findings(findings_text)  # Extract findings from the note
    if not findings:
        return None  # Skip processing if no findings are extracted

    results_df = process_findings(findings, clinical_note, embeddings_model, index, embedded_documents)  # Process and match findings to metadata
    return results_df

def extract_findings(response_content):
    """Extracts findings (key information) from the response content generated by the LLM."""
    try:
        data = json.loads(response_content)  # Parse the JSON content
        findings = data.get("findings", [])
        return findings
    except json.JSONDecodeError:
        return []  # Return an empty list if the content cannot be parsed

def process_findings(findings, clinical_note, embeddings_model, index, embedded_documents):
    """Matches findings with their most relevant metadata entries from embeddings."""
    results = []
    # Split the clinical note into sentences for matching findings to specific contexts
    sentences = clinical_note.split('.')
    
    for finding in findings:
        # Embed the query phrase (finding) using the embeddings model
        query_vector = np.array(list(embeddings_model.embed([finding]))[0]).astype(np.float32).reshape(1, -1)
        distances, indices = index.search(query_vector, 800)  # Search for the most similar embeddings
        
        seen_metadata = set()
        unique_metadata = []

        for idx in indices[0]:
            # Retrieve metadata for the matched embedding
            metadata = embedded_documents[idx]['unique_metadata']
            metadata_str = json.dumps(metadata)  # Convert the metadata dict to a string

            if metadata_str not in seen_metadata:
                seen_metadata.add(metadata_str)  # Track unique metadata
                unique_metadata.append(metadata_str)
                if len(unique_metadata) == 20:  # Limit to the first 20 unique metadata items
                    break

        # Find the best matching sentence from the clinical note for the finding
        finding_words = set(re.findall(r'\b\w+\b', finding.lower()))
        best_match_sentence = None
        max_matching_words = 0

        for sentence in sentences:
            sentence_words = set(re.findall(r'\b\w+\b', sentence.lower()))
            common_words = finding_words & sentence_words

            if len(common_words) > max_matching_words:
                max_matching_words = len(common_words)
                best_match_sentence = sentence.strip()  # Store the sentence with the most matching words

        # Store the results for this finding
        results.append({
            "phrase": finding,
            "unique_metadata": unique_metadata,  # Save unique metadata for this finding
            "original_sentence": best_match_sentence})  # Save the best matching sentence 

    # Convert the results to a DataFrame and save as a CSV file
    # The resulting CSV file will contain the extracted findings and their matched metadata 
    faiss_results_df = pd.DataFrame(results)
    faiss_results_df.to_csv('faiss_search_results.csv', index=False)
    return faiss_results_df  # Return the DataFrame of results

# Helper functions for text cleaning and metadata processing
#Cleans and parses a JSON string by fixing formatting issues.
def clean_and_parse(json_str):
    try:
        json_str = json_str.replace("'", '"')  # Replace single quotes with double quotes
        json_str = re.sub(r'\s+', ' ', json_str)  # Remove excess whitespace
        return json.loads(json_str)  # Return parsed JSON
    except json.JSONDecodeError:
        return None  # Return None if the string cannot be parsed

def process_unique_metadata(metadata):
    """Processes unique metadata by converting all keys to lowercase."""
    if isinstance(metadata, list):
        processed_list = []
        for item in metadata:
            try:
                item_dict = json.loads(item)  # Convert string back to dictionary
                processed_item = {k.lower(): v for k, v in item_dict.items()}  # Make keys lowercase
                processed_list.append(json.dumps(processed_item))  # Convert back to string
            except (json.JSONDecodeError, TypeError):
                continue
        return processed_list  # Return the list of processed metadata
    else:
        return []

def clean_text(text):
    """Cleans input text by converting to lowercase and removing punctuation."""
    text = text.lower()  # Convert to lowercase
    text = text.translate(str.maketrans('', '', string.punctuation)).strip()  # Remove punctuation and trim whitespace
    return text

def extract_hpo_term(phrase, metadata_list):
    """ Extracts HPO terms by matching phrases against a list of metadata. The metadata is loaded from the unique_metadata field in the npy array."""
    cleaned_phrase = clean_text(phrase)  # Clean the input phrase
    fuzzy_matches = []

    # Step 1: Fuzzy Matching with Metadata
    for metadata in metadata_list:
        try:
            # Convert metadata string to dictionary if necessary
            metadata_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
            for term, hp_id in metadata_dict.items():
                cleaned_term = clean_text(term)  # Clean the metadata term
                if fuzz.ratio(cleaned_phrase, cleaned_term) > 80:  # Check for high similarity
                    fuzzy_matches.append({term: hp_id})
        except (json.JSONDecodeError, TypeError):
            continue  # Skip invalid metadata entries

    # If we have fuzzy matches, extend the list
    if fuzzy_matches:
        metadata_list.extend([json.dumps(match) for match in fuzzy_matches])

    # Step 2: Exact Substring Matching
    exact_matches = []
    for metadata in metadata_list:
        try:
            metadata_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
            for term, hp_id in metadata_dict.items():
                cleaned_term = clean_text(term)  # Clean the metadata term
                if cleaned_term in cleaned_phrase:  # Check if term is a substring of the phrase
                    exact_matches.append({term: hp_id})
        except (json.JSONDecodeError, TypeError):
            continue  # Skip invalid metadata entries

    # If we have exact matches, extend the list
    if exact_matches:
        metadata_list.extend([json.dumps(match) for match in exact_matches])

    # Step 3: Exact Matching within Metadata List
    for metadata in metadata_list:
        if not metadata.strip():
            continue
        try:
            metadata_dict = json.loads(metadata) if isinstance(metadata, str) else metadata
            for term, hp_id in metadata_dict.items():
                cleaned_term = clean_text(term)
                if cleaned_phrase == cleaned_term:  # Check for an exact match
                    print(f"Exact match found: {hp_id}")
                    return hp_id  # Return the exact match
        except (json.JSONDecodeError, TypeError):
            continue  # Skip invalid metadata entries

    return None  # Return None if no match is found


# Process the results
def process_results(final_result_df):
    save_or_display = input("Do you want to save the results as a CSV file or display them? (save/display): ").strip().lower()
    
    if save_or_display == 'save':
        output_file = input("Enter the name of the output file (with .csv extension): ").strip()
        
        # Prepare the structured data for saving
        new_data = []
        for idx, row in final_result_df.iterrows():
            patient_id = row['patient_id']
            hpo_terms = row['HPO_Terms']  # Directly access the list of dictionaries
            
            # Process each term
            for term in hpo_terms:
                phrase = term.get('phrase', '').strip()
                hpo_id = term.get('HPO_Term', '').replace("HP:HP:", "HP:")
                
                if not hpo_id:
                    print(f"Blank HPO_Term for patient_id {patient_id} with phrase '{phrase}'")
                
                new_data.append({
                    'Patient ID': patient_id,
                    'Phenotype name': phrase,
                    'HPO ID': hpo_id
                })
        
        # Create a new DataFrame and save it
        new_df = pd.DataFrame(new_data)
        new_df.to_csv(output_file, index=False)
        json_csv_file = f"{output_file}_json.csv"
        final_result_df.to_csv(json_csv_file, index=False)
        timestamped_print(f"Data has been successfully saved to {output_file}")
    
    elif save_or_display == 'display':
        # Prepare the structured data for display
        flattened_data = []
        for idx, row in final_result_df.iterrows():
            patient_id = row['patient_id']
            hpo_terms = row['HPO_Terms']  # Directly access the list of dictionaries
            
            # Process each term
            for term in hpo_terms:
                flattened_data.append({
                    'Case': f"Case {patient_id}",
                    'Phenotype name': term.get('phrase', '').strip(),
                    'HPO ID': term.get('HPO_Term', '').replace("HP:HP:", "HP:")
                })
        
        if flattened_data:
            flattened_df = pd.DataFrame(flattened_data)
            print(tabulate(flattened_df, headers='keys', tablefmt='psql'))
        else:
            timestamped_print("No HPO terms found to display.")
    else:
        print("Invalid choice. Please choose either 'save' or 'display'.")
        
if __name__ == "__main__":
    if check_and_initialize(): # Run your one-time initialization process here
        initialize_groq_environment()
    try:
        start_time = time.time()  # Record the start time
        timestamped_print("Starting the HPO term extraction process.")

        # Check for temporary files at the very beginning
        temp_files = [
            'temp_combined_results.pkl',
            'temp_exact_matches.pkl',
            'temp_non_exact_matches.pkl',
            'temp_final_result.pkl',
            'responses_backup.pkl'
        ]
        temp_files_exist = any(os.path.exists(f) for f in temp_files)

        if temp_files_exist:
            timestamped_print("Temporary files found from a previous run. Attempting to resume processing.")

            # Load existing results if they exist
            combined_results_df = pd.read_pickle('temp_combined_results.pkl') if os.path.exists('temp_combined_results.pkl') else pd.DataFrame()
            exact_matches_df = pd.read_pickle('temp_exact_matches.pkl') if os.path.exists('temp_exact_matches.pkl') else pd.DataFrame()
            non_exact_matches_df = pd.read_pickle('temp_non_exact_matches.pkl') if os.path.exists('temp_non_exact_matches.pkl') else pd.DataFrame()
            final_result_df = pd.read_pickle('temp_final_result.pkl') if os.path.exists('temp_final_result.pkl') else pd.DataFrame()
            responses_df = pd.read_pickle('responses_backup.pkl') if os.path.exists('responses_backup.pkl') else pd.DataFrame()

            # Debugging statements to check DataFrame shapes
            timestamped_print(f"Combined results DataFrame shape: {combined_results_df.shape}")
            timestamped_print(f"Exact matches DataFrame shape: {exact_matches_df.shape}")
            timestamped_print(f"Non-exact matches DataFrame shape: {non_exact_matches_df.shape}")
            timestamped_print(f"Final results DataFrame shape: {final_result_df.shape}")
            timestamped_print(f"Responses DataFrame shape: {responses_df.shape if 'responses_df' in locals() else 'Not loaded'}")

            # Determine which steps need to be executed based on which temp files are missing
            steps_to_run = []

            if combined_results_df.empty:
                steps_to_run.append('process_clinical_notes')
            if exact_matches_df.empty or non_exact_matches_df.empty:
                steps_to_run.append('process_exact_non_exact_matches')
            if not non_exact_matches_df.empty and ('HPO_Term' not in non_exact_matches_df.columns or non_exact_matches_df['HPO_Term'].isna().any()):
                steps_to_run.append('generate_hpo_terms')
            if final_result_df.empty:
                steps_to_run.append('compile_final_results')

            timestamped_print(f"Steps to run: {steps_to_run}")
        else:
            timestamped_print("No temporary files found. Starting a new processing run.")

            # Option for user input
            user_input = input("Do you want to provide clinical notes directly (yes/no)? ").strip().lower()
            if user_input == 'yes':
                clinical_notes = []
                while True:
                    note = input("Enter clinical note (or type 'done' to finish): ").strip()
                    if note.lower() == 'done':
                        break
                    clinical_notes.append(note)
                df = pd.DataFrame({'clinical_note': clinical_notes})

                # Assign patient IDs if not provided
                df['patient_id'] = range(1, len(df) + 1)
            else:
                while True:
                    input_file = input("Enter the filename of the CSV containing clinical notes: ").strip()
                    if input_file.lower().endswith('.csv'):
                        try:
                            df = pd.read_csv(input_file)

                            # Check for 'patient_id' or 'case number' columns
                            if 'patient_id' in df.columns:
                                df['patient_id'] = df['patient_id']
                            elif 'case number' in df.columns:
                                df['patient_id'] = df['case number']
                            else:
                                # Assign new patient IDs if not present
                                df['patient_id'] = range(1, len(df) + 1)

                            break
                        except FileNotFoundError:
                            print("File not found. Please ensure the file exists and the path is correct.")
                        except pd.errors.EmptyDataError:
                            print("The file is empty. Please provide a valid CSV file with data.")
                        except Exception as e:
                            print(f"An error occurred: {e}")
                    else:
                        print("The file must have a .csv extension. Please provide a valid CSV file.")

            df = df.dropna(subset=['clinical_note'])  # Remove rows where 'clinical_note' is NaN
            df['clinical_note'] = df['clinical_note'].astype(str)  # Ensure 'clinical_note' column is of type string

            # Initialize empty DataFrames
            combined_results_df = pd.DataFrame()
            exact_matches_df = pd.DataFrame()
            non_exact_matches_df = pd.DataFrame()
            responses_df = pd.DataFrame()
            final_result_df = pd.DataFrame()

            steps_to_run = ['process_clinical_notes', 'process_exact_non_exact_matches', 'generate_hpo_terms', 'compile_final_results']

        # Proceed with processing based on steps_to_run
        if 'process_clinical_notes' in steps_to_run:
            timestamped_print("Processing clinical notes.")

            # Initialize models and data
            timestamped_print("Initializing embeddings model")
            embeddings_model = initialize_embeddings_model()
            timestamped_print("Loading embedded documents")
            embedded_documents = load_embedded_documents('G2GHPO_metadata.npy')
            timestamped_print("Preparing embeddings list")
            embeddings_array = prepare_embeddings_list(embedded_documents)
            timestamped_print("Creating FAISS index")
            index = create_faiss_index(embeddings_array)

            # Determine which patient_ids still need to be processed
            if not combined_results_df.empty:
                processed_patient_ids = set(combined_results_df['patient_id'].unique())
            else:
                processed_patient_ids = set()

            all_patient_ids = set(df['patient_id'].unique())
            remaining_patient_ids = all_patient_ids - processed_patient_ids

            if remaining_patient_ids:
                timestamped_print(f"Processing {len(remaining_patient_ids)} new clinical notes.")

                # Process each remaining clinical note
                for _, row in df[df['patient_id'].isin(remaining_patient_ids)].iterrows():
                    clinical_note = row['clinical_note']
                    patient_id = row['patient_id']
                    timestamped_print(f"Processing clinical note for patient_id {patient_id}: {clinical_note[:30]}...")
                    result_df = process_row(clinical_note, system_message_I, embeddings_model, index, embedded_documents)
                    if result_df is not None:
                        result_df['patient_id'] = patient_id
                        combined_results_df = pd.concat([combined_results_df, result_df], ignore_index=True)
                    time.sleep(2)
                    combined_results_df.to_pickle('temp_combined_results.pkl')
            else:
                timestamped_print("All clinical notes have been processed in 'temp_combined_results.pkl'.")

        else:
            timestamped_print("Combined results loaded from 'temp_combined_results.pkl'.")

        if combined_results_df.empty:
            timestamped_print("No data in combined_results_df. Exiting.")
            sys.exit(0)

        combined_results_df['phrase'] = combined_results_df['phrase'].str.lower()
        combined_results_df['unique_metadata'] = combined_results_df['unique_metadata'].apply(process_unique_metadata)

        if 'process_exact_non_exact_matches' in steps_to_run:
            timestamped_print("Processing exact and non-exact matches.")

            # Add HPO terms for exact matches
            combined_results_df['HPO_Term'] = combined_results_df.apply(
                lambda row: extract_hpo_term(row['phrase'], row['unique_metadata']), axis=1
            )

            # Separate exact and non-exact matches
            exact_matches_df = combined_results_df.dropna(subset=['HPO_Term'])
            non_exact_matches_df = combined_results_df[combined_results_df['HPO_Term'].isna()]
            exact_matches_df.to_pickle('temp_exact_matches.pkl')
            non_exact_matches_df.to_pickle('temp_non_exact_matches.pkl')
        else:
            timestamped_print("Exact and non-exact matches loaded from temporary files.")

        # Ensure 'HPO_Term' column exists and is of correct type
        if 'HPO_Term' in non_exact_matches_df.columns:
            # Identify indices where 'HPO_Term' is NaN (unprocessed entries)
            remaining_indices = non_exact_matches_df[non_exact_matches_df['HPO_Term'].isna()].index
            timestamped_print(f"Number of unprocessed non-exact matches: {len(remaining_indices)}")
        else:
            remaining_indices = non_exact_matches_df.index
            # Initialize 'HPO_Term' column with NaN values
            non_exact_matches_df['HPO_Term'] = np.nan
            timestamped_print(f"Total non-exact matches to process: {len(remaining_indices)}")

        # Check if non_exact_matches_df is empty
        if non_exact_matches_df.empty:
            timestamped_print("No non-exact matches found. Skipping HPO term generation for non-exact matches.")
            # Ensure 'HPO_Term' column exists
            if 'HPO_Term' not in non_exact_matches_df.columns:
                non_exact_matches_df['HPO_Term'] = pd.Series(dtype="object")
            # Remove 'generate_hpo_terms' from steps_to_run if present
            if 'generate_hpo_terms' in steps_to_run:
                steps_to_run.remove('generate_hpo_terms')

        if 'generate_hpo_terms' in steps_to_run:
            timestamped_print("Generating HPO terms for non-exact matches.")
            timestamped_print(f"Processing {len(remaining_indices)} non-exact matches.")

            # Initialize responses DataFrame if empty
            if responses_df.empty:
                responses_df = pd.DataFrame(columns=['response'])

            # Process the remaining non-exact matches
            if len(remaining_indices) > 0:
                for idx in tqdm(remaining_indices, total=len(remaining_indices)):
                    row = non_exact_matches_df.loc[idx]
                    response_df = generate_hpo_terms(
                        pd.DataFrame([row]),
                        system_message_II
                    )
                    if response_df['response'].iloc[0] is not None:
                        responses_df = pd.concat([responses_df, response_df], ignore_index=True)
                        non_exact_matches_df.at[idx, 'HPO_Term'] = response_df['response'].iloc[0]
                    else:
                        # Handle cases where response is None due to repeated failures
                        non_exact_matches_df.at[idx, 'HPO_Term'] = 'Error: Unable to process'
                    time.sleep(2)
                    if idx % 10 == 0:
                        responses_df.to_pickle('responses_backup.pkl')
                        non_exact_matches_df.to_pickle('temp_non_exact_matches.pkl')
                # Save final progress
                responses_df.to_pickle('responses_backup.pkl')
                non_exact_matches_df.to_pickle('temp_non_exact_matches.pkl')
            else:
                timestamped_print("No unprocessed non-exact matches found.")
        else:
            timestamped_print("All non-exact matches have been processed.")

        if 'compile_final_results' in steps_to_run:
            timestamped_print("Compiling final results.")

            # Combine results and prepare final output
            final_combined_df = pd.concat([exact_matches_df, non_exact_matches_df], ignore_index=True)
            final_combined_df_grouped = final_combined_df.groupby('patient_id').apply(
                lambda group: group[['phrase', 'HPO_Term']].to_dict('records')
            )

            final_result_df = pd.DataFrame({
                'patient_id': final_combined_df_grouped.index,
                'HPO_Terms': final_combined_df_grouped.values
            })
            final_result_df.to_pickle('temp_final_result.pkl')
        else:
            timestamped_print("Final results loaded from 'temp_final_result.pkl'.")

        # Process the final results
        process_results(final_result_df) 

        timestamped_print(f"Total execution time: {time.time() - start_time:.2f} seconds")

        # Delete temporary files
        for temp_file in temp_files:
            if os.path.exists(temp_file):
                os.remove(temp_file)

    except Exception as e:
        timestamped_print(f"An unexpected error occurred: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)

Initialization has already been completed. Skipping.
2024-12-18 16:24:07 - Starting the HPO term extraction process.
2024-12-18 16:24:07 - No temporary files found. Starting a new processing run.
2024-12-18 16:24:20 - Processing clinical notes.
2024-12-18 16:24:20 - Initializing embeddings model


Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

2024-12-18 16:24:20 - Loading embedded documents
2024-12-18 16:24:21 - Preparing embeddings list
2024-12-18 16:24:21 - Creating FAISS index
2024-12-18 16:24:21 - Processing 1 new clinical notes.
2024-12-18 16:24:21 - Processing clinical note for patient_id 1: A 32-year-old man presented to...
TOTAL_TOKENS_USED before query: 1885
2024-12-18 16:24:26 - Processing exact and non-exact matches.
Exact match found: HP:0001903
Exact match found: HP:0001250
Exact match found: HP:0012185
Exact match found: HP:0003002
Exact match found: HP:0005231
Exact match found: HP:0001009
2024-12-18 16:24:26 - Number of unprocessed non-exact matches: 17
2024-12-18 16:24:26 - Generating HPO terms for non-exact matches.
2024-12-18 16:24:26 - Processing 17 non-exact matches.


  0%|          | 0/17 [00:00<?, ?it/s]

TOTAL_TOKENS_USED before query: 2367


  6%|▌         | 1/17 [00:04<01:09,  4.35s/it]

TOTAL_TOKENS_USED before query: 2613


 12%|█▏        | 2/17 [00:08<01:04,  4.31s/it]

TOTAL_TOKENS_USED before query: 2896


 18%|█▊        | 3/17 [00:12<01:00,  4.29s/it]

TOTAL_TOKENS_USED before query: 3152


 24%|██▎       | 4/17 [00:17<00:55,  4.30s/it]

TOTAL_TOKENS_USED before query: 3437


 29%|██▉       | 5/17 [00:21<00:51,  4.30s/it]

TOTAL_TOKENS_USED before query: 3769


 35%|███▌      | 6/17 [00:25<00:47,  4.30s/it]

TOTAL_TOKENS_USED before query: 4021


 41%|████      | 7/17 [00:30<00:42,  4.29s/it]

TOTAL_TOKENS_USED before query: 4331


 47%|████▋     | 8/17 [00:34<00:38,  4.30s/it]

TOTAL_TOKENS_USED before query: 4656


 53%|█████▎    | 9/17 [00:38<00:34,  4.30s/it]

TOTAL_TOKENS_USED before query: 4943


 59%|█████▉    | 10/17 [00:42<00:30,  4.30s/it]

TOTAL_TOKENS_USED before query: 5294


 65%|██████▍   | 11/17 [00:47<00:25,  4.30s/it]

TOTAL_TOKENS_USED before query: 5619


 71%|███████   | 12/17 [00:51<00:21,  4.30s/it]

TOTAL_TOKENS_USED before query: 5923


 76%|███████▋  | 13/17 [00:55<00:17,  4.31s/it]

TOTAL_TOKENS_USED before query: 6298


 82%|████████▏ | 14/17 [01:00<00:12,  4.32s/it]

TOTAL_TOKENS_USED before query: 6559


 88%|████████▊ | 15/17 [01:04<00:08,  4.31s/it]

TOTAL_TOKENS_USED before query: 6872


 94%|█████████▍| 16/17 [01:08<00:04,  4.31s/it]

TOTAL_TOKENS_USED before query: 7132


100%|██████████| 17/17 [01:13<00:00,  4.30s/it]


2024-12-18 16:25:39 - Compiling final results.
+----+--------+--------------------------------+--------------------+
|    | Case   | Phenotype name                 | HPO ID             |
|----+--------+--------------------------------+--------------------|
|  0 | Case 1 | anaemia                        | HP:0001903         |
|  1 | Case 1 | epilepsy                       | HP:0001250         |
|  2 | Case 1 | carpal tunnel syndrome         | HP:0012185         |
|  3 | Case 1 | breast cancer                  | HP:0003002         |
|  4 | Case 1 | chronic gastritis              | HP:0005231         |
|  5 | Case 1 | telangiectasia                 | HP:0001009         |
|  6 | Case 1 | heart arrhythmia               | HP:0011675         |
|  7 | Case 1 | c5 nerve root compression      | HP:0003406         |
|  8 | Case 1 | bowel cancer                   | HP:0003003         |
|  9 | Case 1 | gastric mucosa                 | HP:0004295         |
| 10 | Case 1 | minor foveolar hyperplasia 