
### This script runs hyperparam selection for one-hop traversal on graph using natural language. Before running this notebook, make sure to run: codes/py_scripts/rag_based_text_generation/GPT/two_hop_traversal_hyperparameter_tuning.py 
### This will save the csv files that are used in this notebook


In [1]:
import pandas as pd
import numpy as np
import json
import ast
from tqdm import tqdm
import re
import os
import seaborn as sns
import matplotlib.pyplot as plt
from scipy.stats import sem


In [2]:
def jaccard_similarity(list1, list2):
    set1 = set(list1)
    set2 = set(list2)
    intersection = len(set1.intersection(set2))
    union = len(set1) + len(set2) - intersection    
    if union == 0:
        return 0.0
    else:
        jaccard_similarity = intersection / union
        return jaccard_similarity
    
def extract_answer(text):
    pattern = r'{[^{}]*}'
    match = re.search(pattern, text)
    if match:
        return match.group()
    else:
        return None
    
def extract_by_splitting(text):
    compound_list = text.split(':')[1].split("Diseases")[0].split("], ")[0]+"]"
    disease_list = text.split(':')[-1].split("}")[0]
    resp = {}
    resp["Compounds"] = ast.literal_eval(compound_list)
    resp["Diseases"] = ast.literal_eval(disease_list)
    return resp
    

In [55]:
PARENT_PATH = "../../../data/analysis_results/"

QUESTION_PATH = os.path.join(PARENT_PATH, "drug_reporposing_questions.csv")

FILES = [
    "gpt_4_node_retrieval_rag_based_two_hop_questions_parameter_tuning_round_1_4.csv",
    "gpt_4_node_retrieval_rag_based_two_hop_questions_parameter_tuning_round_2_4.csv",
    "gpt_4_node_retrieval_rag_based_two_hop_questions_parameter_tuning_round_4_4.csv",
    "gpt_4_node_retrieval_rag_based_two_hop_questions_parameter_tuning_round_4_4.csv",
]


In [72]:
df = pd.read_csv(os.path.join(PARENT_PATH, FILES[2]))
df

Unnamed: 0,disease_1,disease_2,central_nodes_groundTruth,text,llm_answer,max_node_hits,context_similarity_threshold
0,optic atrophy,psammomatous meningioma,"['arachnoid mater', 'internal carotid artery',...",What are the Anatomy that are commonly associa...,"{""Nodes"": [""optic tract"", ""pia mater"", ""subara...",30,10
1,tongue disease,Crouzon syndrome,"['mandible', 'foramen magnum', 'lower jaw regi...",What are the Anatomy that are commonly associa...,"{""Nodes"": [""lower jaw region"", ""tongue""]}",30,10
2,Dandy-Walker syndrome,optic atrophy,"['skull', 'arachnoid mater']",What are the Anatomy that are commonly associa...,"{""Nodes"":[]}",30,10
3,goiter,Human papillomavirus infectious disease,"['Voice Disorders', 'Respiratory Sounds']",What are the Symptoms that are commonly associ...,"{""Nodes"":[]}",30,10
4,insulinoma,congenital diaphragmatic hernia,['gastric juice'],What are the Anatomy that are commonly associa...,"{""Nodes"":[""pancreas"", ""diaphragm""]}",30,10
...,...,...,...,...,...,...,...
355,refractive error,visual impairment and progressive phthisis bulbi,"['ciliary body', 'corneal epithelium', 'anteri...",What are the Anatomy that are commonly associa...,"{""Nodes"": [""optic tract"", ""eye""]}",30,90
356,septooptic dysplasia,disorder of sexual development,['Intellectual Disability'],What are the Symptoms that are commonly associ...,"{""Nodes"":[]}",30,90
357,hereditary night blindness,disease of anatomical entity,"['CABP4', 'NYX', 'GNB3', 'TRPM1', 'RHO', 'GRM6...",What are the Genes that are commonly associate...,"{""Nodes"": [""RDH5"", ""SLC24A1"", ""SAG""]}",30,90
358,Pelizaeus-Merzbacher disease,primary progressive multiple sclerosis,"['brain', 'central nervous system', 'entire my...",What are the Anatomy that are commonly associa...,"{""Nodes"": [""brain"", ""central nervous system""]}",30,90


In [127]:
ind = 99
df = pd.read_csv(os.path.join(PARENT_PATH, FILES[3]))
df = df[df.text.str.contains("Symptom")]
print(df.iloc[ind].central_nodes_groundTruth)
print(df.iloc[ind].llm_answer)
print(df.iloc[ind].max_node_hits)
print(df.iloc[ind].context_similarity_threshold)
print(df.iloc[ind].text)

jaccard_similarity(json.loads(df.iloc[ind].llm_answer)["Nodes"], ast.literal_eval(df.iloc[ind].central_nodes_groundTruth))



['Abnormality of movement', 'Speech Disorders', 'Intellectual Disability', 'Abnormality of the nervous system', 'Abnormality of the gastrointestinal tract', 'Abnormality of the eye', 'Muscle Hypotonia', 'Behavioral abnormality', 'Morphological central nervous system abnormality', 'Anxiety', 'Abnormal heart morphology', 'Abnormality of the cardiovascular system', 'Muscle Hypertonia', 'Abnormality of the genitourinary system']
{"Nodes": ["Behavioral abnormality"]}
30
90
What are the Symptoms that are commonly associated with both Lynch syndrome and Tonne-Kalscheuer syndrome?


0.07142857142857142

(120, 7)

In [108]:

ll_performance_list = []
for file in FILES:
    df = pd.read_csv(os.path.join(PARENT_PATH, file))
    df = df[df.text.str.contains("Anatomy")]
    cntxt_simlarity_thresh = df.context_similarity_threshold.unique()
    llm_performance_list_single_combination = []
    for i in cntxt_simlarity_thresh:
        df_sub = df[df.context_similarity_threshold == i]
        llm_performance_list_across_questions = []
        for index, row in tqdm(df_sub.iterrows()):
            ground_truth = ast.literal_eval(row["central_nodes_groundTruth"])
            try:
                llm_answer = json.loads(row["llm_answer"])
            except:
                try:
                    llm_answer = ast.literal_eval(row["llm_answer"].split("Nodes:")[-1])
                except:
                    llm_answer = []
            if not isinstance(llm_answer, list):
                llm_result = llm_answer["Nodes"]
            else:
                llm_result = llm_answer
            llm_performance_list_across_questions.append(jaccard_similarity(ground_truth, llm_result))
        llm_performance_list_single_combination.append((np.mean(llm_performance_list_across_questions), row["max_node_hits"], i))
    ll_performance_list.extend(llm_performance_list_single_combination)


30it [00:00, 6782.88it/s]
30it [00:00, 9963.51it/s]
30it [00:00, 11382.10it/s]
30it [00:00, 11910.00it/s]
30it [00:00, 16041.45it/s]
30it [00:00, 16084.51it/s]
30it [00:00, 18705.09it/s]
30it [00:00, 21461.56it/s]
30it [00:00, 19505.37it/s]
30it [00:00, 23075.21it/s]
30it [00:00, 22441.43it/s]
30it [00:00, 20818.85it/s]
30it [00:00, 24409.14it/s]
30it [00:00, 21849.13it/s]
30it [00:00, 24961.14it/s]
30it [00:00, 24193.26it/s]


In [109]:
ll_performance_list

[(0.14263581763581765, 1, 10),
 (0.15255500955500956, 1, 30),
 (0.14541359541359541, 1, 60),
 (0.104839097671915, 1, 90),
 (0.2280648926237161, 10, 10),
 (0.21056489262371614, 10, 30),
 (0.21204637410519758, 10, 60),
 (0.19934484905073138, 10, 90),
 (0.1868280602839426, 30, 10),
 (0.19332249177837413, 30, 30),
 (0.19332249177837416, 30, 60),
 (0.19530661876250113, 30, 90),
 (0.1868280602839426, 30, 10),
 (0.19332249177837413, 30, 30),
 (0.19332249177837416, 30, 60),
 (0.19530661876250113, 30, 90)]

In [67]:
hyperparam_perf = pd.DataFrame(llm_performance_list, columns=["performance_mean", "performance_std", "performance_sem", "max_node_hits", "context_similarity_threshold"])
hyperparam_perf


Unnamed: 0,performance_mean,performance_std,performance_sem,max_node_hits,context_similarity_threshold
0,0.177949,0.28256,0.030121,1,10
1,0.177131,0.275137,0.020622,1,30
2,0.171196,0.263375,0.016118,1,60
3,0.153279,0.247199,0.013083,1,90
4,0.17662,0.296064,0.031383,10,10
5,0.17662,0.296064,0.031383,10,10
6,0.17212,0.294827,0.022036,10,30
7,0.169671,0.288527,0.017592,10,60
8,0.167662,0.263877,0.027971,20,10
9,0.163276,0.255107,0.019068,20,30


In [48]:
MAX_NODE_HITS_LIST = [1, 10, 20, 30]
QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD_LIST = [10, 30, 60, 90]

for node_hit_index, MAX_NODE_HITS in enumerate(MAX_NODE_HITS_LIST):
    answer_list = []
    for threshold_index, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD in enumerate(QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD_LIST):     
        print(MAX_NODE_HITS, QUESTION_VS_CONTEXT_SIMILARITY_PERCENTILE_THRESHOLD)

1 10
1 30
1 60
1 90
10 10
10 30
10 60
10 90
20 10
20 30
20 60
20 90
30 10
30 30
30 60
30 90
