In [None]:
# LLM Evaluation of Sinusitis Surgery Recommendation

# This script evaluates LLM clinical decision-making for sinusitis surgery.
# The workflow is as follows:
# 1.  Setup: Load libraries and configure API keys.
# 2.  Data Loading: Load and combine multiple years of patient data from BigQuery.
# 3.  Preprocessing: 
    # - Group all records by patient ID.
    # - For each patient, create a sorted, longitudinal clinical note history.
    # - Identify the date of surgery (if any) and exclude pre-operative/post-operative notes.
    # - Censor any sentences mentioning surgical plans to create a "blinded" note for the LLM.
    # - Aggregate all other clinical data (labs, meds, demographics).
    # - Create a clean, flat DataFrame where each row is a unique patient.
# 4.  LLM Analysis:
#     - Generate a structured prompt for each case.
#     - Send the prompt to the GPT-4 API and parse the JSON response.
# 5.  Evaluation:
#     - Compare the LLM's decision against actual surgery CPT codes.
#     - Calculate accuracy, precision, recall, and F1-score.
#     - Analyze the LLM's confidence on correct vs. incorrect predictions.
# 6.  Output: Save the full results and a sample for human evaluation to CSV files.

# Setup 

In [None]:
# Standard libraries
import pandas as pd
import numpy as np
import os
import re
import openai
import json
import matplotlib.pyplot as plt

# For evaluation
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# For progress bars
from tqdm.auto import tqdm 

# For BigQuery access
from google.cloud import bigquery

In [None]:
# Setup - Open AI Authentication
try:
    with open("openai_key.txt", "r") as f:
        os.environ["OPENAI_API_KEY"] = f.read().strip()
    openai.api_key = os.environ["OPENAI_API_KEY"]
except FileNotFoundError:
    print("OpenAI key file not found. Make sure 'openai_key.txt' is in the directory,")
    # Exit if key is not found, as the script cannot proceed.
    exit()

def query_openai(prompt):
    """Sends a prompt to the GPT-4 model and returns the content of the response."""
    try:
        response = openai.ChatCompletion.create(
            model="gpt-4",
            messages=[
                {"role": "system", "content": "You are an expert otolaryngologist. Your task is to provide a surgical recommendation based on the provided patient data. Respond only in the requested JSON format."},
                {"role": "user", "content": prompt}
            ],
            temperature=0.2
        )
        return response.choices[0].message.content
    except Exception as e:
        print(f"An error occurred with the OpenAI API: {e}")
        return None

In [None]:
# Data Loading - Read in datasets from BigQuery
# GCP Configuration
PROJECT_ID = "som-nero-phi-roxanad-entllm"

print("Connecting to BigQuery to load datasets...")
client = bigquery.Client(project=PROJECT_ID)

# List of datasets to be combined
DATASET_IDS = [
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2016',
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2017',
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2018',
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2019',
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2020_21',
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2022_23',
    'starr_phi_confidential_stride_79984_ChronicSinusitis_2024_25',
]

# Name of tables to load from each dataset
DATA_TABLES = [
    'demographics',
    'clinical_note',
    'procedures',
    'labs',
    'med_orders',
    'radiology_report'
]

# Create dictionary to hold dataframes
dataframes = {}

for table_name in DATA_TABLES:
     print(f"Loading and unioning table: '{table_name}'...")
     try:
        # Construct a query to UNION this specific table across all yearly datasets
        union_query = "\nUNION ALL\n".join(
            [f"SELECT * FROM `{PROJECT_ID}.{dataset_id}.{table_name}`" for dataset_id in DATASET_IDS]
        )
        # Load the data and store it in our dictionary
        dataframes[table_name] = client.query(union_query).to_dataframe()
        print(f" -> Loaded {len(dataframes[table_name])} total rows for '{table_name}'.")
    except Exception as e:
        print(f"\n--- QUERY FAILED for table '{table_name}' ---")
        print(f"An error occurred: {e}")
        print("Please check that this table exists in all of your yearly datasets.")
        exit()

In [None]:
# Preprocessing - Setup
print("Preprocessing data...")

# Define keywords and CPT codes for efficient searching
SURGERY_CPT_CODES = {
    '31253', '31254', '31255', '31256', '31257', '31259', '31267',
    '31276', '31287', '31288', '31240'
}
LAB_KEYWORDS = {'cbc', 'eosinophil count', 'eosinophil %', 'bmp', 'cmp', 'ige', 'igg', 'iga', 'igm', 'sinus culture', 'nasal culture'}
MED_KEYWORDS = {
    'amoxicillin', 'amoxicillin-clavulanate', 'doxycycline', 'trimethoprim-sulfamethoxazole', 'clindamycin',
    'levofloxacin', 'cefdinir', 'moxifloxacin', 'cefuroxime', 'azithromycin', 'mupirocin', 'gentamicin',
    'tobramycin', 'vancomycin', 'prednisone', 'methylprednisone', 'dexamethasone', 'budesonide',
    'mometasone', 'fluticasone', 'azelastine', 'saline rinse'
}
DIAGNOSTIC_ENDOSCOPY_CPT_CODES = {'31231', '31237'}

def censor_surgical_plans(text):
    """Censors sentences containing surgical plans or recommendations."""
    import re
    if not isinstance(text, str): return text  # Ensure input is a string
    text = text.lower()  # Normalize to lowercase for keyword matching
    # Define keywords that indicate surgical plans or recommendations
    censor_keywords = ['surgery', 'fess', 'ess', 'operative', 'operation', 'surgical', 'recommend proceeding with', 'plan for', 'schedule', 'consent for']
    # Split text into sentences and filter out those containing any of the keywords
    # This regex splits on sentence-ending punctuation followed by whitespace
    sentences = re.split(r'(?<=[.!?])\s+', text)
    censored_sentences = [s for s in sentences if not any(kw in s.lower() for kw in censor_keywords)]
    return " ".join(censored_sentences)

# Helper functions for parsing nested data

def extract_ent_notes(notes_list):
    """Extracts text from relevant ENT outpatient progress notes."""
    if not isinstance(notes_list, list): return ""
    ent_notes_text = []
    for note in notes_list:
        note_type = note.get('type', '').lower()
        author = note.get('author', '').lower()
        text = note.get('text', '').lower()
        if "progress note, outpatient" in note_type and ('ent' in author or 'otolaryngology' in author or 'ent' in text or 'otolaryngology' in text):
            ent_notes_text.append(note.get('text', ''))
    return "\n---\n".join(ent_notes_text)

def check_list_for_keywords(items, key_name, keywords):
    """Generic function to check if any keyword exists in a list of dictionaries."""
    if not isinstance(items, list): return False
    return any(keyword in str(item.get(key_name, '')).lower() for item in items for keyword in keywords)

def process_procedures(procedures_list):
    """Checks for specific diagnostic and surgical CPT codes."""
    if not isinstance(procedures_list, list): return False, False
    had_surgery = False
    had_endoscopy = False
    for proc in procedures_list:
        if proc.get('code_type') == 'CPT':
            code = proc.get('code')
            if code in SURGERY_CPT_CODES:
                had_surgery = True
            if code in DIAGNOSTIC_ENDOSCOPY_CPT_CODES:
                had_endoscopy = True
    return had_surgery, had_endoscopy


def extract_radiology_report(reports_list):
    """Extracts text from relevant sinus CT reports."""
    if not isinstance(reports_list, list): return ""
    report_texts = []
    for report in reports_list:
        report_type = str(report.get('type', '')).lower()
        title = str(report.get('title', '')).lower()
        # Ensure we only include relevant CT reports of the sinuses
        if 'ct' in report_type and ('sinus' in title or 'paranasal' in title or 'nasal' in title):
            report_texts.append(report.get('text', ''))
    return "\n---\n".join(filter(None, report_texts))

In [None]:
# Preprocessing - Main Processing Loop

processed_patient_records = []

# Get a unique list of all patients from the demographics table
unique_patient_ids = dataframes['demographics']['patient_id'].unique()

for patient_id in tqdm(unique_patient_ids, desc="Processing Patients"):
    # --- Fetch all records for this specific patient from each table ---
    patient_notes = dataframes['clinical_note'][dataframes['clinical_note']['patient_id'] == patient_id]
    patient_procedures = dataframes['procedures'][dataframes['procedures']['patient_id'] == patient_id]
    patient_labs = dataframes['labs'][dataframes['labs']['patient_id'] == patient_id]
    patient_meds = dataframes['med_orders'][dataframes['med_orders']['patient_id'] == patient_id]
    patient_radiology = dataframes['radiology_report'][dataframes['radiology_report']['patient_id'] == patient_id]
    patient_demographics = dataframes['demographics'][dataframes['demographics']['patient_id'] == patient_id].iloc[-1] # Most recent record

    # --- Perform the longitudinal logic ---
    # Sort notes by date to create the history
    all_notes_sorted = patient_notes.sort_values(by='date', ascending=True)

    # Find the earliest surgery date for this patient
    earliest_surgery_date = pd.NaT
    surgeries = patient_procedures[patient_procedures['code'].isin(SURGERY_CPT_CODES)]
    if not surgeries.empty:
        earliest_surgery_date = pd.to_datetime(surgeries['date']).min()
        
    # Filter notes to only those occurring BEFORE the surgery
    if pd.notna(earliest_surgery_date):
        notes_for_llm_df = all_notes_sorted[pd.to_datetime(all_notes_sorted['date']) < earliest_surgery_date]
    else:
        notes_for_llm_df = all_notes_sorted

    # Filter for relevant ENT progress notes
    ent_notes_mask = notes_for_llm_df['type'].str.contains("progress note, outpatient", case=False, na=False) & \
                     (notes_for_llm_df['author'].str.contains("ent|otolaryngology", case=False, na=False) | \
                      notes_for_llm_df['text'].str.contains("ent|otolaryngology", case=False, na=False))
    
    final_ent_notes = notes_for_llm_df[ent_notes_mask]

    if final_ent_notes.empty:
        continue # Skip patient if they have no relevant notes before the decision point

    # Combine and censor notes
    longitudinal_summary = "\n---\n".join(final_ent_notes['text'].dropna())
    censored_summary = censor_surgical_plans(longitudinal_summary)
    
    # ***FIXED***: Use the helper function to properly filter for relevant CT Sinus reports
    radiology_text = extract_radiology_report(patient_radiology.to_dict('records'))

    # Combine data for individual patient
    processed_patient_records.append({
        'patient_id': patient_id,
        'age': patient_demographics.get('age'), 'race': patient_demographics.get('race'),
        'ethnicity': patient_demographics.get('ethnicity'), 'sex': patient_demographics.get('legal_sex'),
        'longitudinal_summary_statement': longitudinal_summary,
        'censored_summary_statement': censored_summary,
        'radiology_text': radiology_text,
        'has_radiology_report': bool(radiology_text),
        'Had_Surgery': pd.notna(earliest_surgery_date),
    })



In [None]:
# Preprocessing - New, Clean DataFrame!

cases_df = pd.DataFrame(processed_patient_records)
cases_df['case_index'] = cases_df.index
print(f"\nPreprocessing complete. Created {len(cases_df)} unique patient longitudinal records.")

# Report on filtering results
initial_patient_count = len(unique_patient_ids)
final_patient_count = len(cases_df)
excluded_cases = initial_patient_count - final_patient_count

print(f"\nPreprocessing complete.")
print(f"Started with {initial_patient_count} unique patients.")
print(f"Excluded {excluded_cases} records that did not meet criteria (e.g., no relevant pre-decision ENT notes).")
print(f"Created a clean dataset with {final_patient_count} valid patient cases for evaluation.")
print(f"Cases with relevant radiology reports: {cases_df['has_radiology_report'].sum()}")
print(f"Cases with recorded surgery: {cases_df['Had_Surgery'].sum()}")


In [None]:
# Analysis - Prompt Engineering

def generate_prompt(case):
    
    radiology_section = f"- Radiology Report: {case['radiology_text']}" if case['has_radiology_report'] else "- Radiology Report: Not available."

    """Generates a structured prompt for the LLM."""

    prompt = f"""
    You are an expert otolaryngologist evaluating a case of chronic sinusitis.
    Based ONLY on the information provided below, make a recommendation on sinus surgery.

    --- Case Details ---
    - Case Index: {case['case_index']}
    - Age: {case['age']}
    - Sex: {case['sex']}
    - Censored Clinical Summary from ENT Notes: {case['censored_summary_statement']}
    {radiology_section}
    
    ---
    
    Provide your response as a JSON object with three keys:
    1. "decision": Your recommendation, either "Yes" or "No".
    2. "confidence": Your confidence level from 1 (not confident) to 10 (very confident).
    3. "reasoning": A brief, 2-4 sentence explanation for your decision.
    """
    return prompt

In [None]:
# Analysis - Feed the data into OpenAI's model for evaluation - documented procedure)
print("\nStarting LLM evaluation for each case...")
results = []

# Iterate over the single, consolidated dataframe
for index, case in tqdm(cases_df.iterrows(), total=len(cases_df), desc="Evaluating Cases"):
    prompt = generate_prompt(case)
    response_text = query_openai(prompt)

    # Parse the JSON response and handle errors
    if response_text:
        try:
            # Clean potential markdown formatting
            if "```json" in response_text:
                response_text = response_text.split("```json")[1].split("```")[0]
            llm_output = json.loads(response_text)
            results.append({
                'case_index': case['case_index'],
                'llm_decision': llm_output.get('decision', 'Error'),
                'llm_confidence': llm_output.get('confidence', 0),
                'llm_reasoning': llm_output.get('reasoning', 'Parsing Error'),
            })
        except (json.JSONDecodeError, AttributeError):
            results.append({'case_index': case['case_index'], 'llm_decision': 'Error', 'llm_confidence': 0, 'llm_reasoning': f"Parsing Failed: {response_text}"})
    else: # Handle API call failures
         results.append({'case_index': case['case_index'], 'llm_decision': 'Error', 'llm_confidence': 0, 'llm_reasoning': "API Call Failed"})

# Merge LLM results back into the main DataFrame
results_df = pd.DataFrame(results)
cases_df = pd.merge(cases_df, results_df, on='case_index')


In [None]:
# Analysis - Evaluation Metrics

# Create a clean dataframe for evaluation, excluding any cases that had API/parsing errors.
eval_df = cases_df[cases_df['llm_decision'].isin(['Yes', 'No'])].copy()

if not eval_df.empty:
    print("\n--- LLM Performance Evaluation ---")

    # Map LLM decision ('Yes'/'No') to boolean for comparison. Using .lower() is robust.
    eval_df['llm_decision_bool'] = eval_df['llm_decision'].str.lower() == 'yes'

    #  Use 'Had_Surgery' (the correct column name) for comparison
    y_true = eval_df['Had_Surgery']
    y_pred = eval_df['llm_decision_bool']

    # Accuracy, Classification Report, and Confusion Matrix
    accuracy = accuracy_score(y_true, y_pred)
    print(f"Overall Accuracy: {accuracy:.2%}\n")
    print("Classification Report:")
    print(classification_report(y_true, y_pred, target_names=['No Surgery', 'Surgery'], zero_division=0))
    print("\nConfusion Matrix:")
    print(confusion_matrix(y_true, y_pred))

    # --- Confidence Analysis ---
    eval_df['is_correct'] = (y_true == y_pred)
    
    print("\n--- Confidence Analysis ---")
    print(f"Average confidence on CORRECT predictions: {eval_df[eval_df['is_correct']]['llm_confidence'].mean():.2f}")
    print(f"Average confidence on INCORRECT predictions: {eval_df[~eval_df['is_correct']]['llm_confidence'].mean():.2f}")

    # Break down confidence by prediction outcome (TP, TN, FP, FN)
    tp_mask = (eval_df['is_correct']) & (eval_df['Had_Surgery'])
    tn_mask = (eval_df['is_correct']) & (~eval_df['Had_Surgery'])
    fp_mask = (~eval_df['is_correct']) & (eval_df['llm_decision_bool'])
    fn_mask = (~eval_df['is_correct']) & (~eval_df['llm_decision_bool'])

    print("\n--- Average Confidence by Prediction Outcome ---")
    print(f"True Positives (Correctly recommended surgery):  {eval_df.loc[tp_mask, 'llm_confidence'].mean():.2f}")
    print(f"True Negatives (Correctly recommended no surgery): {eval_df.loc[tn_mask, 'llm_confidence'].mean():.2f}")
    print(f"False Positives (Wrongly recommended surgery):   {eval_df.loc[fp_mask, 'llm_confidence'].mean():.2f}")
    print(f"False Negatives (Missed needed surgery):       {eval_df.loc[fn_mask, 'llm_confidence'].mean():.2f}")

    # Plotting confidence distribution
    plt.figure(figsize=(10, 6))
    eval_df['llm_confidence'].hist(bins=np.arange(0.5, 11.5, 1), edgecolor='black', rwidth=0.8)
    plt.title('Distribution of LLM Confidence Scores')
    plt.xlabel('Confidence Score (1-10)')
    plt.ylabel('Number of Cases')
    plt.xticks(range(1, 11))
    plt.grid(axis='y', alpha=0.75)
    plt.show()

else:
    print("\nNo cases were successfully evaluated by the LLM. Cannot generate performance metrics.")

In [None]:
# Save the final complete dataframe to a CSV file.
full_results_path = "sinusitis_llm_full_results.csv"
cases_df.to_csv(full_results_path, index=False)
print(f"\nFull results with all columns saved to '{full_results_path}'")

# Save a sample of 200 cases to a separate CSV for human evaluation.
if len(cases_df) > 0:
    sample_path = "sinusitis_llm_human_review_sample.csv"
    # Ensure we don't try to sample more rows than exist
    sample_size = min(200, len(cases_df))
    cases_df.head(sample_size).to_csv(sample_path, index=False)
    print(f"A sample of {sample_size} cases for human review saved to '{sample_path}'")