In [4]:
import os
import glob
import sys
import time
import json

#from utils import medswitch
#from utils import contraceptives
#from utils import openai_query

import numpy as np
import pandas as pd
import numpy as np
import regex as re

import socket
hostname = socket.gethostname()
print(hostname)

import dask
import dask.dataframe as dd
import dask.array as da
import dask.bag as db

from dask_jobqueue import SGECluster
from dask.distributed import Client

i = 0
from dask.distributed import LocalCluster, Client

cluster = LocalCluster(n_workers=24, 
                       #cores = 4,
                       memory_limit='24gb')
client = Client(cluster)

# params
rwd_output = './assets/data/'

def load_register_table(data_asset, table, **kwargs):
    return dd.read_parquet(f'/wynton/protected/project/ic/data/parquet/{data_asset}/{table}/', **kwargs)



ic-app.wynton.ucsf.edu


In [6]:
# Load search terms
note_text_rdd = load_register_table("DEID_CDW", "note_text")

note_rdd = note_text_rdd[note_text_rdd["note_text"].str.contains("gpt4|gpt3", case=False)]
note_rdd = note_rdd.compute()



In [2]:
!ls /wynton/protected/project/ic/data/parquet/DEID_CDW


addressdim				    fmsample
allergendim				    fmshortvariant
allergyfact				    fmtherapy
anesthesiarecordattributevaluedim	    fmtrial
anesthesiarecordfact			    fmvariantproperty
anesthesiaregistrymetricfact		    guarantordim
attendingproviderfact			    hospitaladmissionattributevaluedim
attributedim				    hospitaladmissionfact
billareadim				    icustayregistrydatamart
billingaccountencountermappingfact	    imagingfact
billingaccountfact			    immunizationdim
billingproceduredim			    immunizationeventfact
billingproceduresetdim			    immunizationsetdim
billingtransactionfact			    labcomponentdim
birthanesthesiabridge			    labcomponentresultfact
birthattributevaluedim			    labdim
birthaugmentationbridge			    labtestcomponentresultmappingfact
birthaugmentationindicationbridge	    labtestfact
birthcervicalripeningbridge		    manually_deidentified_note_concepts
birthcesareanindicationbridge		    manually_deidentified_note_metadata
birthepisiotomybridge			    manually_deidentified_

# Contraceptive cohort extraction

In [None]:
from utils import contraceptives
from utils.medswitch import unique_trajectory

# params
filepath= "./data/contraceptives/raw"

contraceptives.getMedications(filepath)
contraceptives.getDemographics(filepath)
contraceptives.addNotes(filepath)
contraceptives.finalWeakAnnotations()


In [16]:
import pandas as pd
pt_demographics = pd.read_parquet("./data/contraceptives/annotated_pt_demographics.parquet.gzip")
med_values = pd.read_parquet("./data/contraceptives/annotated_medications.parquet.gzip")


# OpenAI querying

In [None]:
# Determine order of values to put in prompt
labels = ["Intravaginal", "Oral", "Transdermal", "Injectable", "Intrauterine", "Implant"]
np.random.shuffle(labels)
print(labels)
#['Injectable', 'Intravaginal', 'Transdermal', 'Intrauterine', 'Oral', 'Implant']

### OpenAI querying
engine = "gpt-4"
med_class_name="contraceptives"
date = "2023-11-13"

'''
# split into validation/test datasets
medswitch.split_validation_test(med_class_name, validation_size=0.05, seed=0)

patient split 76 1439
valid split (93, 138)
test split (1871, 138)
'''

# Run prompt development set 
openai_query.prompt_dev(med_class_name, 
           engine,
           date,
          sys_config_values = ["general","specialist", "default" ],
           task_config_values = ["default","manual-function",],
           function_config=None)

# Automated evaluation of prompt development set
with open("./data/contraceptives/raw/contraceptives.json") as med_file:
    med_mapping = json.load(med_file)

prompt_dev_df = pd.read_parquet("./data/contraceptives/gpt4/validation.parquet.gzip")
contraceptives.evaluate_prompt_dev(prompt_dev_df, 
                      med_mapping["med_class_mapping"], 
                      date=date, 
                      engine=engine,
                     average="micro")

# The default prompt consistently showed better performance on both medication and reason extraction 
# and the specialist system configuration showed slightly higher scores on average.

# Use GPT4 to extract test set
openai_query.gpt4_test_set(med_class_name=med_class_name,
              date=date,
              engine=engine,
              sys_config = "specialist",
              task_config = "default")

# Evaluate test
test_labels_df = pd.read_parquet("./data/contraceptives/gpt4/test.parquet.gzip")
contraceptives.evaluate_test(test_labels_df, 
              med_mapping["med_class_mapping"],
              date,
              sys_config="specialist",
              task_config="default",
              med_class_name="contraceptives",
              engine="gpt-4",
              average="micro")  



# Open source evaluation

In [6]:
from utils.benchmark.metrics import classification_metrics
from utils.medswitch import  map_generic

def evaluate_open_source(preds_file,
                     labels_file,
                     med_mapping,
                     average="micro"):
    """
    Wrapper for evaluating open source model predictins against structured label baselines
    """
    pred_df = pd.read_csv(preds_file, index_col=0)
    labels_df = pd.read_parquet(labels_file)
    
    # Extract value
    file_name = preds_file.split("/")[-1] 
    date = file_name.split("_")[0]
    model = file_name.split("_")[1]
    
    # get corresponding labels
    pred_df = pred_df.loc[list(labels_df["note_deid_note_key"])]

    # evaluate
    class_metrics = {}
    pred_values = {}
    for pred_col, label_col in [("new_contraceptive","mapped_med_generic_clean"), ("last_contraceptive","prev_medication")]:
        # map values to class 
        preds = list(pred_df[pred_col])
        preds = [None if type(p)!=str else map_generic(p,
                                                       med_mapping,
                                                       return_value=False) for p in preds]

        preds = ["" if p is None else "" if p==np.nan else "" if p=="None" else p for p in preds]
        
        # evaluate and save metrics, predictions, and labels
        labels = list(labels_df[label_col])
        labels = ["" if l is None else "" if l==np.nan else "" if l=="None" else l for l in labels]
        class_metrics[label_col] = classification_metrics(preds=preds, 
                                                          labels=labels,
                                                          average=average)
        
        #class_metrics["n"] = len(labels)
        #class_metrics["date"] = date

        pred_values[label_col+"_"+model] = preds
        pred_values[label_col+"_labels"] = labels
        
    # Concatenate and format
    eval_dfs = {"class":pd.DataFrame(), "pred_values":pd.DataFrame()}
    for metric_set_name, curr_metrics in zip(eval_dfs, [class_metrics, pred_values]):
        all_class_df = eval_dfs[metric_set_name]
        curr_class_df = pd.DataFrame.from_dict(curr_metrics, orient="index")
        all_class_df = pd.concat([all_class_df, curr_class_df])
        eval_dfs[metric_set_name] = all_class_df
    
    eval_dfs["pred_values"].columns = pred_df.index
    return eval_dfs


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
average = "micro"
with open("./data/contraceptives/raw/contraceptives.json") as med_file:
    med_mapping = json.load(med_file)
med_mapping = med_mapping["med_class_mapping"]

# Prompt set evaluation
all_metrics = []
all_responses = pd.DataFrame(columns=["note_deid_note_key"])

fdir = "./data/contraceptives/open_source/*.csv"
fpaths = glob.glob(fdir)
fpaths = [f for f in fpaths if "prompt-dev_responses.csv" in f]
fpaths = ['./data/contraceptives/open_source/2024-08-31_starling-7b-alpha_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_llama-3-8b-chat-hf_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_gemma-7b-it_prompt-dev_responses.csv',
 './data/contraceptives/open_source/01-27-24_starling-7b-alpha_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_Meta-Llama-3.1-8B-Instruct_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_starling-7b-beta_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_BioMistral-7B_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_gemma2-9b-it_prompt-dev_responses.csv',
 './data/contraceptives/open_source/2024-08-31_JSL-MedMNX-7B-SFT_prompt-dev_responses.csv']


# empties
all_metrics_df = pd.DataFrame()
all_responses = pd.DataFrame()

for file_path in fpaths:
    print(file_path)
    prompt_metrics = evaluate_open_source(preds_file=file_path,
                     labels_file="./data/contraceptives/gpt4/validation.parquet.gzip",
                     med_mapping=med_mapping,
                                  average=average)
    
    # extract date, model_task_name
    file_name = file_path.split("/")[-1] 
    date = file_name.split("_")[0]
    model = file_name.split("_")[1]
    dataset = file_name.split("_")[2]
    
    # Append all metrics
    acc_metrics = prompt_metrics["class"]
    acc_metrics["dataset"] = dataset
    acc_metrics["model"] = model
    acc_metrics["date"] = date
    all_metrics_df = pd.concat([all_metrics_df, acc_metrics])
    
    # Append all repsonses
    append_cols = [idx for idx in prompt_metrics["pred_values"].index if idx not in all_responses.index]
    all_responses = pd.concat([all_responses, prompt_metrics["pred_values"].loc[append_cols]])
    
# Add human labels
annotated = pd.read_csv("./data/contraceptives/annotation/annotated_set_EE.csv", index_col=0)
annotated = annotated[["Contraceptive started", "Contraceptive stopped"]]
annotated.columns = ["mapped_med_generic_clean_expert", "prev_medication_expert"]
annotated = annotated.replace(np.nan, "")
annotated = annotated.replace("None", "")

all_responses = all_responses.T
all_responses = all_responses.merge(annotated, left_index=True, right_index=True)

# Add human evaluation
for model in list(all_metrics_df["model"].unique()):
    start_f1 = classification_metrics(list(all_responses["mapped_med_generic_clean_expert"]),
                    list(all_responses[f"mapped_med_generic_clean_{model}"]),
                    average="micro")

    stop_f1 = classification_metrics(list(all_responses["prev_medication_expert"]),
                        list(all_responses[f"prev_medication_{model}"]),
                        average="micro")
    
    human_scores = pd.DataFrame.from_dict({"mapped_med_generic_clean_expert":start_f1,
                                          "prev_medication_expert":stop_f1}, orient="index")
    # Add metadata
    human_scores["model"] = model
    human_scores["dataset"] = dataset
    human_scores["date"] = date
    
    all_metrics_df = pd.concat([all_metrics_df, human_scores])

# Save prompt dev evaluation
all_responses.to_csv(f"./data/contraceptives/open_source/{dataset}_evaluated_preds.csv")
all_metrics_df.to_csv(f"./data/contraceptives/open_source/{dataset}_classification_metrics_micro.csv")




./data/contraceptives/open_source/2024-08-31_starling-7b-alpha_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_llama-3-8b-chat-hf_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_gemma-7b-it_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_Meta-Llama-3.1-8B-Instruct_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_starling-7b-beta_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_BioMistral-7B_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_gemma2-9b-it_prompt-dev_responses.csv
./data/contraceptives/open_source/2024-08-31_JSL-MedMNX-7B-SFT_prompt-dev_responses.csv


In [8]:

# Test set evaluation
all_metrics = []
all_responses = pd.DataFrame(columns=["note_deid_note_key"])

fdir = "./data/contraceptives/open_source/*.csv"
fpaths = glob.glob(fdir)
fpaths = [f for f in fpaths if "test_responses.csv" in f]

fpaths = ['./data/contraceptives/open_source/2024-08-31_starling-7b-beta_test_responses.csv',
          './data/contraceptives/open_source/2024-08-31_starling-7b-alpha_test_responses.csv',
         './data/contraceptives/open_source/2024-08-31_gemma2-9b-it_test_responses.csv',
          './data/contraceptives/open_source/2024-08-31_BioMistral-7B_test_responses.csv',
          './data/contraceptives/open_source/2024-08-31_JSL-MedMNX-7B-SFT_test_responses.csv',
         './data/contraceptives/open_source/2024-08-31_Meta-Llama-3.1-8B-Instruct_test_responses.csv',]

all_metrics_df = pd.DataFrame()
all_responses = pd.DataFrame()

for file_path in fpaths:
    prompt_metrics = evaluate_open_source(preds_file=file_path,
                     labels_file="./data/contraceptives/gpt4/test.parquet.gzip",
                     med_mapping=med_mapping,
                                  average=average)
    
    # extract date, model_task_name
    file_name = file_path.split("/")[-1] 
    date = file_name.split("_")[0]
    model = file_name.split("_")[1]
    dataset = file_name.split("_")[2]

    # Append all metrics
    acc_metrics = prompt_metrics["class"]
    acc_metrics["dataset"] = dataset
    acc_metrics["model"] = model
    acc_metrics["date"] = date
    all_metrics_df = pd.concat([all_metrics_df, acc_metrics])
    
    # Append all repsonses
    append_cols = [idx for idx in prompt_metrics["pred_values"].index if idx not in all_responses.index]
    all_responses = pd.concat([all_responses, prompt_metrics["pred_values"].loc[append_cols]])
    
# Save test evaluation
all_responses.to_csv(f"./data/contraceptives/open_source/{dataset}_evaluated_preds.csv")
all_metrics_df.to_csv(f"./data/contraceptives/open_source/{dataset}_classification_metrics_micro.csv")

