In [1]:
import pandas as pd
from langchain_community.llms import Ollama
from scripts import constants
from scripts import utils

#### Preliminaries

In [2]:
test_df = pd.read_csv(constants.TEST_SET_PATH)
test_df.head()

Unnamed: 0,hemoglobin,ferritin,ret_count,segmented_neutrophils,tibc,mcv,serum_iron,rbc,gender,creatinine,cholestrol,copper,ethanol,folate,glucose,hematocrit,tsat,label
0,7.116363,-1.0,3.781573,2.738413,-1.0,95.904198,68.457895,2.226085,0,1.892912,39.80855,110.329197,64.40435,21.654404,73.787009,21.349089,-1.0,5
1,8.12532,92.230003,4.231419,1.188039,143.365567,104.057204,204.747831,2.342554,0,0.652614,13.478089,-1.0,32.705481,-1.0,43.520272,24.375961,142.815207,1
2,11.30945,38.324563,-1.0,-1.0,455.077909,76.402602,-1.0,4.440732,0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,33.92835,-1.0,4
3,13.763858,253.513394,2.262606,0.551444,453.772884,82.781943,90.101466,4.987993,0,0.853521,104.005514,34.639227,0.963866,22.083012,88.891838,41.291574,19.856071,0
4,11.464002,-1.0,-1.0,-1.0,320.964653,104.287127,-1.0,3.297819,0,1.163516,121.616315,105.895897,-1.0,9.337462,-1.0,34.392007,-1.0,7


In [3]:
#myllm = Ollama(model="llama3")
myllm = Ollama(model="mistral")

#### Base Version

In [4]:
llm_results_df = pd.DataFrame(columns=['y_actual', 'y_pred'])
llm_results_df['y_actual'] = test_df[:1000]['label']
llm_results_df = llm_results_df.replace({"y_actual": constants.CLASS_DICT})
llm_results_df['y_pred'] = llm_results_df.apply(lambda row: utils.get_diagnosis(row.name, myllm, constants.base_prompt), axis=1)
llm_results_df.head()

Unnamed: 0,y_actual,y_pred
0,hemolytic anemia,hemolytic anemia
1,vitamin b12/folate deficiency anemia,inconclusive diagnosis
2,iron deficiency anemia,hemolytic anemia
3,no anemia,no anemia
4,inconclusive diagnosis,inconclusive diagnosis


In [5]:
acc_spec, f1_spec, roc_auc_spec = utils.test(llm_results_df.y_actual, llm_results_df.y_pred)
acc_spec, f1_spec, roc_auc_spec

(11.3, 6.507839421421463, 49.97502961487874)

In [6]:
llm_results_df

Unnamed: 0,y_actual,y_pred
0,hemolytic anemia,hemolytic anemia
1,vitamin b12/folate deficiency anemia,inconclusive diagnosis
2,iron deficiency anemia,hemolytic anemia
3,no anemia,no anemia
4,inconclusive diagnosis,inconclusive diagnosis
...,...,...
995,anemia of chronic disease,inconclusive diagnosis
996,anemia of chronic disease,inconclusive diagnosis
997,aplastic anemia,inconclusive diagnosis
998,hemolytic anemia,hemolytic anemia


In [7]:
#llm_results_df.to_csv('base_llama.csv', index=False, header=False)
llm_results_df.to_csv('base_mistral.csv', index=False, header=False)

#### Base plus 1 shot

In [8]:
llm_results_df['y_actual'] = test_df[:1000]['label']
llm_results_df = llm_results_df.replace({"y_actual": constants.CLASS_DICT})
llm_results_df['y_pred'] = llm_results_df.apply(lambda row: utils.get_diagnosis(row.name, myllm, constants.base_1shot_prompt), axis=1)
llm_results_df.head()

Unnamed: 0,y_actual,y_pred
0,hemolytic anemia,hemolytic anemia
1,vitamin b12/folate deficiency anemia,hemolytic anemia
2,iron deficiency anemia,iron deficiency anemia
3,no anemia,anemia of chronic disease
4,inconclusive diagnosis,unspecified anemia


In [9]:
acc_spec, f1_spec, roc_auc_spec = utils.test(llm_results_df.y_actual, llm_results_df.y_pred)
acc_spec, f1_spec, roc_auc_spec

(15.7, 12.930727061960138, 51.965324859185)

In [10]:
llm_results_df

Unnamed: 0,y_actual,y_pred
0,hemolytic anemia,hemolytic anemia
1,vitamin b12/folate deficiency anemia,hemolytic anemia
2,iron deficiency anemia,iron deficiency anemia
3,no anemia,anemia of chronic disease
4,inconclusive diagnosis,unspecified anemia
...,...,...
995,anemia of chronic disease,iron deficiency anemia
996,anemia of chronic disease,inconclusive diagnosis
997,aplastic anemia,inconclusive diagnosis
998,hemolytic anemia,iron deficiency anemia


In [11]:
#llm_results_df.to_csv('base_1shot_llama.csv', index=False, header=False)
llm_results_df.to_csv('base_1shot_mistral.csv', index=False, header=False)

#### ChatGPT base

In [4]:
results_df = utils.get_results_chatgpt(250, "gpt-4-turbo", constants.base_prompt, save = True, filename = 'base_chatgpt.csv')

Processing patient:  0
Processing patient:  1
Processing patient:  2
Processing patient:  3
Processing patient:  4
Processing patient:  5
Processing patient:  6
Processing patient:  7
Processing patient:  8
Processing patient:  9
Processing patient:  10
Processing patient:  11
Processing patient:  12
Processing patient:  13
Processing patient:  14
Processing patient:  15
Processing patient:  16
Processing patient:  17
Processing patient:  18
Processing patient:  19
Processing patient:  20
Processing patient:  21
Processing patient:  22
Processing patient:  23
Processing patient:  24
Processing patient:  25
Processing patient:  26
Processing patient:  27
Processing patient:  28
Processing patient:  29
Processing patient:  30
Processing patient:  31
Processing patient:  32
Processing patient:  33
Processing patient:  34
Processing patient:  35
Processing patient:  36
Processing patient:  37
Processing patient:  38
Processing patient:  39
Processing patient:  40
Processing patient:  41
Pr

In [5]:
results_df.head()

Unnamed: 0,y_actual,y_pred
0,hemolytic anemia,iron deficiency anemia
1,vitamin b12/folate deficiency anemia,vitamin b12/folate deficiency anemia
2,iron deficiency anemia,iron deficiency anemia
3,no anemia,no anemia
4,inconclusive diagnosis,vitamin b12/folate deficiency anemia


In [6]:
acc_spec, f1_spec, roc_auc_spec = utils.test(results_df.y_actual, results_df.y_pred)
acc_spec, f1_spec, roc_auc_spec

(40.0, 29.404140196717503, 64.73674600740424)

#### ChatGPT base plus 1 shot

In [7]:
results_df = utils.get_results_chatgpt(250, "gpt-4-turbo", constants.base_1shot_prompt, save = True, filename = 'base_1shot_chatgpt.csv')

Processing patient:  0
Processing patient:  1
Processing patient:  2
Processing patient:  3
Processing patient:  4
Processing patient:  5
Processing patient:  6
Processing patient:  7
Processing patient:  8
Processing patient:  9
Processing patient:  10
Processing patient:  11
Processing patient:  12
Processing patient:  13
Processing patient:  14
Processing patient:  15
Processing patient:  16
Processing patient:  17
Processing patient:  18
Processing patient:  19
Processing patient:  20
Processing patient:  21
Processing patient:  22
Processing patient:  23
Processing patient:  24
Processing patient:  25
Processing patient:  26
Processing patient:  27
Processing patient:  28
Processing patient:  29
Processing patient:  30
Processing patient:  31
Processing patient:  32
Processing patient:  33
Processing patient:  34
Processing patient:  35
Processing patient:  36
Processing patient:  37
Processing patient:  38
Processing patient:  39
Processing patient:  40
Processing patient:  41
Pr

In [8]:
results_df.head()

Unnamed: 0,y_actual,y_pred
0,hemolytic anemia,vitamin b12/folate deficiency anemia
1,vitamin b12/folate deficiency anemia,vitamin b12/folate deficiency anemia
2,iron deficiency anemia,iron deficiency anemia
3,no anemia,no anemia
4,inconclusive diagnosis,vitamin b12/folate deficiency anemia


In [9]:
acc_a, f1_a, roc_auc_a = utils.test(results_df.y_actual, results_df.y_pred)
acc_a, f1_a, roc_auc_a

(44.0, 32.65213164174211, 66.35108610269049)