<h1>ESCARGOT benchmarking</h1>
<h2>Manually configured Escargot for the ALZKB public knowledge graph and private weaviate server</h2>

In [2]:

config = {
    "azuregpt35-16k" : {
        "model_id":"gpt-35-turbo-16k", 
        "prompt_token_cost": 0.001,
        "response_token_cost": 0.002,
        "temperature": 0.7,
        "max_tokens": 2000,
        "stop": None,
        "api_version": "",
        "api_base": "",
        "api_key": "",
        "embedding_id":"text-embedding-ada-002"
    },
    "memgraph" : {
        "host": "",
        "port": 7687
    },
    "weaviate" : {
        "api_key": "",
        "url": "",
        "db": "",
        "limit": 200
    }
}

from escargot import Escargot
escargot = Escargot(config, node_types = "BiologicalProcess, BodyPart, CellularComponent, Datatype, Disease, Drug, DrugClass, Gene, MolecularFunction, Pathway, Symptom", relationship_types = """CHEMICALBINDSGENE
CHEMICALDECREASESEXPRESSION
CHEMICALINCREASESEXPRESSION
DRUGINCLASS
DRUGCAUSESEFFECT
DRUGTREATSDISEASE
GENEPARTICIPATESINBIOLOGICALPROCESS
GENEINPATHWAY
GENEINTERACTSWITHGENE
GENEHASMOLECULARFUNCTION
GENEASSOCIATEDWITHCELLULARCOMPONENT
GENEASSOCIATESWITHDISEASE
SYMPTOMMANIFESTATIONOFDISEASE
BODYPARTUNDEREXPRESSESGENE
BODYPARTOVEREXPRESSESGENE
DISEASELOCALIZESTOANATOMY
DISEASEASSOCIATESWITHDISEASET""", 
model_name="azuregpt35-16k")
# model_name="azuregpt4o")
escargot.memgraph_client.schema = """Node properties are the following:
Node name: 'BiologicalProcess', Node properties: ['commonName']
Node name: 'BodyPart', Node properties: ['commonName']
Node name: 'CellularComponent', Node properties: ['commonName']
Node name: 'Disease', Node properties: ['commonName']
Node name: 'Drug', Node properties: ['commonName']
Node name: 'DrugClass', Node properties: ['commonName']
Node name: 'Gene', Node properties: ['commonName', 'geneSymbol', 'typeOfGene']
Node name: 'MolecularFunction', Node properties: ['commonName']
Node name: 'Pathway', Node properties: ['commonName']
Node name: 'Symptom', Node properties: ['commonName']
Relationship properties are the following:
The relationships are the following:
(:Drug)-[:CHEMICALBINDSGENE]-(:Gene)
(:Drug)-[:CHEMICALDECREASESEXPRESSION]-(:Gene)
(:Drug)-[:CHEMICALINCREASESEXPRESSION]-(:Gene)
(:Drug)-[:DRUGINCLASS]-(:DrugClass)
(:Drug)-[:DRUGCAUSESEFFECT]-(:Disease)
(:Drug)-[:DRUGTREATSDISEASE]-(:Disease)
(:Gene)-[:GENEPARTICIPATESINBIOLOGICALPROCESS]-(:BiologicalProcess)
(:Gene)-[:GENEINPATHWAY]-(:Pathway)
(:Gene)-[:GENEINTERACTSWITHGENE]-(:Gene)
(:Gene)-[:GENEHASMOLECULARFUNCTION]-(:MolecularFunction)
(:Gene)-[:GENEASSOCIATEDWITHCELLULARCOMPONENT]-(:CellularComponent)
(:Gene)-[:GENEASSOCIATESWITHDISEASE]-(:Disease)
(:Symptom)-[:SYMPTOMMANIFESTATIONOFDISEASE]-(:Disease)
(:BodyPart)-[:BODYPARTUNDEREXPRESSESGENE]-(:Gene)
(:BodyPart)-[:BODYPARTOVEREXPRESSESGENE]-(:Gene)
(:Disease)-[:DISEASELOCALIZESTOANATOMY]-(:BodyPart)
(:Disease)-[:DISEASEASSOCIATESWITHDISEASET]-(:Disease)"""





In [None]:
import json
import dill
json_files =  ['MCQ_1hop.json', 'MCQ_2hop.json', 'OpenEnded_1hop.json', 'OpenEnded_2hop.json', 'True_or_False_1hop.json', 'True_or_False_2hop.json']
responses = {}
for json_file in json_files:
    print(json_file)
    with open("../dataset/"+json_file) as f:
        data = json.load(f)
    responses[json_file] = {}
    for question in data:
        response = ''
        tries = 0
        while response == '' and tries < 3:
            escargot.memgraph_client.cache = {}
            try:
                response = escargot.ask(question['question'], answer_type= "array",debug_level = 0)
            except Exception as e:
                response = ''
            tries += 1
        
        print('question:', question['question'], 'answer:', question['answer'], 'response:', response)
        print("------------------------------------------------------------------------------------------------------------------------------\n")
        responses[json_file][question['question']] = str(response)
dill.dump(responses, open('Escargot_esponses.pkl', 'wb'))

<h2>Base GPT3.5</h2>

In [81]:
BASE_GENERATION_TEMPLATE = """
Answer the following question and return only the answer. If it's multiple choice, return the answer in the format "1", "2", "3", "4", etc. If it's a free text answer, return the answer as a string.
Question: {question}
"""

In [83]:
import json
import dill
json_files =  ['MCQ_1hop.json', 'MCQ_2hop.json', 'OpenEnded_1hop.json', 'OpenEnded_2hop.json', 'True_or_False_1hop.json', 'True_or_False_2hop.json']
responses = {}
for json_file in json_files:
    print(json_file)
    with open(json_file) as f:
        data = json.load(f)
    responses[json_file] = {}
    for question in data:
        response = ''
        formatted_question = BASE_GENERATION_TEMPLATE.format(question = question['question'])
        try:
            response = escargot.quick_chat(formatted_question)
            # Remove "Answer:" from the response
            if response.startswith("Answer:"):
                response= response[8:].strip()
            
            #remove ```cypher from the response
            response = response.replace("```cypher", "")

            #remove ``` from anywhere in the response
            response = response.replace("```", "")

            #remove \n from the response
            response = response.replace("\n", "")
            
            # print("Memgraph request:",response)
            print("question:",question['question'])
            print("request:",response)
        except Exception as e:
            # print("Memgraph request failed",e)
            response = ''
        
        responses[json_file][question['question']] = str(response)
        # break
    # break
dill.dump(responses, open('results/Base_responses.pkl', 'wb'))
    

MCQ_1hop.json
question: Which of the following binds to the drug Leucovorin? 1. CAD 2. PDS5B 3. SEL1L 4. ABCC2 5. RMI1
request: 2
question: Which of the following binds to the drug Chlormerodrin? 1. PDS5A 2. RMI1 3. CAD 4. PDS5B 5. SLC12A1
request: 1
question: Which of the following binds to the drug Papaverine? 1. SEL1L 2. CAD 3. PDS5B 4. PDS5A 5. PDE4B
request: 5
question: Which of the following binds to the drug Ethchlorvynol? 1. RMI1 2. GABRB3 3. SEL1L 4. PDS5B 5. CAD
request: 2
question: Which of the following binds to the drug Methimazole? 1. PDS5B 2. PDS5A 3. CYP3A4 4. RMI1 5. SEL1L
request: 2. PDS5A
question: Which of the following binds to the drug Amoxapine? 1. RMI1 2. SEL1L 3. CAD 4. PDS5A 5. HTR1A
request: 5
question: Which of the following binds to the drug Amobarbital? 1. PDS5B 2. RMI1 3. GABRA5 4. PDS5A 5. CAD
request: 4. PDS5A
question: Which of the following binds to the drug Doxazosin? 1. PDS5B 2. SEL1L 3. CAD 4. PDS5A 5. KCNH7
request: 1. PDS5B
question: Which of the

  dill.dump(responses, open('Base_responses.pkl', 'wb'))


<h2>RAG</h2>

'The length of this context is 10 words.'

<h2>Assessing the score</h2>

In [6]:
json_files = ['MCQ_1hop.json', 'MCQ_2hop.json', 'OpenEnded_1hop.json', 'OpenEnded_2hop.json', 'True_or_False_1hop.json', 'True_or_False_2hop.json']

In [7]:
GROUND_TRUTH_COMPARISON_TEMPLATE = """
Given this question: {question}
And the ground truth answer: {ground_truth_answer}
Determine if this answer by the student is correct. Give only a 1 for correct and 0 for incorrect: 
{answer}
"""
GET_ANSWER_MCQ_TEMPLATE = """Given the following multiple choice question and options:
{question}
{options}

And the student's answer:
{answer}

Which option did the student select?"""
CONVERT_ANSWER_TO_ARRAY = """Question:
{question}
Convert the following answer to a python array. return only the array as an evaluable string starting with [ and ending with ]:
{answer}
"""

In [9]:

#get the responses from the pickle files
import dill
base_responses = dill.load(open('../results/Base_responses.pkl', 'rb'))
rag_responses = dill.load(open('../results/RAG_responses.pkl', 'rb'))

  responses = dill.load(open('../results/RAG_responses.pkl', 'rb'))


In [9]:
import dill 
import json
files =  [
    'MCQ_1hop.json',
    'MCQ_2hop.json', 
    'OpenEnded_1hop.json', 
    'OpenEnded_2hop.json', 
    'True_or_False_1hop.json',
    'True_or_False_2hop.json'
]

scores = {}
response_from = ''
score_data = {}

#testing different responses. comment out the responses you don't want to test
responses = dill.load(open("../results/Base_responses.pkl", "rb"))
response_from = 'base'

# responses = dill.load(open("../results/RAG_responses.pkl", "rb"))

# responses = dill.load(open("Escargot_responses.pkl", "rb"))

for file in files:
    print(file)
    scores[file] = {}
    score_data[file] = []
    total = 0
    correct = 0
    dataset = json.load(open("../dataset/"+file, "rb"))

    for question in dataset:
        if "MCQ" in file or "True_or_False" in file:
            if question['question'] not in responses[file]:
                continue
            total += 1
            #compare the ground truth answer with the student's answer
            response = responses[file][question['question']]
            # print(response)
            if "MCQ" in file:
                options = question['question'].split('? ')[1]
                formatted_answer_prompt = GET_ANSWER_MCQ_TEMPLATE.format(question=question['question'].split('? ')[0], options = options, answer = response)
                formatted_answer = ''
                try:
                    formatted_answer = escargot.quick_chat(formatted_answer_prompt)
                except Exception as e:
                    print("LM request failed",e)  
                # print(formatted_answer)
            else:
                formatted_answer = response

            formatted_question = GROUND_TRUTH_COMPARISON_TEMPLATE.format(question = question['question'], ground_truth_answer = question['answer'], answer = formatted_answer)
            try:
                score_response = escargot.quick_chat(formatted_question)
            except Exception as e:
                print("LM request failed",e)
            if "1" in score_response:
                correct += 1
                print(question['answer'], formatted_answer)
            # elif len(score_response) != 1:
            #     total -= 1
            else:
                print(question, response, formatted_answer, score_response)
            score_data[file].append([question['question'], question['answer'], response, formatted_answer, score_response])
            print(correct,'/', total, correct/total *100)
        elif "OpenEnded" in file:
            
            total += 1
            print(question['question'])
            # if 'OpenEnded_1hop_Answers_ChatGPT_3.5_RAG', 'OpenEnded_2hop_Answers_ChatGPT_3.5_RAG
            if 'OpenEnded_1hop_Answers_ChatGPT_3.5_RAG' in responses or response_from =='base':
                if response_from != 'base':
                    if file == 'OpenEnded_1hop.json':
                        response = responses['OpenEnded_1hop_Answers_ChatGPT_3.5_RAG'][question['question']]
                    elif file == 'OpenEnded_2hop.json':
                        response = responses['OpenEnded_2hop_Answers_ChatGPT_3.5_RAG'][question['question']]
                    array_answer = CONVERT_ANSWER_TO_ARRAY.format(question = question['question'], answer = response['response_by_llm'])
                else:
                    array_answer = CONVERT_ANSWER_TO_ARRAY.format(question = question['question'], answer = response)
                # print(response['answer'])
                
                try:
                    score_response = escargot.quick_chat(array_answer)
                    # print(question, response, score_response)
                except Exception as e:
                    print("LM request failed",e)
                response = score_response
                #lowercase
                response = response.lower()
            else:
                response = responses[file][question['question']]
                print(response)
            #compare the ground truth answer with the student's answer
            # response = responses[file][question['question']]
            # question['answer']
            

            
            try:
                if total == 271:
                    print(total)
                if "{'np'" in response:
                    #remove from 'np:' to the first > in the response
                    response = '{' + response.split("'np':")[1].split('>, ')[1]
                    if "array([" in response:
                        #remove from 'array([' to the first ']' in the response
                        response = "[" + response.split("array([")[1].split(']')[0] + "]"
                try:
                    response = eval(response)
                except Exception as e:
                    print("response eval prob:",e)
                    response = response.replace("nan", "None")
                # print(type(response))
                #check if the response is a dict
                if isinstance(response, dict):
                    if len(response) > 0:
                        response = response[list(response.keys())[-1]]
                    else:
                        response = []
                if isinstance(response, set):
                    response = list(response)
                if isinstance(response, tuple):
                    response = list(response[1])
                #make sure the response is a list
                if not isinstance(response, list):
                    response = [response]
                
                if len(response) == 1:
                    if "[" in response[0]:
                        response = response[0].replace("[", "").replace("]", "")
                        response = response.split(',')
                        response = [answer.strip() for answer in response]
                        #remove ' from the response
                        response = [answer.replace("'", "") for answer in response]
                        response = [answer.replace('"', "") for answer in response]
                    #if there is Alzheimer's Disease in the response, there may be a comma afterwards describing a specific type of Alzheimer's Disease
                    # such as if response = "Alzheimer's Disease, Focal Onset, Alzheimer's Disease, Early-Onset, Alzheimer's Disease, Late-Onset"
                    # we want this to look like the array ['alzheimer's disease, focal onset', 'alzheimer's disease, early-onset', 'alzheimer's disease, late-onset']
                    if "Alzheimer's Disease" in response[0]:
                        response = response
                    if "," in response[0] and "Alzheimer's Disease" not in response[0]:
                        response = response[0].split(',')
                        response = [answer.strip() for answer in response]

                #if "biological-region" is in the array
                if "biological-region" in response:
                    response = [answer.replace("biological-region", "biological-region gene") for answer in response]
                    #protein-coding
                if "protein-coding" in response:
                    response = [answer.replace("protein-coding", "protein-coding gene") for answer in response]
                if "ncRNA" in response:
                    response = [answer.replace("ncRNA", "ncRNA gene") for answer in response]
                response = [answer.lower() for answer in response]
                print(response)
                if "Alzheimer Disease, Late Onset" in question['answer'] or "Alzheimer Disease, Early Onset" in question['answer'] or "Alzheimer's Disease, Focal Onset" in question['answer']:
                    gt_array = []
                    if "Alzheimer Disease, Late Onset" in question['answer']:
                        gt_array.append("Alzheimer Disease, Late Onset")
                        #remove the Alzheimer Disease, Late Onset from the question['answer']
                        question['answer'] = question['answer'].replace("Alzheimer Disease, Late Onset", "")
                    if "Alzheimer Disease, Early Onset" in question['answer']:
                        gt_array.append("Alzheimer Disease, Early Onset")
                        #remove the Alzheimer Disease, Early Onset from the question['answer']
                        question['answer'] = question['answer'].replace("Alzheimer Disease, Early Onset", "")
                    if "Alzheimer's Disease, Focal Onset" in question['answer']:
                        gt_array.append("Alzheimer's Disease, Focal Onset")
                        #remove the Alzheimer Disease, Focal Onset from the question['answer']
                        question['answer'] = question['answer'].replace("Alzheimer's Disease, Focal Onset", "")
                    gt_answer = question['answer'].lower().split(',')
                    #append the gt_array to the gt_answer
                    gt_answer = gt_answer + gt_array
                else:
                    gt_answer = question['answer'].lower().split(',')
                gt_answer = [answer.strip() for answer in gt_answer]
                # remove empty strings from the list
                gt_answer = [answer.lower() for answer in gt_answer if answer != '']
                print(gt_answer)

                #metric only for the OpenEnded questions
                count_of_intersection = len(set(response).intersection(set(gt_answer)))
                unique_union = len(set(response).union(set(gt_answer)))
                if len(response) > len(gt_answer):
                    score = count_of_intersection/(unique_union + 5*(unique_union - count_of_intersection))
                else:
                    score = count_of_intersection/unique_union
                if score < 0.6:
                    print("question:", question['question'])
                    # print("original response:", responses[file][question['question']])
                    print("response:", response)
                    print("ground truth:", gt_answer)
                    print("score:", score)
                    print("\n")
                print("score:", score)
                print("\n")
                correct += score
            
            except Exception as e:
                print("response eval prob:",e)
                break

            print(correct,'/', total, correct/total *100)
        
    scores[file]["correct"] = correct
    scores[file]["total"] = total
    scores[file]["percentage"] = correct/total *100
    print('dataset_total:', len(dataset), 'correct:', correct, 'total:', total, 'percentage:', correct/total *100)
    


  responses = dill.load(open("../results/Base_responses.pkl", "rb"))
  dataset = json.load(open("../dataset/"+file, "rb"))


MCQ_1hop.json
{'question': 'Which of the following binds to the drug Leucovorin? 1. CAD 2. PDS5B 3. SEL1L 4. ABCC2 5. RMI1', 'answer': '4'} 2 The student selected option 2. 0
0 / 1 0.0
{'question': 'Which of the following binds to the drug Chlormerodrin? 1. PDS5A 2. RMI1 3. CAD 4. PDS5B 5. SLC12A1', 'answer': '5'} 1 The student selected option 1. 0
0 / 2 0.0
5 The student selected option 5.
1 / 3 33.33333333333333
2 The student selected option 2.
2 / 4 50.0
{'question': 'Which of the following binds to the drug Methimazole? 1. PDS5B 2. PDS5A 3. CYP3A4 4. RMI1 5. SEL1L', 'answer': '3'} 2. PDS5A The student selected option 2. PDS5A. 0
2 / 5 40.0
5 The student selected option 5.
3 / 6 50.0
{'question': 'Which of the following binds to the drug Amobarbital? 1. PDS5B 2. RMI1 3. GABRA5 4. PDS5A 5. CAD', 'answer': '3'} 4. PDS5A The student selected option 4. PDS5A. 0
3 / 7 42.857142857142854
{'question': 'Which of the following binds to the drug Doxazosin? 1. PDS5B 2. SEL1L 3. CAD 4. PDS5A 5.