In [1]:
import pandas as pd
import sys
sys.path.append('..')
from langchain_community.llms import Ollama 
from scripts import constants
from scripts import utils

In [2]:
test_df = pd.read_csv(constants.TEST_SET_PATH)
test_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,7.116363,-1.0,3.781573,2.738413,-1.0,95.904198,68.457895,2.226085,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,8.12532,92.230003,4.231419,1.188039,143.365567,104.057204,204.747831,2.342554,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,11.30945,38.324563,-1.0,-1.0,455.077909,76.402602,-1.0,4.440732,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,13.763858,253.513394,2.262606,0.551444,453.772884,82.781943,90.101466,4.987993,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,11.464002,-1.0,-1.0,-1.0,320.964653,104.287127,-1.0,3.297819,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


In [3]:
#myllm = Ollama(model="llama3")
myllm = Ollama(model="mistral")

#### Sequential without COT

In [4]:
llm_results_nocot_df = pd.DataFrame(columns=["y_actual", "y_pred"])
llm_results_nocot_df["y_actual"] = test_df[:1000]["label"]
llm_results_nocot_df = llm_results_nocot_df.replace({"y_actual": constants.CLASS_DICT})
llm_results_nocot_df[["y_pred", "pathway_length", "trajectory", "pathway", "messages"]] = llm_results_nocot_df.apply(lambda row: 
                                                                                             utils.get_llm_diagnosis_msg_pass(row.name, myllm, constants.sequentialNOCOT_prompt), 
                                                                                             axis=1, result_type='expand')
llm_results_nocot_df.head()

Unnamed: 0,y_actual,y_pred,pathway_length,trajectory,pathway,messages
0,hemolytic anemia,iron deficiency anemia,5,"[ Hemoglobin, Mean corpuscular volume, Mean ...","[ Hemoglobin, 7.12 g/dL, Mean corpuscular vol...",[content='Please name the first feature whose ...
1,vitamin b12/folate deficiency anemia,anemia of chronic disease,4,"[ Hemoglobin, Mean corpuscular volume, Ferri...","[ Hemoglobin, 8.13 g/dL, Mean corpuscular vol...",[content='Please name the first feature whose ...
2,iron deficiency anemia,iron deficiency anemia,4,"[ Hemoglobin, Mean corpuscular volume, Ferri...","[ Hemoglobin, 11.31 g/dL, Mean corpuscular vo...",[content='Please name the first feature whose ...
3,no anemia,no anemia,2,"[ Hemoglobin, No anemia]","[ Hemoglobin, 13.76 g/dL, No anemia]",[content='Please name the first feature whose ...
4,inconclusive diagnosis,inconclusive diagnosis,4,"[ Hemoglobin, Mean corpuscular volume, Retic...","[ Hemoglobin, 11.46 g/dL, Mean corpuscular vo...",[content='Please name the first feature whose ...


The cell below is for Mistral that adds an extra space at the beginning.

In [6]:
valid_classes = set(constants.CLASS_DICT.values())

# Function to sanitize predictions
def sanitize_predictions(pred):
    pred = pred.strip()  # Remove leading and trailing whitespace
    if pred in valid_classes:
        return pred
    else:
        return "inconclusive diagnosis" 

# Apply the sanitization to predicted labels
llm_results_nocot_df["y_pred"] = llm_results_nocot_df["y_pred"].apply(sanitize_predictions)


In [None]:
# Select only the y_actual and y_pred columns
llm_results_nocot_df = llm_results_nocot_df[["y_actual", "y_pred", "pathway"]]
#llm_results.to_csv('pathways_NOCOTsequential_llama.csv', index=False, header=False)
llm_results_nocot_df.to_csv('pathways_NOCOTsequential_mistral.csv', index=False, header=False)

In [7]:
acc_r, f1_r, roc_auc_r = utils.test(llm_results_nocot_df.y_actual, llm_results_nocot_df.y_pred)
mean_pathway_length_r = utils.compute_mean_pathway_length(llm_results_nocot_df)
acc_r, f1_r, roc_auc_r, mean_pathway_length_r

(27.700000000000003, 22.89244494891416, 58.8934001960572, 3.633)

#### Sequential with COT

In [None]:
llm_results_cot_df = pd.DataFrame(columns=["y_actual", "y_pred"])
llm_results_cot_df["y_actual"] = test_df[:1000]["label"]
llm_results_cot_df = llm_results_cot_df.replace({"y_actual": constants.CLASS_DICT})
llm_results_cot_df[["y_pred", "pathway_length", "trajectory", "pathway", "messages"]] = llm_results_cot_df.apply(lambda row: 
                                                                                             utils.get_llm_diagnosis_msg_pass_cot(row.name, myllm, constants.sequentialCOT_prompt), 
                                                                                             axis=1, result_type='expand')
llm_results_cot_df.head()

In [None]:
# Filter the DataFrame to get only the rows where y_actual is different from y_pred
incorrect_predictions_df = llm_results_cot_df[llm_results_cot_df["y_actual"] != llm_results_cot_df["y_pred"]]

wrong_diagn_df = incorrect_predictions_df[["y_actual", "y_pred", "messages"]]
predictions_df = llm_results_cot_df[["y_actual", "y_pred", "pathway"]]
cleaned_df = llm_results_cot_df["trajectory"]

print(wrong_diagn_df.head())

                                y_actual                  y_pred  \
7                       hemolytic anemia  inconclusive diagnosis   
8   vitamin b12/folate deficiency anemia      unspecified anemia   
9                              no anemia        hemolytic anemia   
13             anemia of chronic disease         aplastic anemia   
14                      hemolytic anemia         aplastic anemia   

                                             messages  
7   [content='Please name the first feature whose ...  
8   [content='Please name the first feature whose ...  
9   [content='Please name the first feature whose ...  
13  [content='Please name the first feature whose ...  
14  [content='Please name the first feature whose ...  


In [None]:
# Export the DataFrame to a CSV file
#wrong_diagn_df.to_csv('wrongpathway_cot_df_llama.csv', index=False)
#predictions_df.to_csv('pathways_COTSequential_llama.csv', index = False)
#cleaned_df.to_csv('pathways_COTSequential_llama_cleaned.csv', index = False)

wrong_diagn_df.to_csv('wrongpathway_cot_df_mistral.csv', index=False)
predictions_df.to_csv('pathways_COTSequential_mistral.csv', index = False)
cleaned_df.to_csv('pathways_COTSequential_mistral_cleaned.csv', index = False)

In [None]:
# Ensure predictions are within the expected class set
valid_classes = set(constants.CLASS_DICT.values())

# Function to sanitize predictions
def sanitize_predictions(pred):
    pred = pred.strip()  # Remove leading and trailing whitespace
    if pred in valid_classes:
        return pred
    else:
        return "inconclusive diagnosis"  # or some other default/unknown class

# Apply the sanitization to predicted labels
llm_results_cot_df["y_pred"] = llm_results_cot_df["y_pred"].apply(sanitize_predictions)

In [None]:
acc_r, f1_r, roc_auc_r = utils.test(llm_results_cot_df.y_actual, llm_results_cot_df.y_pred)
mean_pathway_length_r = utils.compute_mean_pathway_length(llm_results_cot_df)
acc_r, f1_r, roc_auc_r, mean_pathway_length_r

(56.699999999999996, 57.3000188414589, 76.12065619473907, 6.137)

#### ChatGPT Sequential No COT

In [4]:
results_df = utils.get_results_chatgpt_sequential(250, "gpt-4-turbo", constants.sequentialNOCOT_prompt, save = True, filename = 'pathways_NOCOTsequential_chatgpt.csv')

Processing patient:  0
Processing patient:  1
Processing patient:  2
Processing patient:  3
Processing patient:  4
Processing patient:  5
Processing patient:  6
Processing patient:  7
Processing patient:  8
Processing patient:  9
Processing patient:  10
Processing patient:  11
Processing patient:  12
Processing patient:  13
Processing patient:  14
Processing patient:  15
Processing patient:  16
Processing patient:  17
Processing patient:  18
Processing patient:  19
Processing patient:  20
Processing patient:  21
Processing patient:  22
Processing patient:  23
Processing patient:  24
Processing patient:  25
Processing patient:  26
Processing patient:  27
Processing patient:  28
Processing patient:  29
Processing patient:  30
Processing patient:  31
Processing patient:  32
Processing patient:  33
Processing patient:  34
Processing patient:  35
Processing patient:  36
Processing patient:  37
Processing patient:  38
Processing patient:  39
Processing patient:  40
Processing patient:  41
Pr

In [5]:
results_df.head()

Unnamed: 0,y_actual,y_pred,pathway_length
0,hemolytic anemia,hemolytic anemia,4.0
1,vitamin b12/folate deficiency anemia,vitamin b12/folate deficiency anemia,4.0
2,iron deficiency anemia,inconclusive diagnosis,5.0
3,no anemia,no anemia,2.0
4,inconclusive diagnosis,inconclusive diagnosis,4.0


In [6]:
acc, f1, roc_auc = utils.test(results_df.y_actual, results_df.y_pred)
mean_pathway_length = utils.compute_mean_pathway_length(results_df)
acc, f1, roc_auc, mean_pathway_length

(74.0, 69.55868828341478, 85.60098158683833, 4.156)

In [7]:
results_df.head()

Unnamed: 0,y_actual,y_pred,pathway_length
0,hemolytic anemia,hemolytic anemia,4.0
1,vitamin b12/folate deficiency anemia,vitamin b12/folate deficiency anemia,4.0
2,iron deficiency anemia,inconclusive diagnosis,5.0
3,no anemia,no anemia,2.0
4,inconclusive diagnosis,inconclusive diagnosis,4.0


#### ChatGPT Sequential COT

In [8]:
results_df = utils.get_results_chatgpt_sequential_cot(250, "gpt-4-turbo", constants.sequentialCOT_prompt, save = True, filename = 'pathways_COTsequential_chatgpt.csv')

Processing patient:  0
Processing patient:  1
Processing patient:  2


Processing patient:  3
Processing patient:  4
Processing patient:  5
Processing patient:  6
Processing patient:  7
Processing patient:  8
Processing patient:  9
Processing patient:  10
Processing patient:  11
Processing patient:  12
Processing patient:  13
Processing patient:  14
Processing patient:  15
Processing patient:  16
Processing patient:  17
Processing patient:  18
Processing patient:  19
Processing patient:  20
Processing patient:  21
Processing patient:  22
Processing patient:  23
Processing patient:  24
Processing patient:  25
Processing patient:  26
Processing patient:  27
Processing patient:  28
Processing patient:  29
Processing patient:  30
Processing patient:  31
Processing patient:  32
Processing patient:  33
Processing patient:  34
Processing patient:  35
Processing patient:  36
Processing patient:  37
Processing patient:  38
Processing patient:  39
Processing patient:  40
Processing patient:  41
Processing patient:  42
Processing patient:  43
Processing patient:  44

In [9]:
results_df.head()

Unnamed: 0,y_actual,y_pred,pathway_length
0,hemolytic anemia,hemolytic anemia,4.0
1,vitamin b12/folate deficiency anemia,vitamin b12/folate deficiency anemia,4.0
2,iron deficiency anemia,iron deficiency anemia,5.0
3,no anemia,no anemia,2.0
4,inconclusive diagnosis,inconclusive diagnosis,4.0


In [10]:
acc, f1, roc_auc = utils.test(results_df.y_actual, results_df.y_pred)
mean_pathway_length = utils.compute_mean_pathway_length(results_df)
acc, f1, roc_auc, mean_pathway_length

(92.80000000000001, 90.78695337598317, 95.59099116749978, 4.608)