# GPT Evaluation

This notebook attempts to evaluate the performance of a prompt using GPT

In [None]:
import os

import numpy as np
import openai
import tiktoken
from dotenv import load_dotenv

from discharge_docs.processing.processing import (
    get_patient_file,
    get_splitted_discharge_docs_NICU,
    load_and_process_data_metavision,
)
from discharge_docs.prompt import load_prompts

# Enables automatic reloading of (locally installed) packages
%load_ext autoreload
%autoreload 2

In [None]:
# initialise openAI API
load_dotenv()


openai.api_key = os.getenv("AZURE_OPENAI_KEY")
openai.api_base = os.getenv("AZURE_OPENAI_ENDPOINT")
openai.api_type = "azure"
openai.api_version = "2024-02-01"  # this may change in the future

deployment_name = "aiva-gpt"

os.environ["TIKTOKEN_CACHE_DIR"] = ""

In [None]:
openai.api_base

In [None]:
# get prompts
user_prompt, system_prompt = load_pompts()

# print(user_prompt   )

# print(system_prompt )

In [None]:
# # get data

# data = load_and_process_data_metavision()
# enc_ids_outside_limit = []
# size_outside_limit = []
# enc_ids_within_limit = []
# size_within_limit = []
# for enc_id in data.enc_id.unique():
#     # enc_id = 6
#     patient_data_string, patient_data_df = get_patient_file(enc_id, data)
#     # print(patient_data)
#     # check the length of the patient data
#     encoding = tiktoken.get_encoding("cl100k_base")
#     # print(f"The number of tokens in the data: {len(encoding.encode(patient_data_string))}")
#     # print(f"The length of stay is: {np.mean(patient_data_df['length_of_stay'])}")
#     if len(encoding.encode(patient_data_string))<16000:
#         enc_ids_within_limit.append(enc_id)
#         size_within_limit.append(len(encoding.encode(patient_data_string)))
#     else:
#         enc_ids_outside_limit.append(enc_id)
#         size_outside_limit.append(len(encoding.encode(patient_data_string)))
# print(f"The number of patients within the limit is: {len(enc_ids_within_limit)}") # 176
# print(enc_ids_within_limit)
"""[20, 38, 39, 42, 48, 46, 55, 63, 67, 69, 68, 71, 75, 83, 91, 92, 93, 95, 99, 110, 
107, 111, 115, 116, 119, 120, 121, 126, 128, 129, 133, 138, 150, 149, 152, 165, 167, 
171, 168, 169, 175, 185, 187, 193, 200, 202, 211, 216, 217, 219, 223, 227, 228, 231, 
234, 235, 241, 243, 264, 273, 278, 279, 284, 286, 288, 292, 306, 304, 305, 310, 309,
 317, 319, 324, 327, 326, 336, 338, 341, 343, 345, 354, 355, 357, 364, 363, 370, 374, 
 377, 381, 384, 396, 404, 405, 408, 413, 417, 420, 422, 423, 424, 425, 428, 430, 431, 
 437, 443, 439, 442, 458, 454, 463, 465, 475, 477, 488, 482, 491, 487, 490, 494, 499, 
 498, 497, 501, 510, 512, 518, 519, 522, 524, 530, 539, 540, 547, 544, 554, 557, 571, 
 569, 577, 583, 584, 591, 601, 609, 610, 613, 614, 615, 616, 624, 621, 631, 639, 648, 
 651, 662, 671, 677, 680, 682, 685, 690, 695, 697, 700, 703, 716, 714, 715, 723, 724,
   725, 729, 730]
"""
# print(f"The number of patients outside the limit is: {len(enc_ids_outside_limit)}") # 143
# gepseudonimiseerd & af: (subset)
NICU = [20, 107, 116, 129, 150]
IC = [48, 55, 63, 67, 69, 68, 71]

In [None]:
# get data
data = load_and_process_data_metavision()


enc_id = 20
patient_data_string, patient_data_df = get_patient_file(enc_id, data)
discharge_letter_split_df = get_splitted_discharge_docs_NICU(enc_id, data)

In [None]:
# generate discharge docs through GPT
GPT_returned = get_chatgpt_output_beloop(
    patient_data_string,
    engine=deployment_name,
    user_prompt=user_prompt,
    system_prompt=system_prompt,
    temperature=0,
)
GPT_returned

In [None]:
GPT_output = get_category_output(GPT_returned, "Neurologie")
EPD_output = discharge_letter_split_df[
    discharge_letter_split_df["category"] == "Neurologie"
]["text"].values[0]

print(f"The GPT output is: {GPT_output}")
print(f"The EPD output is: {EPD_output}")

In [None]:
# evaluate the performance:

evaluatie_prompt = load_evaluatie_prompt()

n_runs = 10
for i in range(n_runs):
    print(f"Run {i+1}")

    eval = compare_GPT_output_with_EPD_output(
        GPT_output, EPD_output, evaluatie_prompt, engine=deployment_name, temperature=0
    )
    print(eval)