# Extracting JSON of Correct/Final Diagnosis from Case Records

In [None]:
# --- Step 1: Install necessary library ---

!pip install PyMuPDF -q # -q for quiet installation

import os
import json
import re
import fitz  # PyMuPDF library
from google.colab import drive

# --- Functions (find_files, extract_text_from_pdf, extract_diagnosis_from_text) ---

def find_files(directory, extension=".pdf"):
    """Finds all files with a given extension in a directory, sorted alphabetically."""
    files = []
    try:
        if not os.path.isdir(directory):
             print(f"Error: Directory not found or not accessible: {directory}")
             return None
        # Sort files alphabetically for consistent processing order
        for filename in sorted(os.listdir(directory)):
            if filename.lower().endswith(extension.lower()):
                files.append(os.path.join(directory, filename))
        if not files:
             print(f"Warning: No files with extension '{extension}' found in directory '{directory}'.")
        return files
    except Exception as e:
        print(f"An error occurred listing directory contents: {e}")
        return None

def extract_text_from_pdf(filepath):
    """Extracts all text content from a PDF file."""
    try:
        doc = fitz.open(filepath)
        full_text = ""
        for page_num in range(len(doc)):
            page = doc.load_page(page_num)
            full_text += page.get_text("text") # Extract text from the page
        doc.close()
        # Basic text cleaning
        full_text = re.sub(r'\n\s*\n', '\n\n', full_text)
        full_text = re.sub(r' +', ' ', full_text)
        return full_text
    except fitz.fitz.FileNotFoundError:
        print(f"Error: File not found by PyMuPDF: {filepath}")
        return None
    except Exception as e:
        print(f"Error extracting text from PDF {os.path.basename(filepath)}: {e}")
        return None

def extract_diagnosis_from_text(full_text, marker="FINAL DIAGNOSIS", filename="Unknown"):
    """
    Extracts text following the 'FINAL DIAGNOSIS' marker, tailored for NEJM format.
    """
    if not full_text:
        return None

    marker_pattern = re.compile(r"^\s*" + re.escape(marker) + r"\s*$", re.IGNORECASE | re.MULTILINE)
    match = marker_pattern.search(full_text)

    if not match:
        return None # Marker not found

    start_index = match.end()
    text_after_marker = full_text[start_index:]
    lines = text_after_marker.splitlines()
    extracted_lines = []
    found_diagnosis_line = False

    for line in lines:
        stripped_line = line.strip()
        if stripped_line:
            stop_keywords = ["Disclosure forms provided", "References", "This case was presented"]
            if any(keyword.lower() in stripped_line.lower() for keyword in stop_keywords) and found_diagnosis_line:
                 break
            extracted_lines.append(stripped_line)
            found_diagnosis_line = True
        elif found_diagnosis_line:
            break
        if len(extracted_lines) > 5:
            break

    if not extracted_lines:
         return "" # Marker found but no text followed

    diagnosis_text = " ".join(extracted_lines)
    diagnosis_text = re.sub(r'\s+', ' ', diagnosis_text).strip()
    return diagnosis_text

# --- Configuration ---
DIAGNOSIS_MARKER = "FINAL DIAGNOSIS"
FILE_EXTENSION = ".pdf"
DEFAULT_OUTPUT_FILENAME = "ground_truth_diagnoses.json" # Name of the output file

# --- Main Execution ---
if __name__ == "__main__":
    # --- Google Drive Mounting ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print(f"Google Drive mounted. Your 'My Drive' is at: {DRIVE_ROOT}")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Get User Input for Folder Path using Colab File Browser ---
    print("\n--- Please specify the folder containing your PDF files ---")
    print("1. Look at the 'Files' panel on the left side of Colab.")
    print("2. Navigate through 'drive' -> 'MyDrive' to find your folder.")
    print("3. Right-click on the correct folder.")
    print("4. Select 'Copy path' from the menu.")
    print("5. Paste the copied path into the prompt below and press Enter.")
    print("-" * 60)
    input_directory_path = input("Paste the full path to the folder here: ")

    # --- Clean and Validate the Input Path ---
    INPUT_DIRECTORY = input_directory_path.strip()
    if not INPUT_DIRECTORY:
        print("Error: No path provided. Exiting.")
        exit()

    # Define Output Path (saving to the root of My Drive)
    OUTPUT_JSON_FILE = os.path.join(DRIVE_ROOT, DEFAULT_OUTPUT_FILENAME)

    print("-" * 30)
    print(f"Attempting to read files from: {INPUT_DIRECTORY}")
    print(f"Output JSON will be saved to: {OUTPUT_JSON_FILE}")
    print("-" * 30)

    # --- Check if Input Directory Exists ---
    if not os.path.isdir(INPUT_DIRECTORY):
        print(f"Error: The path you provided is not a valid directory: '{INPUT_DIRECTORY}'")
    else:
        # --- Find and Process Files ---
        case_files = find_files(INPUT_DIRECTORY, extension=FILE_EXTENSION)
        all_diagnoses = []
        files_processed_count = 0
        marker_not_found_count = 0
        extraction_error_count = 0
        total_files_attempted = 0

        if case_files: # Check if find_files returned a list
            total_files_attempted = len(case_files)
            print(f"Found {total_files_attempted} PDF files. Processing...")
            for idx, file_path in enumerate(case_files): # Use enumerate for progress count
                filename = os.path.basename(file_path)
                print(f"Processing ({idx + 1}/{total_files_attempted}): {filename}...")

                # Step 1: Extract text from PDF
                full_pdf_text = extract_text_from_pdf(file_path)

                if full_pdf_text:
                    # Step 2: Extract diagnosis using the refined logic
                    diagnosis = extract_diagnosis_from_text(full_pdf_text, DIAGNOSIS_MARKER, filename)

                    if diagnosis is not None: # None means marker wasn't found
                        # --- Generate case_id from filename (Exact base name) ---
                        case_id, _ = os.path.splitext(filename) # Gets filename without extension
                        # --- End case_id generation ---

                        all_diagnoses.append({
                            "case_id": case_id,
                            "correct_diagnosis": diagnosis # diagnosis can be ""
                        })
                        files_processed_count += 1
                        if diagnosis == "":
                             print(f"--> Warning: Marker found but no diagnosis text extracted for {filename}.")
                    else:
                        # Marker not found
                        print(f"--> Warning: Marker '{DIAGNOSIS_MARKER}' not found in {filename}.")
                        marker_not_found_count += 1
                else:
                     # Error message printed in extract_text_from_pdf
                     print(f"--> Error: PDF text extraction failed for {filename}.")
                     extraction_error_count += 1

            # --- Save Results to JSON in Google Drive ---
            print("-" * 30)
            print(f"Processing Summary:")
            print(f" - Files successfully processed (marker found): {files_processed_count}")
            print(f" - Files where marker was not found: {marker_not_found_count}")
            print(f" - Files with PDF extraction errors: {extraction_error_count}")
            print(f" - Total files attempted: {total_files_attempted}")
            print("-" * 30)

            if all_diagnoses:
                try:
                    # Save in the order files were processed (which is sorted alphabetically by filename)
                    with open(OUTPUT_JSON_FILE, 'w', encoding='utf-8') as f_out:
                        json.dump(all_diagnoses, f_out, indent=2, ensure_ascii=False)
                    print(f"Successfully generated JSON for {len(all_diagnoses)} cases.")
                    print(f"Results saved to: {OUTPUT_JSON_FILE}")
                except Exception as e:
                    print(f"\nError writing JSON output file: {e}")
            else:
                print("\nNo diagnoses were successfully extracted to save.")
        else:
             # Message handled by find_files or the initial directory check
             print("Processing complete. No PDF files found in the specified directory.")

# Reminder for the user
print("\nScript finished. IMPORTANT: Please review the generated JSON file carefully.")
print(f"Verify the extracted diagnoses and case_ids in '{OUTPUT_JSON_FILE}'.")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted. Your 'My Drive' is at: /content/drive/MyDrive/

--- Please specify the folder containing your PDF files ---
1. Look at the 'Files' panel on the left side of Colab.
2. Navigate through 'drive' -> 'MyDrive' to find your folder.
3. Right-click on the correct folder.
4. Select 'Copy path' from the menu.
5. Paste the copied path into the prompt below and press Enter.
------------------------------------------------------------
Paste the full path to the folder here: /content/drive/MyDrive/NEJM Case Records
------------------------------
Attempting to read files from: /content/drive/MyDrive/NEJM Case Records
Output JSON will be saved to: /content/drive/MyDrive/ground_truth_diagnoses.json
------------------------------
Found 25 PDF files. Processing...
Processing (1/25): NEJMcpc2100279.pdf...
Processing (2/25): NEJMcpc2300900.pdf...
Processing (3/25): NEJMcpc2309383.pdf...
Processing (4/25): NEJMcpc2309500.pdf...
Process

# Collating All Gemini 2.5 Pro Differential Diagnoses (Full Case Record)

In [None]:
# --- Step 1: Ensure necessary libraries are available ---

import os
import json
from google.colab import drive
import glob # Useful for finding files matching a pattern

# --- Function to find JSON files ---
def find_json_files(directory):
    """Finds all .json files in a given directory, sorted alphabetically."""
    files = []
    try:
        if not os.path.isdir(directory):
             print(f"Error: Directory not found or not accessible: {directory}")
             return None
        json_pattern = os.path.join(directory, '*.json')
        files = sorted(glob.glob(json_pattern))
        if not files:
             print(f"Warning: No .json files found in directory '{directory}'.")
        return files
    except Exception as e:
        print(f"An error occurred listing or searching directory contents: {e}")
        return None

# --- Function to parse a single prediction JSON file ---
def parse_prediction_file(filepath):
    """
    Parses a single Gemini prediction JSON file and extracts required fields,
    assigning rank based on order.

    Args:
        filepath (str): Path to the individual JSON file.

    Returns:
        dict: A dictionary containing the extracted 'case_id' and
              'differential_diagnosis' list (with only 'diagnosis' and
              'rank'), or None if parsing fails or required
              keys are missing.
    """
    filename = os.path.basename(filepath)
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # --- Extract required top-level keys ---
        case_id = data.get('case_id')
        diff_diag_list = data.get('differential_diagnosis')

        if case_id is None:
            print(f"Warning: Missing 'case_id' key in file: {filename}. Skipping.")
            return None
        if diff_diag_list is None:
            print(f"Warning: Missing 'differential_diagnosis' key in file: {filename} (Case ID: {case_id}). Skipping.")
            return None
        if not isinstance(diff_diag_list, list):
             print(f"Warning: 'differential_diagnosis' is not a list in file: {filename} (Case ID: {case_id}). Skipping.")
             return None

        # --- Process the differential diagnosis list ---
        processed_diff_diag = []
        # Use enumerate to get index (for rank) and item
        for index, item in enumerate(diff_diag_list):
            if not isinstance(item, dict):
                print(f"Warning: Item in 'differential_diagnosis' is not a dictionary in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            diagnosis = item.get('diagnosis')
            # We no longer need confidence_level = item.get('confidence_level')

            if diagnosis is None:
                print(f"Warning: Missing 'diagnosis' key within an item in 'differential_diagnosis' in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            # Calculate rank (1-based index)
            rank = index + 1

            processed_diff_diag.append({
                "diagnosis": diagnosis,
                "rank": rank  # Use 'rank' key with the calculated rank
            })

        # Return the structured data for this case
        return {
            "case_id": case_id,
            "differential_diagnosis": processed_diff_diag
        }

    except json.JSONDecodeError:
        print(f"Error: Invalid JSON structure in file: {filename}. Skipping.")
        return None
    except FileNotFoundError:
         print(f"Error: File not found during processing: {filepath}. Skipping.")
         return None
    except Exception as e:
        print(f"An unexpected error occurred processing file {filename}: {e}. Skipping.")
        return None

# --- Configuration ---
DEFAULT_OUTPUT_FILENAME = "gemini_2.5_pro_predictions.json" # Updated filename

# --- Main Execution ---
if __name__ == "__main__":
    # --- Google Drive Mounting ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print(f"Google Drive mounted. Your 'My Drive' is at: {DRIVE_ROOT}")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Get User Input for Folder Path ---
    print("\n--- Please specify the folder containing the 25 individual Gemini prediction JSON files ---")
    print("1. Use the 'Files' panel on the left to navigate to the folder.")
    print("2. Right-click on the folder.")
    print("3. Select 'Copy path'.")
    print("4. Paste the copied path below and press Enter.")
    print("-" * 60)
    input_directory_path = input("Paste the full path to the folder here: ")

    # --- Clean and Validate Input Path ---
    INPUT_DIRECTORY = input_directory_path.strip()
    if not INPUT_DIRECTORY:
        print("Error: No path provided. Exiting.")
        exit()

    # Define Output Path (saving to the root of My Drive)
    OUTPUT_JSON_FILE = os.path.join(DRIVE_ROOT, DEFAULT_OUTPUT_FILENAME)

    print("-" * 30)
    print(f"Reading individual JSON files from: {INPUT_DIRECTORY}")
    print(f"Collated output with ranks will be saved to: {OUTPUT_JSON_FILE}")
    print("-" * 30)

    # --- Check if Input Directory Exists ---
    if not os.path.isdir(INPUT_DIRECTORY):
        print(f"Error: The path provided is not a valid directory: '{INPUT_DIRECTORY}'")
    else:
        # --- Find and Process JSON Files ---
        prediction_files = find_json_files(INPUT_DIRECTORY)
        collated_predictions = []
        files_processed_count = 0
        files_skipped_count = 0
        total_files_found = 0

        if prediction_files: # Check if find_json_files returned a list
            total_files_found = len(prediction_files)
            print(f"Found {total_files_found} JSON files. Processing...")

            for idx, file_path in enumerate(prediction_files):
                filename = os.path.basename(file_path)
                print(f"Processing ({idx + 1}/{total_files_found}): {filename}...")

                # Parse the individual file
                parsed_data = parse_prediction_file(file_path)

                if parsed_data:
                    collated_predictions.append(parsed_data)
                    files_processed_count += 1
                else:
                    # Error/warning message already printed by parse_prediction_file
                    files_skipped_count += 1

            # --- Save Collated Results ---
            print("-" * 30)
            print(f"Processing Summary:")
            print(f" - Files successfully parsed and included: {files_processed_count}")
            print(f" - Files skipped due to errors or missing keys: {files_skipped_count}")
            print(f" - Total JSON files found: {total_files_found}")
            print("-" * 30)

            if collated_predictions:
                try:
                    # Optional: Sort the final list by case_id if needed
                    # collated_predictions.sort(key=lambda x: x.get('case_id', ''))

                    with open(OUTPUT_JSON_FILE, 'w', encoding='utf-8') as f_out:
                        json.dump(collated_predictions, f_out, indent=2, ensure_ascii=False)
                    print(f"Successfully generated collated predictions file with ranks.")
                    print(f"Results saved to: {OUTPUT_JSON_FILE}")
                except Exception as e:
                    print(f"\nError writing final JSON output file: {e}")
            else:
                print("\nNo data was successfully parsed from any file. Output file not created.")
        else:
             # Message handled by find_json_files or the initial directory check
             print("Processing complete. No JSON files found in the specified directory.")

# Reminder
print("\nScript finished.")
if files_processed_count > 0:
    print(f"Please check the generated file: {OUTPUT_JSON_FILE}")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted. Your 'My Drive' is at: /content/drive/MyDrive/

--- Please specify the folder containing the 25 individual Gemini prediction JSON files ---
1. Use the 'Files' panel on the left to navigate to the folder.
2. Right-click on the folder.
3. Select 'Copy path'.
4. Paste the copied path below and press Enter.
------------------------------------------------------------
Paste the full path to the folder here: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Gemini 2.5 Pro Full Case Records
------------------------------
Reading individual JSON files from: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Gemini 2.5 Pro Full Case Records
Collated output with ranks will be saved to: /content/drive/MyDrive/gemini_2.5_pro_predictions_ranked.json
------------------------------
Found 25 JSON files. Processing...
Processing (1/25): NEJMcpc2100279.json...
Processing (2

# Collating All Grok 3 Differential Diagnoses (Full Case Record)

In [None]:
# --- Step 1: Ensure necessary libraries are available ---

import os
import json
from google.colab import drive
import glob # Useful for finding files matching a pattern

# --- Function to find JSON files ---
def find_json_files(directory):
    """Finds all .json files in a given directory, sorted alphabetically."""
    files = []
    try:
        if not os.path.isdir(directory):
             print(f"Error: Directory not found or not accessible: {directory}")
             return None
        json_pattern = os.path.join(directory, '*.json')
        files = sorted(glob.glob(json_pattern))
        if not files:
             print(f"Warning: No .json files found in directory '{directory}'.")
        return files
    except Exception as e:
        print(f"An error occurred listing or searching directory contents: {e}")
        return None

# --- Function to parse a single prediction JSON file ---
def parse_prediction_file(filepath):
    """
    Parses a single Grok prediction JSON file and extracts required fields,
    assigning rank based on order.

    Args:
        filepath (str): Path to the individual JSON file.

    Returns:
        dict: A dictionary containing the extracted 'case_id' and
              'differential_diagnosis' list (with only 'diagnosis' and
              'rank'), or None if parsing fails or required
              keys are missing.
    """
    filename = os.path.basename(filepath)
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # --- Extract required top-level keys ---
        case_id = data.get('case_id')
        diff_diag_list = data.get('differential_diagnosis')

        if case_id is None:
            print(f"Warning: Missing 'case_id' key in file: {filename}. Skipping.")
            return None
        if diff_diag_list is None:
            print(f"Warning: Missing 'differential_diagnosis' key in file: {filename} (Case ID: {case_id}). Skipping.")
            return None
        if not isinstance(diff_diag_list, list):
             print(f"Warning: 'differential_diagnosis' is not a list in file: {filename} (Case ID: {case_id}). Skipping.")
             return None

        # --- Process the differential diagnosis list ---
        processed_diff_diag = []
        # Use enumerate to get index (for rank) and item
        for index, item in enumerate(diff_diag_list):
            if not isinstance(item, dict):
                print(f"Warning: Item in 'differential_diagnosis' is not a dictionary in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            diagnosis = item.get('diagnosis')
            # We no longer need confidence_level = item.get('confidence_level')

            if diagnosis is None:
                print(f"Warning: Missing 'diagnosis' key within an item in 'differential_diagnosis' in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            # Calculate rank (1-based index)
            rank = index + 1

            processed_diff_diag.append({
                "diagnosis": diagnosis,
                "rank": rank  # Use 'rank' key with the calculated rank
            })

        # Return the structured data for this case
        return {
            "case_id": case_id,
            "differential_diagnosis": processed_diff_diag
        }

    except json.JSONDecodeError:
        print(f"Error: Invalid JSON structure in file: {filename}. Skipping.")
        return None
    except FileNotFoundError:
         print(f"Error: File not found during processing: {filepath}. Skipping.")
         return None
    except Exception as e:
        print(f"An unexpected error occurred processing file {filename}: {e}. Skipping.")
        return None

# --- Configuration ---
DEFAULT_OUTPUT_FILENAME = "grok_3_predictions.json" # Updated filename

# --- Main Execution ---
if __name__ == "__main__":
    # --- Google Drive Mounting ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print(f"Google Drive mounted. Your 'My Drive' is at: {DRIVE_ROOT}")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Get User Input for Folder Path ---
    print("\n--- Please specify the folder containing the 25 individual Grok prediction JSON files ---")
    print("1. Use the 'Files' panel on the left to navigate to the folder.")
    print("2. Right-click on the folder.")
    print("3. Select 'Copy path'.")
    print("4. Paste the copied path below and press Enter.")
    print("-" * 60)
    input_directory_path = input("Paste the full path to the folder here: ")

    # --- Clean and Validate Input Path ---
    INPUT_DIRECTORY = input_directory_path.strip()
    if not INPUT_DIRECTORY:
        print("Error: No path provided. Exiting.")
        exit()

    # Define Output Path (saving to the root of My Drive)
    OUTPUT_JSON_FILE = os.path.join(DRIVE_ROOT, DEFAULT_OUTPUT_FILENAME)

    print("-" * 30)
    print(f"Reading individual JSON files from: {INPUT_DIRECTORY}")
    print(f"Collated output with ranks will be saved to: {OUTPUT_JSON_FILE}")
    print("-" * 30)

    # --- Check if Input Directory Exists ---
    if not os.path.isdir(INPUT_DIRECTORY):
        print(f"Error: The path provided is not a valid directory: '{INPUT_DIRECTORY}'")
    else:
        # --- Find and Process JSON Files ---
        prediction_files = find_json_files(INPUT_DIRECTORY)
        collated_predictions = []
        files_processed_count = 0
        files_skipped_count = 0
        total_files_found = 0

        if prediction_files: # Check if find_json_files returned a list
            total_files_found = len(prediction_files)
            print(f"Found {total_files_found} JSON files. Processing...")

            for idx, file_path in enumerate(prediction_files):
                filename = os.path.basename(file_path)
                print(f"Processing ({idx + 1}/{total_files_found}): {filename}...")

                # Parse the individual file
                parsed_data = parse_prediction_file(file_path)

                if parsed_data:
                    collated_predictions.append(parsed_data)
                    files_processed_count += 1
                else:
                    # Error/warning message already printed by parse_prediction_file
                    files_skipped_count += 1

            # --- Save Collated Results ---
            print("-" * 30)
            print(f"Processing Summary:")
            print(f" - Files successfully parsed and included: {files_processed_count}")
            print(f" - Files skipped due to errors or missing keys: {files_skipped_count}")
            print(f" - Total JSON files found: {total_files_found}")
            print("-" * 30)

            if collated_predictions:
                try:
                    # Optional: Sort the final list by case_id if needed
                    # collated_predictions.sort(key=lambda x: x.get('case_id', ''))

                    with open(OUTPUT_JSON_FILE, 'w', encoding='utf-8') as f_out:
                        json.dump(collated_predictions, f_out, indent=2, ensure_ascii=False)
                    print(f"Successfully generated collated predictions file with ranks.")
                    print(f"Results saved to: {OUTPUT_JSON_FILE}")
                except Exception as e:
                    print(f"\nError writing final JSON output file: {e}")
            else:
                print("\nNo data was successfully parsed from any file. Output file not created.")
        else:
             # Message handled by find_json_files or the initial directory check
             print("Processing complete. No JSON files found in the specified directory.")

# Reminder
print("\nScript finished.")
if files_processed_count > 0:
    print(f"Please check the generated file: {OUTPUT_JSON_FILE}")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted. Your 'My Drive' is at: /content/drive/MyDrive/

--- Please specify the folder containing the 25 individual Grok prediction JSON files ---
1. Use the 'Files' panel on the left to navigate to the folder.
2. Right-click on the folder.
3. Select 'Copy path'.
4. Paste the copied path below and press Enter.
------------------------------------------------------------
Paste the full path to the folder here: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Grok 3 Full Case Records
------------------------------
Reading individual JSON files from: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Grok 3 Full Case Records
Collated output with ranks will be saved to: /content/drive/MyDrive/grok_3_predictions.json
------------------------------
Found 25 JSON files. Processing...
Processing (1/25): NEJMcpc2100279.json...
Processing (2/25): NEJMcpc2300900.json...
Proc

# Collating All ChatGPT o4-mini-high Differential Diagnoses (Full Case Record)

In [5]:
# --- Step 1: Ensure necessary libraries are available ---

import os
import json
from google.colab import drive
import glob # Useful for finding files matching a pattern

# --- Function to find JSON files ---
def find_json_files(directory):
    """Finds all .json files in a given directory, sorted alphabetically."""
    files = []
    try:
        if not os.path.isdir(directory):
             print(f"Error: Directory not found or not accessible: {directory}")
             return None
        json_pattern = os.path.join(directory, '*.json')
        files = sorted(glob.glob(json_pattern))
        if not files:
             print(f"Warning: No .json files found in directory '{directory}'.")
        return files
    except Exception as e:
        print(f"An error occurred listing or searching directory contents: {e}")
        return None

# --- Function to parse a single prediction JSON file ---
def parse_prediction_file(filepath):
    """
    Parses a single ChatGPT prediction JSON file and extracts required fields,
    assigning rank based on order.

    Args:
        filepath (str): Path to the individual JSON file.

    Returns:
        dict: A dictionary containing the extracted 'case_id' and
              'differential_diagnosis' list (with only 'diagnosis' and
              'rank'), or None if parsing fails or required
              keys are missing.
    """
    filename = os.path.basename(filepath)
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # --- Extract required top-level keys ---
        case_id = data.get('case_id')
        diff_diag_list = data.get('differential_diagnosis')

        if case_id is None:
            print(f"Warning: Missing 'case_id' key in file: {filename}. Skipping.")
            return None
        if diff_diag_list is None:
            print(f"Warning: Missing 'differential_diagnosis' key in file: {filename} (Case ID: {case_id}). Skipping.")
            return None
        if not isinstance(diff_diag_list, list):
             print(f"Warning: 'differential_diagnosis' is not a list in file: {filename} (Case ID: {case_id}). Skipping.")
             return None

        # --- Process the differential diagnosis list ---
        processed_diff_diag = []
        # Use enumerate to get index (for rank) and item
        for index, item in enumerate(diff_diag_list):
            if not isinstance(item, dict):
                print(f"Warning: Item in 'differential_diagnosis' is not a dictionary in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            diagnosis = item.get('diagnosis')
            # We no longer need confidence_level = item.get('confidence_level')

            if diagnosis is None:
                print(f"Warning: Missing 'diagnosis' key within an item in 'differential_diagnosis' in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            # Calculate rank (1-based index)
            rank = index + 1

            processed_diff_diag.append({
                "diagnosis": diagnosis,
                "rank": rank  # Use 'rank' key with the calculated rank
            })

        # Return the structured data for this case
        return {
            "case_id": case_id,
            "differential_diagnosis": processed_diff_diag
        }

    except json.JSONDecodeError:
        print(f"Error: Invalid JSON structure in file: {filename}. Skipping.")
        return None
    except FileNotFoundError:
         print(f"Error: File not found during processing: {filepath}. Skipping.")
         return None
    except Exception as e:
        print(f"An unexpected error occurred processing file {filename}: {e}. Skipping.")
        return None

# --- Configuration ---
DEFAULT_OUTPUT_FILENAME = "chatGPT_o4-mini-high_predictions.json" # Updated filename

# --- Main Execution ---
if __name__ == "__main__":
    # --- Google Drive Mounting ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print(f"Google Drive mounted. Your 'My Drive' is at: {DRIVE_ROOT}")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Get User Input for Folder Path ---
    print("\n--- Please specify the folder containing the 25 individual ChatGPT prediction JSON files ---")
    print("1. Use the 'Files' panel on the left to navigate to the folder.")
    print("2. Right-click on the folder.")
    print("3. Select 'Copy path'.")
    print("4. Paste the copied path below and press Enter.")
    print("-" * 60)
    input_directory_path = input("Paste the full path to the folder here: ")

    # --- Clean and Validate Input Path ---
    INPUT_DIRECTORY = input_directory_path.strip()
    if not INPUT_DIRECTORY:
        print("Error: No path provided. Exiting.")
        exit()

    # Define Output Path (saving to the root of My Drive)
    OUTPUT_JSON_FILE = os.path.join(DRIVE_ROOT, DEFAULT_OUTPUT_FILENAME)

    print("-" * 30)
    print(f"Reading individual JSON files from: {INPUT_DIRECTORY}")
    print(f"Collated output with ranks will be saved to: {OUTPUT_JSON_FILE}")
    print("-" * 30)

    # --- Check if Input Directory Exists ---
    if not os.path.isdir(INPUT_DIRECTORY):
        print(f"Error: The path provided is not a valid directory: '{INPUT_DIRECTORY}'")
    else:
        # --- Find and Process JSON Files ---
        prediction_files = find_json_files(INPUT_DIRECTORY)
        collated_predictions = []
        files_processed_count = 0
        files_skipped_count = 0
        total_files_found = 0

        if prediction_files: # Check if find_json_files returned a list
            total_files_found = len(prediction_files)
            print(f"Found {total_files_found} JSON files. Processing...")

            for idx, file_path in enumerate(prediction_files):
                filename = os.path.basename(file_path)
                print(f"Processing ({idx + 1}/{total_files_found}): {filename}...")

                # Parse the individual file
                parsed_data = parse_prediction_file(file_path)

                if parsed_data:
                    collated_predictions.append(parsed_data)
                    files_processed_count += 1
                else:
                    # Error/warning message already printed by parse_prediction_file
                    files_skipped_count += 1

            # --- Save Collated Results ---
            print("-" * 30)
            print(f"Processing Summary:")
            print(f" - Files successfully parsed and included: {files_processed_count}")
            print(f" - Files skipped due to errors or missing keys: {files_skipped_count}")
            print(f" - Total JSON files found: {total_files_found}")
            print("-" * 30)

            if collated_predictions:
                try:
                    # Optional: Sort the final list by case_id if needed
                    # collated_predictions.sort(key=lambda x: x.get('case_id', ''))

                    with open(OUTPUT_JSON_FILE, 'w', encoding='utf-8') as f_out:
                        json.dump(collated_predictions, f_out, indent=2, ensure_ascii=False)
                    print(f"Successfully generated collated predictions file with ranks.")
                    print(f"Results saved to: {OUTPUT_JSON_FILE}")
                except Exception as e:
                    print(f"\nError writing final JSON output file: {e}")
            else:
                print("\nNo data was successfully parsed from any file. Output file not created.")
        else:
             # Message handled by find_json_files or the initial directory check
             print("Processing complete. No JSON files found in the specified directory.")

# Reminder
print("\nScript finished.")
if files_processed_count > 0:
    print(f"Please check the generated file: {OUTPUT_JSON_FILE}")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted. Your 'My Drive' is at: /content/drive/MyDrive/

--- Please specify the folder containing the 25 individual ChatGPT prediction JSON files ---
1. Use the 'Files' panel on the left to navigate to the folder.
2. Right-click on the folder.
3. Select 'Copy path'.
4. Paste the copied path below and press Enter.
------------------------------------------------------------
Paste the full path to the folder here: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Full Case Records/Untitled folder
------------------------------
Reading individual JSON files from: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Full Case Records/Untitled folder
Collated output with ranks will be saved to: /content/drive/MyDrive/chatGPT_o4-mini-high_predictions.json
------------------------------
Found 1 JSON files. Processing...
Processing (1/1): Sample Conversation.json...
--------

# Collating All Perplexity Research Differential Diagnoses (Full Case Record)

In [None]:
# --- Step 1: Ensure necessary libraries are available ---

import os
import json
from google.colab import drive
import glob # Useful for finding files matching a pattern

# --- Function to find JSON files ---
def find_json_files(directory):
    """Finds all .json files in a given directory, sorted alphabetically."""
    files = []
    try:
        if not os.path.isdir(directory):
             print(f"Error: Directory not found or not accessible: {directory}")
             return None
        json_pattern = os.path.join(directory, '*.json')
        files = sorted(glob.glob(json_pattern))
        if not files:
             print(f"Warning: No .json files found in directory '{directory}'.")
        return files
    except Exception as e:
        print(f"An error occurred listing or searching directory contents: {e}")
        return None

# --- Function to parse a single prediction JSON file ---
def parse_prediction_file(filepath):
    """
    Parses a single Perplexity prediction JSON file and extracts required fields,
    assigning rank based on order.

    Args:
        filepath (str): Path to the individual JSON file.

    Returns:
        dict: A dictionary containing the extracted 'case_id' and
              'differential_diagnosis' list (with only 'diagnosis' and
              'rank'), or None if parsing fails or required
              keys are missing.
    """
    filename = os.path.basename(filepath)
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # --- Extract required top-level keys ---
        case_id = data.get('case_id')
        diff_diag_list = data.get('differential_diagnosis')

        if case_id is None:
            print(f"Warning: Missing 'case_id' key in file: {filename}. Skipping.")
            return None
        if diff_diag_list is None:
            print(f"Warning: Missing 'differential_diagnosis' key in file: {filename} (Case ID: {case_id}). Skipping.")
            return None
        if not isinstance(diff_diag_list, list):
             print(f"Warning: 'differential_diagnosis' is not a list in file: {filename} (Case ID: {case_id}). Skipping.")
             return None

        # --- Process the differential diagnosis list ---
        processed_diff_diag = []
        # Use enumerate to get index (for rank) and item
        for index, item in enumerate(diff_diag_list):
            if not isinstance(item, dict):
                print(f"Warning: Item in 'differential_diagnosis' is not a dictionary in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            diagnosis = item.get('diagnosis')
            # We no longer need confidence_level = item.get('confidence_level')

            if diagnosis is None:
                print(f"Warning: Missing 'diagnosis' key within an item in 'differential_diagnosis' in file: {filename} (Case ID: {case_id}). Skipping item.")
                continue

            # Calculate rank (1-based index)
            rank = index + 1

            processed_diff_diag.append({
                "diagnosis": diagnosis,
                "rank": rank  # Use 'rank' key with the calculated rank
            })

        # Return the structured data for this case
        return {
            "case_id": case_id,
            "differential_diagnosis": processed_diff_diag
        }

    except json.JSONDecodeError:
        print(f"Error: Invalid JSON structure in file: {filename}. Skipping.")
        return None
    except FileNotFoundError:
         print(f"Error: File not found during processing: {filepath}. Skipping.")
         return None
    except Exception as e:
        print(f"An unexpected error occurred processing file {filename}: {e}. Skipping.")
        return None

# --- Configuration ---
DEFAULT_OUTPUT_FILENAME = "perplexity_research_predictions.json" # Updated filename

# --- Main Execution ---
if __name__ == "__main__":
    # --- Google Drive Mounting ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print(f"Google Drive mounted. Your 'My Drive' is at: {DRIVE_ROOT}")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Get User Input for Folder Path ---
    print("\n--- Please specify the folder containing the 25 individual Perplexity prediction JSON files ---")
    print("1. Use the 'Files' panel on the left to navigate to the folder.")
    print("2. Right-click on the folder.")
    print("3. Select 'Copy path'.")
    print("4. Paste the copied path below and press Enter.")
    print("-" * 60)
    input_directory_path = input("Paste the full path to the folder here: ")

    # --- Clean and Validate Input Path ---
    INPUT_DIRECTORY = input_directory_path.strip()
    if not INPUT_DIRECTORY:
        print("Error: No path provided. Exiting.")
        exit()

    # Define Output Path (saving to the root of My Drive)
    OUTPUT_JSON_FILE = os.path.join(DRIVE_ROOT, DEFAULT_OUTPUT_FILENAME)

    print("-" * 30)
    print(f"Reading individual JSON files from: {INPUT_DIRECTORY}")
    print(f"Collated output with ranks will be saved to: {OUTPUT_JSON_FILE}")
    print("-" * 30)

    # --- Check if Input Directory Exists ---
    if not os.path.isdir(INPUT_DIRECTORY):
        print(f"Error: The path provided is not a valid directory: '{INPUT_DIRECTORY}'")
    else:
        # --- Find and Process JSON Files ---
        prediction_files = find_json_files(INPUT_DIRECTORY)
        collated_predictions = []
        files_processed_count = 0
        files_skipped_count = 0
        total_files_found = 0

        if prediction_files: # Check if find_json_files returned a list
            total_files_found = len(prediction_files)
            print(f"Found {total_files_found} JSON files. Processing...")

            for idx, file_path in enumerate(prediction_files):
                filename = os.path.basename(file_path)
                print(f"Processing ({idx + 1}/{total_files_found}): {filename}...")

                # Parse the individual file
                parsed_data = parse_prediction_file(file_path)

                if parsed_data:
                    collated_predictions.append(parsed_data)
                    files_processed_count += 1
                else:
                    # Error/warning message already printed by parse_prediction_file
                    files_skipped_count += 1

            # --- Save Collated Results ---
            print("-" * 30)
            print(f"Processing Summary:")
            print(f" - Files successfully parsed and included: {files_processed_count}")
            print(f" - Files skipped due to errors or missing keys: {files_skipped_count}")
            print(f" - Total JSON files found: {total_files_found}")
            print("-" * 30)

            if collated_predictions:
                try:
                    # Optional: Sort the final list by case_id if needed
                    # collated_predictions.sort(key=lambda x: x.get('case_id', ''))

                    with open(OUTPUT_JSON_FILE, 'w', encoding='utf-8') as f_out:
                        json.dump(collated_predictions, f_out, indent=2, ensure_ascii=False)
                    print(f"Successfully generated collated predictions file with ranks.")
                    print(f"Results saved to: {OUTPUT_JSON_FILE}")
                except Exception as e:
                    print(f"\nError writing final JSON output file: {e}")
            else:
                print("\nNo data was successfully parsed from any file. Output file not created.")
        else:
             # Message handled by find_json_files or the initial directory check
             print("Processing complete. No JSON files found in the specified directory.")

# Reminder
print("\nScript finished.")
if files_processed_count > 0:
    print(f"Please check the generated file: {OUTPUT_JSON_FILE}")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted. Your 'My Drive' is at: /content/drive/MyDrive/

--- Please specify the folder containing the 25 individual Perplexity prediction JSON files ---
1. Use the 'Files' panel on the left to navigate to the folder.
2. Right-click on the folder.
3. Select 'Copy path'.
4. Paste the copied path below and press Enter.
------------------------------------------------------------
Paste the full path to the folder here: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Perplexity Research Full Case Records
------------------------------
Reading individual JSON files from: /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Perplexity Research Full Case Records
Collated output with ranks will be saved to: /content/drive/MyDrive/perplexity_research_predictions.json
------------------------------
Found 25 JSON files. Processing...
Processing (1/25): NEJMcpc2100279.json...
P

# Calculating Mean Reciprocal Rank and Discounted Cumulative Gain for Gemini 2.5 Pro Responses

## Semantic Similarity using Embeddings

In [None]:
# --- Step 1: Install necessary libraries ---
!pip install sentence-transformers scikit-learn numpy -q

import json
import math
import numpy as np
from sentence_transformers import SentenceTransformer, util # For embeddings and similarity
import os
from google.colab import drive

def load_json_data(filepath):
    """Loads JSON data from a file."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {filepath}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred loading {filepath}: {e}")
        return None

def calculate_metrics_semantic(predictions_data, ground_truth_data, model, similarity_threshold=0.85, k=5):
    """
    Calculates MRR, Mean NDCG@k, and Top-1 Accuracy using semantic similarity.

    Args:
        predictions_data (list): List of prediction dicts (case_id, differential_diagnosis list).
        ground_truth_data (list): List of ground truth dicts (case_id, correct_diagnosis).
        model (SentenceTransformer): The loaded sentence embedding model.
        similarity_threshold (float): Cosine similarity threshold to consider diagnoses a match.
        k (int): The cutoff for calculating NDCG.

    Returns:
        dict: A dictionary containing the calculated metrics, or None if an error occurs.
    """
    if not predictions_data or not ground_truth_data:
        print("Error: Input data is missing.")
        return None

    # Convert lists to dictionaries keyed by case_id for efficient lookup
    try:
        predictions_dict = {item['case_id']: item['differential_diagnosis'] for item in predictions_data}
        ground_truth_dict = {item['case_id']: item['correct_diagnosis'] for item in ground_truth_data}
    except KeyError as e:
        print(f"Error: Missing expected key '{e}' while structuring data. Check JSON formats.")
        return None
    except TypeError as e:
        print(f"Error: Problem accessing data, likely incorrect JSON structure: {e}")
        return None


    reciprocal_ranks = []
    ndcg_scores = []
    top1_correct_count = 0
    processed_cases = 0
    cases_with_match = 0

    common_case_ids = set(predictions_dict.keys()) & set(ground_truth_dict.keys())

    if not common_case_ids:
        print("Error: No common case_ids found between the two files.")
        return None

    print(f"Processing {len(common_case_ids)} common cases...")

    for case_id in common_case_ids:
        if case_id not in predictions_dict or case_id not in ground_truth_dict:
            print(f"Warning: Skipping case {case_id} - missing in one of the files.")
            continue

        correct_diagnosis_text = ground_truth_dict[case_id]
        predicted_diagnoses_list = predictions_dict[case_id]

        if not correct_diagnosis_text:
            print(f"Warning: Empty correct_diagnosis for case {case_id}. Skipping.")
            continue
        if not predicted_diagnoses_list:
            print(f"Warning: Empty differential_diagnosis list for case {case_id}. Assigning zero scores.")
            reciprocal_ranks.append(0)
            ndcg_scores.append(0)
            processed_cases += 1
            continue

        # Ensure predictions are sorted by rank (should be if generated correctly)
        predicted_diagnoses_list.sort(key=lambda x: x.get('rank', float('inf')))
        predicted_texts = [item.get('diagnosis', '') for item in predicted_diagnoses_list]
        if not any(predicted_texts): # Check if all predicted texts are empty
             print(f"Warning: All predicted diagnosis texts are empty for case {case_id}. Assigning zero scores.")
             reciprocal_ranks.append(0)
             ndcg_scores.append(0)
             processed_cases += 1
             continue


        # --- Semantic Similarity Calculation ---
        try:
            # Generate embeddings
            correct_embedding = model.encode(correct_diagnosis_text, convert_to_tensor=True)
            predicted_embeddings = model.encode(predicted_texts, convert_to_tensor=True)

            # Calculate cosine similarities
            cosine_scores = util.cos_sim(correct_embedding, predicted_embeddings)[0] # Get the similarity scores list

            # Find the rank of the *first* prediction exceeding the threshold
            found_rank = 0
            match_found_this_case = False
            for i in range(len(predicted_diagnoses_list)):
                similarity = cosine_scores[i].item() # Get similarity for the i-th prediction
                # print(f"  Case {case_id}, Rank {i+1}: '{predicted_texts[i]}' vs '{correct_diagnosis_text}' -> Sim: {similarity:.4f}") # Debug print
                if similarity >= similarity_threshold:
                    found_rank = i + 1  # 1-based rank
                    match_found_this_case = True
                    # print(f"    --> Match found at rank {found_rank}!") # Debug print
                    break # Stop at the first match in the ranked list

            if match_found_this_case:
                cases_with_match += 1

        except Exception as e:
             print(f"Error during embedding/similarity calculation for case {case_id}: {e}. Assigning zero scores.")
             found_rank = 0
             match_found_this_case = False
        # --- End Semantic Similarity ---


        # --- Calculate Metrics based on found_rank ---
        # RR
        rr = 1 / found_rank if found_rank > 0 else 0
        reciprocal_ranks.append(rr)

        # DCG@k and NDCG@k
        dcg = 0.0
        # Check relevance based on the *single* match found by similarity
        for i in range(min(k, len(predicted_diagnoses_list))):
            rank_in_list = i + 1
            # Relevance is 1 only if this item *is* the one identified as the match
            relevance = 1 if rank_in_list == found_rank else 0
            dcg += relevance / math.log2(rank_in_list + 1)

        # IDCG@k is 1 if a match was possible (i.e., if found_rank > 0) and k >= found_rank
        # Simpler: IDCG is the score if the best match was at rank 1.
        idcg = 1.0 / math.log2(1 + 1) if found_rank > 0 else 0.0 # IDCG is 1 if a match exists

        ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(ndcg)

        # Top-1 Accuracy
        if found_rank == 1:
            top1_correct_count += 1

        processed_cases += 1

    # --- Aggregate Results ---
    if processed_cases == 0:
        print("Error: No cases were successfully processed.")
        return None

    mean_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0
    mean_ndcg_at_k = np.mean(ndcg_scores) if ndcg_scores else 0
    top_1_accuracy = top1_correct_count / processed_cases if processed_cases > 0 else 0

    return {
        "total_cases_processed": processed_cases,
        "cases_with_semantic_match_found": cases_with_match,
        "similarity_threshold": similarity_threshold,
        "embedding_model_used": model.config.name_or_path if hasattr(model, 'config') and hasattr(model.config, 'name_or_path') else 'Unknown',
        "mrr": mean_mrr,
        f"mean_ndcg@{k}": mean_ndcg_at_k,
        "top_1_accuracy": top_1_accuracy,
        "k_for_ndcg": k
    }

# --- Configuration ---
# Model choice: 'all-MiniLM-L6-v2' is fast and general purpose.
# Consider 'emilyalsentzer/Bio_ClinicalBERT' or other biomedical models if available via sentence-transformers
# or if you install transformers separately, but start with a general one.
MODEL_NAME = 'all-MiniLM-L6-v2'
SIMILARITY_THRESHOLD = 0.75 # Adjust this threshold based on results (0.8-0.9 is common)
NDCG_K = 5 # Evaluate NDCG up to rank 5

# --- Main Execution ---
if __name__ == "__main__":
    # --- Mount Drive ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print("Google Drive mounted.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Get File Paths ---
    print("\n--- Specify Input JSON File Paths ---")
    print("Use the Colab file browser (left panel) to copy the *full path* for each file.")

    ground_truth_file_path = input("1. Paste the full path to 'ground_truth_diagnoses.json': ").strip()
    predictions_file_path = input("2. Paste the full path to 'gemini_2.5_pro_predictions.json': ").strip()

    if not ground_truth_file_path or not predictions_file_path:
        print("Error: One or both file paths were not provided. Exiting.")
        exit()

    # --- Load Data ---
    print("\nLoading data...")
    ground_truth_data = load_json_data(ground_truth_file_path)
    predictions_data = load_json_data(predictions_file_path)

    if ground_truth_data is None or predictions_data is None:
        print("Failed to load data. Exiting.")
        exit()

    # --- Load Embedding Model ---
    print(f"\nLoading sentence transformer model: {MODEL_NAME}...")
    try:
        model = SentenceTransformer(MODEL_NAME)
        print("Model loaded successfully.")
    except Exception as e:
        print(f"Error loading sentence transformer model: {e}")
        print("Please ensure the model name is correct and you have internet access.")
        exit()

    # --- Calculate Metrics ---
    print(f"\nCalculating metrics using semantic similarity (Threshold: {SIMILARITY_THRESHOLD})...")
    results = calculate_metrics_semantic(
        predictions_data,
        ground_truth_data,
        model,
        similarity_threshold=SIMILARITY_THRESHOLD,
        k=NDCG_K
    )

    # --- Display Results ---
    if results:
        print("\n--- Evaluation Results ---")
        print(f"Embedding Model: {results['embedding_model_used']}")
        print(f"Similarity Threshold: {results['similarity_threshold']}")
        print(f"Total Cases Processed: {results['total_cases_processed']}")
        print(f"Cases Where a Match Was Found: {results['cases_with_semantic_match_found']} ({results['cases_with_semantic_match_found']/results['total_cases_processed']:.1%})")
        print("-" * 25)
        print(f"MRR (Mean Reciprocal Rank): {results['mrr']:.4f}")

        # Corrected NDCG print statement:
        k_value = results['k_for_ndcg']
        ndcg_key = f"mean_ndcg@{k_value}"
        if ndcg_key in results:
             print(f"Mean NDCG@{k_value}: {results[ndcg_key]:.4f}")
        else:
             print(f"Mean NDCG@{k_value}: Key '{ndcg_key}' not found in results.") # Error handling

        print(f"Top-1 Accuracy (Correct diagnosis is semantically matched at rank 1): {results['top_1_accuracy']:.2%}")
        print("------------------------")
    else:
        print("\nFailed to calculate metrics.")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted.

--- Specify Input JSON File Paths ---
Use the Colab file browser (left panel) to copy the *full path* for each file.
1. Paste the full path to 'ground_truth_diagnoses.json': /content/drive/MyDrive/ground_truth_diagnoses.json
2. Paste the full path to 'gemini_2.5_pro_predictions.json': /content/drive/MyDrive/gemini_2.5_pro_predictions_ranked.json

Loading data...

Loading sentence transformer model: all-MiniLM-L6-v2...
Model loaded successfully.

Calculating metrics using semantic similarity (Threshold: 0.75)...
Processing 25 common cases...

--- Evaluation Results ---
Embedding Model: Unknown
Similarity Threshold: 0.75
Total Cases Processed: 25
Cases Where a Match Was Found: 6 (24.0%)
-------------------------
MRR (Mean Reciprocal Rank): 0.2133
Mean NDCG@5: 0.2200
Top-1 Accuracy (Correct diagnosis is semantically matched at rank 1): 20.00%
------------------------


## LLM-based Comparison

In [None]:
# --- Step 1: Install necessary libraries ---
!pip install google-generativeai numpy -q

import os
import json
import math
import numpy as np
import time
from google.colab import drive
from google.colab import userdata # For securely getting the API key
import google.generativeai as genai
import glob # Useful for finding files matching a pattern

# --- Function to load JSON data ---
def load_json_data(filepath):
    """Loads JSON data from a file."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {filepath}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred loading {filepath}: {e}")
        return None

# --- Function to ask Gemini for semantic match ---
# (Using the version with temperature=0.0)
def check_diagnosis_match_with_gemini(ground_truth_dx, predicted_dx, model, retries=2, delay=5):
    """
    Asks the Gemini model if two diagnoses semantically match.
    """
    if not predicted_dx:
        return False

    prompt = f"""Compare the following two medical diagnoses.
Diagnosis 1 (Ground Truth): "{ground_truth_dx}"
Diagnosis 2 (Prediction): "{predicted_dx}"

Do these two diagnoses refer to essentially the same condition, a very close subtype, or is the prediction clearly encompassed within the ground truth, such that the prediction could be considered correct in this context?

Answer ONLY with the word 'YES' or 'NO'.
"""
    response = None # Initialize response to None
    for attempt in range(retries + 1):
        try:
            response = model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(temperature=0.0)
                )

            cleaned_response = response.text.strip().upper().replace(".", "")
            if cleaned_response == "YES":
                return True
            elif cleaned_response == "NO":
                return False
            else:
                # Fallback check
                if "YES" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains YES ('{response.text}'). Treating as YES.") # Optional debug
                     return True
                elif "NO" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
                else:
                     # print(f"      Warning: LLM response was not clear YES/NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
        except Exception as e:
            block_reason = ""
            # Try to access potential block reason safely
            try:
                 # Check if response exists and has the necessary attributes before accessing them
                 if response and hasattr(response, 'prompt_feedback') and response.prompt_feedback and hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
                      block_reason = f" (Block Reason: {response.prompt_feedback.block_reason})"
            except AttributeError:
                 pass # Ignore if feedback attributes don't exist or response is None

            print(f"      Error calling Gemini API (Attempt {attempt + 1}/{retries + 1}): {e}{block_reason}")
            if attempt < retries:
                print(f"      Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("      Max retries reached. Treating as NO match.")
                return False
    return False


# --- Function to calculate metrics AND find matching ranks using LLM ---
def calculate_metrics_and_ranks_llm(predictions_data, ground_truth_data, model, k=5):
    """
    Calculates MRR, Mean NDCG@k, Top-1 Accuracy using LLM for matching,
    AND stores details of the first match found for validation.

    Args:
        predictions_data (list): List of prediction dicts.
        ground_truth_data (list): List of ground truth dicts.
        model (genai.GenerativeModel): The initialized Gemini model.
        k (int): The cutoff for calculating NDCG.

    Returns:
        tuple: (dict: metrics, dict: individual match details) or (None, None)
    """
    if not predictions_data or not ground_truth_data:
        print("Error: Input data is missing.")
        return None, None

    try:
        predictions_dict = {item['case_id']: item['differential_diagnosis'] for item in predictions_data}
        ground_truth_dict = {item['case_id']: item['correct_diagnosis'] for item in ground_truth_data}
    except KeyError as e:
        print(f"Error: Missing expected key '{e}' while structuring data.")
        return None, None
    except TypeError as e:
        print(f"Error: Problem accessing data, likely incorrect JSON structure: {e}")
        return None, None

    reciprocal_ranks = []
    ndcg_scores = []
    top1_correct_count = 0
    processed_cases = 0
    cases_with_match = 0
    individual_match_details = {} # To store rank and text for validation

    common_case_ids = sorted(list(set(predictions_dict.keys()) & set(ground_truth_dict.keys())))

    if not common_case_ids:
        print("Error: No common case_ids found.")
        return None, None

    print(f"\nProcessing {len(common_case_ids)} common cases using LLM ({model.model_name})...")
    print("This will take time...\n")

    for case_id in common_case_ids:
        print(f"--- Processing Case: {case_id} ---")
        if case_id not in predictions_dict or case_id not in ground_truth_dict:
            print("  Skipped - Data Missing in one of the files.")
            individual_match_details[case_id] = {"status": "Skipped - Data Missing"}
            print("-" * 30)
            continue

        correct_diagnosis_text = ground_truth_dict[case_id]
        predicted_diagnoses_list = predictions_dict[case_id]

        if not correct_diagnosis_text:
            print("  Skipped - No Ground Truth diagnosis text.")
            individual_match_details[case_id] = {"status": "Skipped - No Ground Truth"}
            print("-" * 30)
            continue
        if not predicted_diagnoses_list:
            print("  No Match Possible (Prediction list is empty). Assigning zero scores.")
            reciprocal_ranks.append(0)
            ndcg_scores.append(0)
            individual_match_details[case_id] = {"status": "No Match Found (Empty Predictions)"}
            processed_cases += 1
            print("-" * 30)
            continue

        # Ensure predictions are sorted by rank
        predicted_diagnoses_list.sort(key=lambda x: x.get('rank', float('inf')))

        # --- LLM Comparison Loop ---
        found_rank = 0
        first_match_details = {}
        for i, prediction_item in enumerate(predicted_diagnoses_list):
            current_rank = i + 1
            predicted_text = prediction_item.get('diagnosis', '')
            print(f"  Rank {current_rank}: Comparing...") # Keep output concise

            # Call LLM to check match
            is_match = check_diagnosis_match_with_gemini(correct_diagnosis_text, predicted_text, model)
            time.sleep(1.1) # IMPORTANT: Rate limiting

            if is_match:
                print(f"    --> Match found by LLM at rank {current_rank}!")
                found_rank = current_rank
                first_match_details = {
                    "rank": found_rank,
                    "ground_truth": correct_diagnosis_text,
                    "prediction": predicted_text
                }
                break # Stop at the first match

        # --- Store results for this case ---
        if found_rank > 0:
             individual_match_details[case_id] = first_match_details
             cases_with_match += 1
        else:
             individual_match_details[case_id] = {"status": "No Match Found"}
             print("  No Match Found for this case.")


        # --- Calculate Metrics based on found_rank ---
        rr = 1 / found_rank if found_rank > 0 else 0
        reciprocal_ranks.append(rr)

        dcg = 0.0
        for i in range(min(k, len(predicted_diagnoses_list))):
            rank_in_list = i + 1
            relevance = 1 if rank_in_list == found_rank else 0
            dcg += relevance / math.log2(rank_in_list + 1)

        idcg = 1.0 / math.log2(1 + 1) if found_rank > 0 else 0.0
        ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(ndcg)

        if found_rank == 1:
            top1_correct_count += 1

        processed_cases += 1
        print("-" * 30) # Separator between cases

    # --- Aggregate Results ---
    if processed_cases == 0:
        print("Error: No cases were successfully processed.")
        return None, None

    mean_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0
    mean_ndcg_at_k = np.mean(ndcg_scores) if ndcg_scores else 0
    top_1_accuracy = top1_correct_count / processed_cases if processed_cases > 0 else 0

    metrics = {
        "total_cases_processed": processed_cases,
        "cases_with_llm_match_found": cases_with_match,
        "llm_model_used": model.model_name,
        "mrr": mean_mrr,
        f"mean_ndcg@{k}": mean_ndcg_at_k,
        "top_1_accuracy": top_1_accuracy,
        "k_for_ndcg": k
    }
    return metrics, individual_match_details

# --- Configuration ---
LLM_MODEL_NAME = 'gemini-2.5-pro-preview-03-25' # Your specified model
NDCG_K = 5 # Rank cutoff for NDCG calculation

# --- Main Execution ---
if __name__ == "__main__":
    # --- Mount Drive ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print("Google Drive mounted.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Configure Gemini API ---
    try:
        GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
        if not GOOGLE_API_KEY:
            raise ValueError("API Key not found in Colab Secrets")
        genai.configure(api_key=GOOGLE_API_KEY)
        print("Gemini API configured.")
    except Exception as e:
        print(f"Error configuring Gemini API: {e}")
        print("Please ensure you have set the 'GOOGLE_API_KEY' secret in Colab.")
        exit()

    # --- Initialize the Gemini Model ---
    try:
        print(f"Initializing Gemini model: {LLM_MODEL_NAME}...")
        model = genai.GenerativeModel(LLM_MODEL_NAME)
        print("Model initialized.")
    except Exception as e:
        print(f"Error initializing Gemini model: {e}")
        exit()

    # --- Get File Paths ---
    print("\n--- Specify Input JSON File Paths ---")
    print("Use the Colab file browser (left panel) to copy the *full path* for each file.")
    ground_truth_file_path = input("1. Paste the full path to 'ground_truth_diagnoses.json': ").strip()
    predictions_file_path = input("2. Paste the full path to 'gemini_2.5_pro_predictions.json': ").strip() # Use the ranked one

    if not ground_truth_file_path or not predictions_file_path:
        print("Error: One or both file paths were not provided. Exiting.")
        exit()

    # --- Load Data ---
    print("\nLoading data...")
    ground_truth_data = load_json_data(ground_truth_file_path)
    predictions_data = load_json_data(predictions_file_path)

    if ground_truth_data is None or predictions_data is None:
        print("Failed to load data. Exiting.")
        exit()

    # --- Calculate Metrics and Find Ranks using LLM ---
    metrics_results, individual_match_info = calculate_metrics_and_ranks_llm(
        predictions_data,
        ground_truth_data,
        model,
        k=NDCG_K
    )

    # --- Display Individual Match Details (for Validation) ---
    if individual_match_info:
        print("\n" + "="*60)
        print("--- Individual Case Match Details (Rank of First LLM Match) ---")
        print("="*60)
        for case_id, details in individual_match_info.items():
            print(f"{case_id}:")
            if "rank" in details: # Check if a match was found
                print(f"  Rank {details['rank']}: '{details['ground_truth']}' vs '{details['prediction']}'")
            else:
                print(f"  {details.get('status', 'Unknown Status')}") # Print status like 'No Match Found' or 'Skipped'
            print("-" * 30)
    else:
        print("\nCould not retrieve individual match details.")


    # --- Display Overall Metrics ---
    if metrics_results:
        print("\n" + "="*60)
        print("--- Overall Evaluation Results (LLM-based) ---")
        print("="*60)
        print(f"LLM Model Used: {metrics_results['llm_model_used']}")
        print(f"Total Cases Processed: {metrics_results['total_cases_processed']}")
        print(f"Cases Where LLM Found a Match: {metrics_results['cases_with_llm_match_found']} ({metrics_results['cases_with_llm_match_found']/metrics_results['total_cases_processed']:.1%})")
        print("-" * 25)
        print(f"MRR (Mean Reciprocal Rank): {metrics_results['mrr']:.4f}")

        k_value = metrics_results['k_for_ndcg']
        ndcg_key = f"mean_ndcg@{k_value}"
        if ndcg_key in metrics_results:
             print(f"Mean NDCG@{k_value}: {metrics_results[ndcg_key]:.4f}")
        else:
             print(f"Mean NDCG@{k_value}: Key '{ndcg_key}' not found in results.")

        print(f"Top-1 Accuracy (LLM match at rank 1): {metrics_results['top_1_accuracy']:.2%}")
        print("="*60)
    else:
        print("\nFailed to calculate overall metrics.")

    print("\nScript finished.")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted.
Gemini API configured.
Initializing Gemini model: gemini-2.5-pro-preview-03-25...
Model initialized.

--- Specify Input JSON File Paths ---
Use the Colab file browser (left panel) to copy the *full path* for each file.
1. Paste the full path to 'ground_truth_diagnoses.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/ground_truth_diagnoses.json
2. Paste the full path to 'gemini_2.5_pro_predictions.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/gemini_2.5_pro_predictions.json

Loading data...

Processing 25 common cases using LLM (models/gemini-2.5-pro-preview-03-25)...
This will take time...

--- Processing Case: NEJMcpc2100279 ---
  Rank 1: Comparing...
  Rank 2: Comparing...
  Rank 3: Comparing...
  Rank 4: Comparing...
  Rank 5: Comparing...
  No Match Found for this case.
------------------------------
--- Processing Case: NEJMcpc2300

# Calculating Mean Reciprocal Rank and Discounted Cumulative Gain for Grok 3 Responses

## LLM-based Comparison

In [None]:
# --- Step 1: Install necessary libraries ---
!pip install google-generativeai numpy -q

import os
import json
import math
import numpy as np
import time
from google.colab import drive
from google.colab import userdata # For securely getting the API key
import google.generativeai as genai
import glob # Useful for finding files matching a pattern

# --- Function to load JSON data ---
def load_json_data(filepath):
    """Loads JSON data from a file."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {filepath}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred loading {filepath}: {e}")
        return None

# --- Function to ask Gemini for semantic match ---
# (Using the version with temperature=0.0)
def check_diagnosis_match_with_gemini(ground_truth_dx, predicted_dx, model, retries=2, delay=5):
    """
    Asks the Gemini model if two diagnoses semantically match.
    """
    if not predicted_dx:
        return False

    prompt = f"""Compare the following two medical diagnoses.
Diagnosis 1 (Ground Truth): "{ground_truth_dx}"
Diagnosis 2 (Prediction): "{predicted_dx}"

Do these two diagnoses refer to essentially the same condition, a very close subtype, or is the prediction clearly encompassed within the ground truth, such that the prediction could be considered correct in this context?

Answer ONLY with the word 'YES' or 'NO'.
"""
    response = None # Initialize response to None
    for attempt in range(retries + 1):
        try:
            response = model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(temperature=0.0)
                )

            cleaned_response = response.text.strip().upper().replace(".", "")
            if cleaned_response == "YES":
                return True
            elif cleaned_response == "NO":
                return False
            else:
                # Fallback check
                if "YES" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains YES ('{response.text}'). Treating as YES.") # Optional debug
                     return True
                elif "NO" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
                else:
                     # print(f"      Warning: LLM response was not clear YES/NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
        except Exception as e:
            block_reason = ""
            # Try to access potential block reason safely
            try:
                 # Check if response exists and has the necessary attributes before accessing them
                 if response and hasattr(response, 'prompt_feedback') and response.prompt_feedback and hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
                      block_reason = f" (Block Reason: {response.prompt_feedback.block_reason})"
            except AttributeError:
                 pass # Ignore if feedback attributes don't exist or response is None

            print(f"      Error calling Gemini API (Attempt {attempt + 1}/{retries + 1}): {e}{block_reason}")
            if attempt < retries:
                print(f"      Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("      Max retries reached. Treating as NO match.")
                return False
    return False


# --- Function to calculate metrics AND find matching ranks using LLM ---
def calculate_metrics_and_ranks_llm(predictions_data, ground_truth_data, model, k=5):
    """
    Calculates MRR, Mean NDCG@k, Top-1 Accuracy using LLM for matching,
    AND stores details of the first match found for validation.

    Args:
        predictions_data (list): List of prediction dicts.
        ground_truth_data (list): List of ground truth dicts.
        model (genai.GenerativeModel): The initialized Gemini model.
        k (int): The cutoff for calculating NDCG.

    Returns:
        tuple: (dict: metrics, dict: individual match details) or (None, None)
    """
    if not predictions_data or not ground_truth_data:
        print("Error: Input data is missing.")
        return None, None

    try:
        predictions_dict = {item['case_id']: item['differential_diagnosis'] for item in predictions_data}
        ground_truth_dict = {item['case_id']: item['correct_diagnosis'] for item in ground_truth_data}
    except KeyError as e:
        print(f"Error: Missing expected key '{e}' while structuring data.")
        return None, None
    except TypeError as e:
        print(f"Error: Problem accessing data, likely incorrect JSON structure: {e}")
        return None, None

    reciprocal_ranks = []
    ndcg_scores = []
    top1_correct_count = 0
    processed_cases = 0
    cases_with_match = 0
    individual_match_details = {} # To store rank and text for validation

    common_case_ids = sorted(list(set(predictions_dict.keys()) & set(ground_truth_dict.keys())))

    if not common_case_ids:
        print("Error: No common case_ids found.")
        return None, None

    print(f"\nProcessing {len(common_case_ids)} common cases using LLM ({model.model_name})...")
    print("This will take time...\n")

    for case_id in common_case_ids:
        print(f"--- Processing Case: {case_id} ---")
        if case_id not in predictions_dict or case_id not in ground_truth_dict:
            print("  Skipped - Data Missing in one of the files.")
            individual_match_details[case_id] = {"status": "Skipped - Data Missing"}
            print("-" * 30)
            continue

        correct_diagnosis_text = ground_truth_dict[case_id]
        predicted_diagnoses_list = predictions_dict[case_id]

        if not correct_diagnosis_text:
            print("  Skipped - No Ground Truth diagnosis text.")
            individual_match_details[case_id] = {"status": "Skipped - No Ground Truth"}
            print("-" * 30)
            continue
        if not predicted_diagnoses_list:
            print("  No Match Possible (Prediction list is empty). Assigning zero scores.")
            reciprocal_ranks.append(0)
            ndcg_scores.append(0)
            individual_match_details[case_id] = {"status": "No Match Found (Empty Predictions)"}
            processed_cases += 1
            print("-" * 30)
            continue

        # Ensure predictions are sorted by rank
        predicted_diagnoses_list.sort(key=lambda x: x.get('rank', float('inf')))

        # --- LLM Comparison Loop ---
        found_rank = 0
        first_match_details = {}
        for i, prediction_item in enumerate(predicted_diagnoses_list):
            current_rank = i + 1
            predicted_text = prediction_item.get('diagnosis', '')
            print(f"  Rank {current_rank}: Comparing...") # Keep output concise

            # Call LLM to check match
            is_match = check_diagnosis_match_with_gemini(correct_diagnosis_text, predicted_text, model)
            time.sleep(1.1) # IMPORTANT: Rate limiting

            if is_match:
                print(f"    --> Match found by LLM at rank {current_rank}!")
                found_rank = current_rank
                first_match_details = {
                    "rank": found_rank,
                    "ground_truth": correct_diagnosis_text,
                    "prediction": predicted_text
                }
                break # Stop at the first match

        # --- Store results for this case ---
        if found_rank > 0:
             individual_match_details[case_id] = first_match_details
             cases_with_match += 1
        else:
             individual_match_details[case_id] = {"status": "No Match Found"}
             print("  No Match Found for this case.")


        # --- Calculate Metrics based on found_rank ---
        rr = 1 / found_rank if found_rank > 0 else 0
        reciprocal_ranks.append(rr)

        dcg = 0.0
        for i in range(min(k, len(predicted_diagnoses_list))):
            rank_in_list = i + 1
            relevance = 1 if rank_in_list == found_rank else 0
            dcg += relevance / math.log2(rank_in_list + 1)

        idcg = 1.0 / math.log2(1 + 1) if found_rank > 0 else 0.0
        ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(ndcg)

        if found_rank == 1:
            top1_correct_count += 1

        processed_cases += 1
        print("-" * 30) # Separator between cases

    # --- Aggregate Results ---
    if processed_cases == 0:
        print("Error: No cases were successfully processed.")
        return None, None

    mean_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0
    mean_ndcg_at_k = np.mean(ndcg_scores) if ndcg_scores else 0
    top_1_accuracy = top1_correct_count / processed_cases if processed_cases > 0 else 0

    metrics = {
        "total_cases_processed": processed_cases,
        "cases_with_llm_match_found": cases_with_match,
        "llm_model_used": model.model_name,
        "mrr": mean_mrr,
        f"mean_ndcg@{k}": mean_ndcg_at_k,
        "top_1_accuracy": top_1_accuracy,
        "k_for_ndcg": k
    }
    return metrics, individual_match_details

# --- Configuration ---
LLM_MODEL_NAME = 'gemini-2.5-pro-preview-03-25' # Your specified model
NDCG_K = 5 # Rank cutoff for NDCG calculation

# --- Main Execution ---
if __name__ == "__main__":
    # --- Mount Drive ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print("Google Drive mounted.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Configure Gemini API ---
    try:
        GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
        if not GOOGLE_API_KEY:
            raise ValueError("API Key not found in Colab Secrets")
        genai.configure(api_key=GOOGLE_API_KEY)
        print("Gemini API configured.")
    except Exception as e:
        print(f"Error configuring Gemini API: {e}")
        print("Please ensure you have set the 'GOOGLE_API_KEY' secret in Colab.")
        exit()

    # --- Initialize the Gemini Model ---
    try:
        print(f"Initializing Gemini model: {LLM_MODEL_NAME}...")
        model = genai.GenerativeModel(LLM_MODEL_NAME)
        print("Model initialized.")
    except Exception as e:
        print(f"Error initializing Gemini model: {e}")
        exit()

    # --- Get File Paths ---
    print("\n--- Specify Input JSON File Paths ---")
    print("Use the Colab file browser (left panel) to copy the *full path* for each file.")
    ground_truth_file_path = input("1. Paste the full path to 'ground_truth_diagnoses.json': ").strip()
    predictions_file_path = input("2. Paste the full path to 'grok_3_predictions.json': ").strip() # Use the ranked one

    if not ground_truth_file_path or not predictions_file_path:
        print("Error: One or both file paths were not provided. Exiting.")
        exit()

    # --- Load Data ---
    print("\nLoading data...")
    ground_truth_data = load_json_data(ground_truth_file_path)
    predictions_data = load_json_data(predictions_file_path)

    if ground_truth_data is None or predictions_data is None:
        print("Failed to load data. Exiting.")
        exit()

    # --- Calculate Metrics and Find Ranks using LLM ---
    metrics_results, individual_match_info = calculate_metrics_and_ranks_llm(
        predictions_data,
        ground_truth_data,
        model,
        k=NDCG_K
    )

    # --- Display Individual Match Details (for Validation) ---
    if individual_match_info:
        print("\n" + "="*60)
        print("--- Individual Case Match Details (Rank of First LLM Match) ---")
        print("="*60)
        for case_id, details in individual_match_info.items():
            print(f"{case_id}:")
            if "rank" in details: # Check if a match was found
                print(f"  Rank {details['rank']}: '{details['ground_truth']}' vs '{details['prediction']}'")
            else:
                print(f"  {details.get('status', 'Unknown Status')}") # Print status like 'No Match Found' or 'Skipped'
            print("-" * 30)
    else:
        print("\nCould not retrieve individual match details.")


    # --- Display Overall Metrics ---
    if metrics_results:
        print("\n" + "="*60)
        print("--- Overall Evaluation Results (LLM-based) ---")
        print("="*60)
        print(f"LLM Model Used: {metrics_results['llm_model_used']}")
        print(f"Total Cases Processed: {metrics_results['total_cases_processed']}")
        print(f"Cases Where LLM Found a Match: {metrics_results['cases_with_llm_match_found']} ({metrics_results['cases_with_llm_match_found']/metrics_results['total_cases_processed']:.1%})")
        print("-" * 25)
        print(f"MRR (Mean Reciprocal Rank): {metrics_results['mrr']:.4f}")

        k_value = metrics_results['k_for_ndcg']
        ndcg_key = f"mean_ndcg@{k_value}"
        if ndcg_key in metrics_results:
             print(f"Mean NDCG@{k_value}: {metrics_results[ndcg_key]:.4f}")
        else:
             print(f"Mean NDCG@{k_value}: Key '{ndcg_key}' not found in results.")

        print(f"Top-1 Accuracy (LLM match at rank 1): {metrics_results['top_1_accuracy']:.2%}")
        print("="*60)
    else:
        print("\nFailed to calculate overall metrics.")

    print("\nScript finished.")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted.
Gemini API configured.
Initializing Gemini model: gemini-2.5-pro-preview-03-25...
Model initialized.

--- Specify Input JSON File Paths ---
Use the Colab file browser (left panel) to copy the *full path* for each file.
1. Paste the full path to 'ground_truth_diagnoses.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/ground_truth_diagnoses.json
2. Paste the full path to 'grok_3_predictions.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/grok_3_predictions.json

Loading data...

Processing 25 common cases using LLM (models/gemini-2.5-pro-preview-03-25)...
This will take time...

--- Processing Case: NEJMcpc2100279 ---
  Rank 1: Comparing...
    --> Match found by LLM at rank 1!
------------------------------
--- Processing Case: NEJMcpc2300900 ---
  Rank 1: Comparing...
  Rank 2: Comparing...
  Rank 3: Comparing...
  Rank 4: Comparing...
  

# Calculating Mean Reciprocal Rank and Discounted Cumulative Gain for ChatGPT o4-mini-high Responses

In [None]:
# --- Step 1: Install necessary libraries ---
!pip install google-generativeai numpy -q

import os
import json
import math
import numpy as np
import time
from google.colab import drive
from google.colab import userdata # For securely getting the API key
import google.generativeai as genai
import glob # Useful for finding files matching a pattern

# --- Function to load JSON data ---
def load_json_data(filepath):
    """Loads JSON data from a file."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {filepath}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred loading {filepath}: {e}")
        return None

# --- Function to ask Gemini for semantic match ---
# (Using the version with temperature=0.0)
def check_diagnosis_match_with_gemini(ground_truth_dx, predicted_dx, model, retries=2, delay=5):
    """
    Asks the Gemini model if two diagnoses semantically match.
    """
    if not predicted_dx:
        return False

    prompt = f"""Compare the following two medical diagnoses.
Diagnosis 1 (Ground Truth): "{ground_truth_dx}"
Diagnosis 2 (Prediction): "{predicted_dx}"

Do these two diagnoses refer to essentially the same condition, a very close subtype, or is the prediction clearly encompassed within the ground truth, such that the prediction could be considered correct in this context?

Answer ONLY with the word 'YES' or 'NO'.
"""
    response = None # Initialize response to None
    for attempt in range(retries + 1):
        try:
            response = model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(temperature=0.0)
                )

            cleaned_response = response.text.strip().upper().replace(".", "")
            if cleaned_response == "YES":
                return True
            elif cleaned_response == "NO":
                return False
            else:
                # Fallback check
                if "YES" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains YES ('{response.text}'). Treating as YES.") # Optional debug
                     return True
                elif "NO" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
                else:
                     # print(f"      Warning: LLM response was not clear YES/NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
        except Exception as e:
            block_reason = ""
            # Try to access potential block reason safely
            try:
                 # Check if response exists and has the necessary attributes before accessing them
                 if response and hasattr(response, 'prompt_feedback') and response.prompt_feedback and hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
                      block_reason = f" (Block Reason: {response.prompt_feedback.block_reason})"
            except AttributeError:
                 pass # Ignore if feedback attributes don't exist or response is None

            print(f"      Error calling Gemini API (Attempt {attempt + 1}/{retries + 1}): {e}{block_reason}")
            if attempt < retries:
                print(f"      Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("      Max retries reached. Treating as NO match.")
                return False
    return False


# --- Function to calculate metrics AND find matching ranks using LLM ---
def calculate_metrics_and_ranks_llm(predictions_data, ground_truth_data, model, k=5):
    """
    Calculates MRR, Mean NDCG@k, Top-1 Accuracy using LLM for matching,
    AND stores details of the first match found for validation.

    Args:
        predictions_data (list): List of prediction dicts.
        ground_truth_data (list): List of ground truth dicts.
        model (genai.GenerativeModel): The initialized Gemini model.
        k (int): The cutoff for calculating NDCG.

    Returns:
        tuple: (dict: metrics, dict: individual match details) or (None, None)
    """
    if not predictions_data or not ground_truth_data:
        print("Error: Input data is missing.")
        return None, None

    try:
        predictions_dict = {item['case_id']: item['differential_diagnosis'] for item in predictions_data}
        ground_truth_dict = {item['case_id']: item['correct_diagnosis'] for item in ground_truth_data}
    except KeyError as e:
        print(f"Error: Missing expected key '{e}' while structuring data.")
        return None, None
    except TypeError as e:
        print(f"Error: Problem accessing data, likely incorrect JSON structure: {e}")
        return None, None

    reciprocal_ranks = []
    ndcg_scores = []
    top1_correct_count = 0
    processed_cases = 0
    cases_with_match = 0
    individual_match_details = {} # To store rank and text for validation

    common_case_ids = sorted(list(set(predictions_dict.keys()) & set(ground_truth_dict.keys())))

    if not common_case_ids:
        print("Error: No common case_ids found.")
        return None, None

    print(f"\nProcessing {len(common_case_ids)} common cases using LLM ({model.model_name})...")
    print("This will take time...\n")

    for case_id in common_case_ids:
        print(f"--- Processing Case: {case_id} ---")
        if case_id not in predictions_dict or case_id not in ground_truth_dict:
            print("  Skipped - Data Missing in one of the files.")
            individual_match_details[case_id] = {"status": "Skipped - Data Missing"}
            print("-" * 30)
            continue

        correct_diagnosis_text = ground_truth_dict[case_id]
        predicted_diagnoses_list = predictions_dict[case_id]

        if not correct_diagnosis_text:
            print("  Skipped - No Ground Truth diagnosis text.")
            individual_match_details[case_id] = {"status": "Skipped - No Ground Truth"}
            print("-" * 30)
            continue
        if not predicted_diagnoses_list:
            print("  No Match Possible (Prediction list is empty). Assigning zero scores.")
            reciprocal_ranks.append(0)
            ndcg_scores.append(0)
            individual_match_details[case_id] = {"status": "No Match Found (Empty Predictions)"}
            processed_cases += 1
            print("-" * 30)
            continue

        # Ensure predictions are sorted by rank
        predicted_diagnoses_list.sort(key=lambda x: x.get('rank', float('inf')))

        # --- LLM Comparison Loop ---
        found_rank = 0
        first_match_details = {}
        for i, prediction_item in enumerate(predicted_diagnoses_list):
            current_rank = i + 1
            predicted_text = prediction_item.get('diagnosis', '')
            print(f"  Rank {current_rank}: Comparing...") # Keep output concise

            # Call LLM to check match
            is_match = check_diagnosis_match_with_gemini(correct_diagnosis_text, predicted_text, model)
            time.sleep(1.1) # IMPORTANT: Rate limiting

            if is_match:
                print(f"    --> Match found by LLM at rank {current_rank}!")
                found_rank = current_rank
                first_match_details = {
                    "rank": found_rank,
                    "ground_truth": correct_diagnosis_text,
                    "prediction": predicted_text
                }
                break # Stop at the first match

        # --- Store results for this case ---
        if found_rank > 0:
             individual_match_details[case_id] = first_match_details
             cases_with_match += 1
        else:
             individual_match_details[case_id] = {"status": "No Match Found"}
             print("  No Match Found for this case.")


        # --- Calculate Metrics based on found_rank ---
        rr = 1 / found_rank if found_rank > 0 else 0
        reciprocal_ranks.append(rr)

        dcg = 0.0
        for i in range(min(k, len(predicted_diagnoses_list))):
            rank_in_list = i + 1
            relevance = 1 if rank_in_list == found_rank else 0
            dcg += relevance / math.log2(rank_in_list + 1)

        idcg = 1.0 / math.log2(1 + 1) if found_rank > 0 else 0.0
        ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(ndcg)

        if found_rank == 1:
            top1_correct_count += 1

        processed_cases += 1
        print("-" * 30) # Separator between cases

    # --- Aggregate Results ---
    if processed_cases == 0:
        print("Error: No cases were successfully processed.")
        return None, None

    mean_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0
    mean_ndcg_at_k = np.mean(ndcg_scores) if ndcg_scores else 0
    top_1_accuracy = top1_correct_count / processed_cases if processed_cases > 0 else 0

    metrics = {
        "total_cases_processed": processed_cases,
        "cases_with_llm_match_found": cases_with_match,
        "llm_model_used": model.model_name,
        "mrr": mean_mrr,
        f"mean_ndcg@{k}": mean_ndcg_at_k,
        "top_1_accuracy": top_1_accuracy,
        "k_for_ndcg": k
    }
    return metrics, individual_match_details

# --- Configuration ---
LLM_MODEL_NAME = 'gemini-2.5-pro-preview-03-25' # Your specified model
NDCG_K = 5 # Rank cutoff for NDCG calculation

# --- Main Execution ---
if __name__ == "__main__":
    # --- Mount Drive ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print("Google Drive mounted.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Configure Gemini API ---
    try:
        GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
        if not GOOGLE_API_KEY:
            raise ValueError("API Key not found in Colab Secrets")
        genai.configure(api_key=GOOGLE_API_KEY)
        print("Gemini API configured.")
    except Exception as e:
        print(f"Error configuring Gemini API: {e}")
        print("Please ensure you have set the 'GOOGLE_API_KEY' secret in Colab.")
        exit()

    # --- Initialize the Gemini Model ---
    try:
        print(f"Initializing Gemini model: {LLM_MODEL_NAME}...")
        model = genai.GenerativeModel(LLM_MODEL_NAME)
        print("Model initialized.")
    except Exception as e:
        print(f"Error initializing Gemini model: {e}")
        exit()

    # --- Get File Paths ---
    print("\n--- Specify Input JSON File Paths ---")
    print("Use the Colab file browser (left panel) to copy the *full path* for each file.")
    ground_truth_file_path = input("1. Paste the full path to 'ground_truth_diagnoses.json': ").strip()
    predictions_file_path = input("2. Paste the full path to 'chatGPT_o4-mini-high_predictions.json': ").strip() # Use the ranked one

    if not ground_truth_file_path or not predictions_file_path:
        print("Error: One or both file paths were not provided. Exiting.")
        exit()

    # --- Load Data ---
    print("\nLoading data...")
    ground_truth_data = load_json_data(ground_truth_file_path)
    predictions_data = load_json_data(predictions_file_path)

    if ground_truth_data is None or predictions_data is None:
        print("Failed to load data. Exiting.")
        exit()

    # --- Calculate Metrics and Find Ranks using LLM ---
    metrics_results, individual_match_info = calculate_metrics_and_ranks_llm(
        predictions_data,
        ground_truth_data,
        model,
        k=NDCG_K
    )

    # --- Display Individual Match Details (for Validation) ---
    if individual_match_info:
        print("\n" + "="*60)
        print("--- Individual Case Match Details (Rank of First LLM Match) ---")
        print("="*60)
        for case_id, details in individual_match_info.items():
            print(f"{case_id}:")
            if "rank" in details: # Check if a match was found
                print(f"  Rank {details['rank']}: '{details['ground_truth']}' vs '{details['prediction']}'")
            else:
                print(f"  {details.get('status', 'Unknown Status')}") # Print status like 'No Match Found' or 'Skipped'
            print("-" * 30)
    else:
        print("\nCould not retrieve individual match details.")


    # --- Display Overall Metrics ---
    if metrics_results:
        print("\n" + "="*60)
        print("--- Overall Evaluation Results (LLM-based) ---")
        print("="*60)
        print(f"LLM Model Used: {metrics_results['llm_model_used']}")
        print(f"Total Cases Processed: {metrics_results['total_cases_processed']}")
        print(f"Cases Where LLM Found a Match: {metrics_results['cases_with_llm_match_found']} ({metrics_results['cases_with_llm_match_found']/metrics_results['total_cases_processed']:.1%})")
        print("-" * 25)
        print(f"MRR (Mean Reciprocal Rank): {metrics_results['mrr']:.4f}")

        k_value = metrics_results['k_for_ndcg']
        ndcg_key = f"mean_ndcg@{k_value}"
        if ndcg_key in metrics_results:
             print(f"Mean NDCG@{k_value}: {metrics_results[ndcg_key]:.4f}")
        else:
             print(f"Mean NDCG@{k_value}: Key '{ndcg_key}' not found in results.")

        print(f"Top-1 Accuracy (LLM match at rank 1): {metrics_results['top_1_accuracy']:.2%}")
        print("="*60)
    else:
        print("\nFailed to calculate overall metrics.")

    print("\nScript finished.")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted.
Gemini API configured.
Initializing Gemini model: gemini-2.5-pro-preview-03-25...
Model initialized.

--- Specify Input JSON File Paths ---
Use the Colab file browser (left panel) to copy the *full path* for each file.
1. Paste the full path to 'ground_truth_diagnoses.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/ground_truth_diagnoses.json
2. Paste the full path to 'chatGPT_o4-mini-high_predictions.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/chatGPT_o4-mini-high_predictions.json

Loading data...

Processing 25 common cases using LLM (models/gemini-2.5-pro-preview-03-25)...
This will take time...

--- Processing Case: NEJMcpc2100279 ---
  Rank 1: Comparing...
  Rank 2: Comparing...
  Rank 3: Comparing...
  Rank 4: Comparing...
  Rank 5: Comparing...
  Rank 6: Comparing...
  No Match Found for this case.
----------------------------

# Calculating Mean Reciprocal Rank and Discounted Cumulative Gain for Perplexity Research Responses

In [6]:
# --- Step 1: Install necessary libraries ---
!pip install google-generativeai numpy -q

import os
import json
import math
import numpy as np
import time
from google.colab import drive
from google.colab import userdata # For securely getting the API key
import google.generativeai as genai
import glob # Useful for finding files matching a pattern

# --- Function to load JSON data ---
def load_json_data(filepath):
    """Loads JSON data from a file."""
    try:
        with open(filepath, 'r', encoding='utf-8') as f:
            data = json.load(f)
        return data
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except json.JSONDecodeError:
        print(f"Error: Could not decode JSON from {filepath}")
        return None
    except Exception as e:
        print(f"An unexpected error occurred loading {filepath}: {e}")
        return None

# --- Function to ask Gemini for semantic match ---
# (Using the version with temperature=0.0)
def check_diagnosis_match_with_gemini(ground_truth_dx, predicted_dx, model, retries=2, delay=5):
    """
    Asks the Gemini model if two diagnoses semantically match.
    """
    if not predicted_dx:
        return False

    prompt = f"""Compare the following two medical diagnoses.
Diagnosis 1 (Ground Truth): "{ground_truth_dx}"
Diagnosis 2 (Prediction): "{predicted_dx}"

Do these two diagnoses refer to essentially the same condition, a very close subtype, or is the prediction clearly encompassed within the ground truth, such that the prediction could be considered correct in this context?

Answer ONLY with the word 'YES' or 'NO'.
"""
    response = None # Initialize response to None
    for attempt in range(retries + 1):
        try:
            response = model.generate_content(
                prompt,
                generation_config=genai.types.GenerationConfig(temperature=0.0)
                )

            cleaned_response = response.text.strip().upper().replace(".", "")
            if cleaned_response == "YES":
                return True
            elif cleaned_response == "NO":
                return False
            else:
                # Fallback check
                if "YES" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains YES ('{response.text}'). Treating as YES.") # Optional debug
                     return True
                elif "NO" in cleaned_response:
                     # print(f"      Warning: LLM response unclear but contains NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
                else:
                     # print(f"      Warning: LLM response was not clear YES/NO ('{response.text}'). Treating as NO.") # Optional debug
                     return False
        except Exception as e:
            block_reason = ""
            # Try to access potential block reason safely
            try:
                 # Check if response exists and has the necessary attributes before accessing them
                 if response and hasattr(response, 'prompt_feedback') and response.prompt_feedback and hasattr(response.prompt_feedback, 'block_reason') and response.prompt_feedback.block_reason:
                      block_reason = f" (Block Reason: {response.prompt_feedback.block_reason})"
            except AttributeError:
                 pass # Ignore if feedback attributes don't exist or response is None

            print(f"      Error calling Gemini API (Attempt {attempt + 1}/{retries + 1}): {e}{block_reason}")
            if attempt < retries:
                print(f"      Retrying in {delay} seconds...")
                time.sleep(delay)
            else:
                print("      Max retries reached. Treating as NO match.")
                return False
    return False


# --- Function to calculate metrics AND find matching ranks using LLM ---
def calculate_metrics_and_ranks_llm(predictions_data, ground_truth_data, model, k=5):
    """
    Calculates MRR, Mean NDCG@k, Top-1 Accuracy using LLM for matching,
    AND stores details of the first match found for validation.

    Args:
        predictions_data (list): List of prediction dicts.
        ground_truth_data (list): List of ground truth dicts.
        model (genai.GenerativeModel): The initialized Gemini model.
        k (int): The cutoff for calculating NDCG.

    Returns:
        tuple: (dict: metrics, dict: individual match details) or (None, None)
    """
    if not predictions_data or not ground_truth_data:
        print("Error: Input data is missing.")
        return None, None

    try:
        predictions_dict = {item['case_id']: item['differential_diagnosis'] for item in predictions_data}
        ground_truth_dict = {item['case_id']: item['correct_diagnosis'] for item in ground_truth_data}
    except KeyError as e:
        print(f"Error: Missing expected key '{e}' while structuring data.")
        return None, None
    except TypeError as e:
        print(f"Error: Problem accessing data, likely incorrect JSON structure: {e}")
        return None, None

    reciprocal_ranks = []
    ndcg_scores = []
    top1_correct_count = 0
    processed_cases = 0
    cases_with_match = 0
    individual_match_details = {} # To store rank and text for validation

    common_case_ids = sorted(list(set(predictions_dict.keys()) & set(ground_truth_dict.keys())))

    if not common_case_ids:
        print("Error: No common case_ids found.")
        return None, None

    print(f"\nProcessing {len(common_case_ids)} common cases using LLM ({model.model_name})...")
    print("This will take time...\n")

    for case_id in common_case_ids:
        print(f"--- Processing Case: {case_id} ---")
        if case_id not in predictions_dict or case_id not in ground_truth_dict:
            print("  Skipped - Data Missing in one of the files.")
            individual_match_details[case_id] = {"status": "Skipped - Data Missing"}
            print("-" * 30)
            continue

        correct_diagnosis_text = ground_truth_dict[case_id]
        predicted_diagnoses_list = predictions_dict[case_id]

        if not correct_diagnosis_text:
            print("  Skipped - No Ground Truth diagnosis text.")
            individual_match_details[case_id] = {"status": "Skipped - No Ground Truth"}
            print("-" * 30)
            continue
        if not predicted_diagnoses_list:
            print("  No Match Possible (Prediction list is empty). Assigning zero scores.")
            reciprocal_ranks.append(0)
            ndcg_scores.append(0)
            individual_match_details[case_id] = {"status": "No Match Found (Empty Predictions)"}
            processed_cases += 1
            print("-" * 30)
            continue

        # Ensure predictions are sorted by rank
        predicted_diagnoses_list.sort(key=lambda x: x.get('rank', float('inf')))

        # --- LLM Comparison Loop ---
        found_rank = 0
        first_match_details = {}
        for i, prediction_item in enumerate(predicted_diagnoses_list):
            current_rank = i + 1
            predicted_text = prediction_item.get('diagnosis', '')
            print(f"  Rank {current_rank}: Comparing...") # Keep output concise

            # Call LLM to check match
            is_match = check_diagnosis_match_with_gemini(correct_diagnosis_text, predicted_text, model)
            time.sleep(1.1) # IMPORTANT: Rate limiting

            if is_match:
                print(f"    --> Match found by LLM at rank {current_rank}!")
                found_rank = current_rank
                first_match_details = {
                    "rank": found_rank,
                    "ground_truth": correct_diagnosis_text,
                    "prediction": predicted_text
                }
                break # Stop at the first match

        # --- Store results for this case ---
        if found_rank > 0:
             individual_match_details[case_id] = first_match_details
             cases_with_match += 1
        else:
             individual_match_details[case_id] = {"status": "No Match Found"}
             print("  No Match Found for this case.")


        # --- Calculate Metrics based on found_rank ---
        rr = 1 / found_rank if found_rank > 0 else 0
        reciprocal_ranks.append(rr)

        dcg = 0.0
        for i in range(min(k, len(predicted_diagnoses_list))):
            rank_in_list = i + 1
            relevance = 1 if rank_in_list == found_rank else 0
            dcg += relevance / math.log2(rank_in_list + 1)

        idcg = 1.0 / math.log2(1 + 1) if found_rank > 0 else 0.0
        ndcg = dcg / idcg if idcg > 0 else 0.0
        ndcg_scores.append(ndcg)

        if found_rank == 1:
            top1_correct_count += 1

        processed_cases += 1
        print("-" * 30) # Separator between cases

    # --- Aggregate Results ---
    if processed_cases == 0:
        print("Error: No cases were successfully processed.")
        return None, None

    mean_mrr = np.mean(reciprocal_ranks) if reciprocal_ranks else 0
    mean_ndcg_at_k = np.mean(ndcg_scores) if ndcg_scores else 0
    top_1_accuracy = top1_correct_count / processed_cases if processed_cases > 0 else 0

    metrics = {
        "total_cases_processed": processed_cases,
        "cases_with_llm_match_found": cases_with_match,
        "llm_model_used": model.model_name,
        "mrr": mean_mrr,
        f"mean_ndcg@{k}": mean_ndcg_at_k,
        "top_1_accuracy": top_1_accuracy,
        "k_for_ndcg": k
    }
    return metrics, individual_match_details

# --- Configuration ---
LLM_MODEL_NAME = 'gemini-2.5-pro-preview-03-25' # Your specified model
NDCG_K = 5 # Rank cutoff for NDCG calculation

# --- Main Execution ---
if __name__ == "__main__":
    # --- Mount Drive ---
    try:
        print("Mounting Google Drive...")
        drive.mount('/content/drive', force_remount=True)
        DRIVE_ROOT = "/content/drive/MyDrive/"
        print("Google Drive mounted.")
    except Exception as e:
        print(f"Error mounting Google Drive: {e}")
        exit()

    # --- Configure Gemini API ---
    try:
        GOOGLE_API_KEY = userdata.get('GOOGLE_API_KEY')
        if not GOOGLE_API_KEY:
            raise ValueError("API Key not found in Colab Secrets")
        genai.configure(api_key=GOOGLE_API_KEY)
        print("Gemini API configured.")
    except Exception as e:
        print(f"Error configuring Gemini API: {e}")
        print("Please ensure you have set the 'GOOGLE_API_KEY' secret in Colab.")
        exit()

    # --- Initialize the Gemini Model ---
    try:
        print(f"Initializing Gemini model: {LLM_MODEL_NAME}...")
        model = genai.GenerativeModel(LLM_MODEL_NAME)
        print("Model initialized.")
    except Exception as e:
        print(f"Error initializing Gemini model: {e}")
        exit()

    # --- Get File Paths ---
    print("\n--- Specify Input JSON File Paths ---")
    print("Use the Colab file browser (left panel) to copy the *full path* for each file.")
    ground_truth_file_path = input("1. Paste the full path to 'ground_truth_diagnoses.json': ").strip()
    predictions_file_path = input("2. Paste the full path to 'perplexity_research_predictions.json': ").strip() # Use the ranked one

    if not ground_truth_file_path or not predictions_file_path:
        print("Error: One or both file paths were not provided. Exiting.")
        exit()

    # --- Load Data ---
    print("\nLoading data...")
    ground_truth_data = load_json_data(ground_truth_file_path)
    predictions_data = load_json_data(predictions_file_path)

    if ground_truth_data is None or predictions_data is None:
        print("Failed to load data. Exiting.")
        exit()

    # --- Calculate Metrics and Find Ranks using LLM ---
    metrics_results, individual_match_info = calculate_metrics_and_ranks_llm(
        predictions_data,
        ground_truth_data,
        model,
        k=NDCG_K
    )

    # --- Display Individual Match Details (for Validation) ---
    if individual_match_info:
        print("\n" + "="*60)
        print("--- Individual Case Match Details (Rank of First LLM Match) ---")
        print("="*60)
        for case_id, details in individual_match_info.items():
            print(f"{case_id}:")
            if "rank" in details: # Check if a match was found
                print(f"  Rank {details['rank']}: '{details['ground_truth']}' vs '{details['prediction']}'")
            else:
                print(f"  {details.get('status', 'Unknown Status')}") # Print status like 'No Match Found' or 'Skipped'
            print("-" * 30)
    else:
        print("\nCould not retrieve individual match details.")


    # --- Display Overall Metrics ---
    if metrics_results:
        print("\n" + "="*60)
        print("--- Overall Evaluation Results (LLM-based) ---")
        print("="*60)
        print(f"LLM Model Used: {metrics_results['llm_model_used']}")
        print(f"Total Cases Processed: {metrics_results['total_cases_processed']}")
        print(f"Cases Where LLM Found a Match: {metrics_results['cases_with_llm_match_found']} ({metrics_results['cases_with_llm_match_found']/metrics_results['total_cases_processed']:.1%})")
        print("-" * 25)
        print(f"MRR (Mean Reciprocal Rank): {metrics_results['mrr']:.4f}")

        k_value = metrics_results['k_for_ndcg']
        ndcg_key = f"mean_ndcg@{k_value}"
        if ndcg_key in metrics_results:
             print(f"Mean NDCG@{k_value}: {metrics_results[ndcg_key]:.4f}")
        else:
             print(f"Mean NDCG@{k_value}: Key '{ndcg_key}' not found in results.")

        print(f"Top-1 Accuracy (LLM match at rank 1): {metrics_results['top_1_accuracy']:.2%}")
        print("="*60)
    else:
        print("\nFailed to calculate overall metrics.")

    print("\nScript finished.")

Mounting Google Drive...
Mounted at /content/drive
Google Drive mounted.
Gemini API configured.
Initializing Gemini model: gemini-2.5-pro-preview-03-25...
Model initialized.

--- Specify Input JSON File Paths ---
Use the Colab file browser (left panel) to copy the *full path* for each file.
1. Paste the full path to 'ground_truth_diagnoses.json': /content/drive/MyDrive/BADM550 - Wolters Kluwer Health - Language Model Project/Full Case Records/ground_truth_diagnoses.json
2. Paste the full path to 'perplexity_research_predictions.json': /content/drive/MyDrive/chatGPT_o4-mini-high_predictions.json

Loading data...

Processing 1 common cases using LLM (models/gemini-2.5-pro-preview-03-25)...
This will take time...

--- Processing Case: NEJMcpc2100279 ---
  Rank 1: Comparing...
  Rank 2: Comparing...
  Rank 3: Comparing...
  Rank 4: Comparing...
  Rank 5: Comparing...
  No Match Found for this case.
------------------------------

--- Individual Case Match Details (Rank of First LLM Match) 