In [1]:
OPENAI_API_KEY = "INSERT_OPEN_API_KEY_HERE"
ANTHROPIC_API_KEY = "INSERT_ANTHROPIC_API_KEY_HERE"

In [2]:
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import VectorStoreIndex
from llama_index.core import StorageContext, load_index_from_storage
import nest_asyncio
nest_asyncio.apply()

In [3]:
documents = SimpleDirectoryReader(input_files=["anatomybook.pdf"]).load_data()
nodes = SentenceSplitter(chunk_size=128, chunk_overlap=25).get_nodes_from_documents(documents)

In [4]:
Settings.llm = OpenAI(model="gpt-3.5-turbo", api_key=OPENAI_API_KEY)
Settings.embed_model = OpenAIEmbedding(model="text-embedding-ada-002", api_key=OPENAI_API_KEY)

In [5]:
index = VectorStoreIndex(nodes)

In [6]:
query_engine1 = RetrieverQueryEngine(
    retriever=VectorIndexRetriever(
    index=index,
    similarity_top_k=6,
),
    response_synthesizer=get_response_synthesizer(),
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
)

query_engine2 = RetrieverQueryEngine(
    retriever=VectorIndexRetriever(
    index=index,
    similarity_top_k=7,
),
    response_synthesizer=get_response_synthesizer(),
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
)

query_engine3 = RetrieverQueryEngine(
    retriever=VectorIndexRetriever(
    index=index,
    similarity_top_k=8,
),
    response_synthesizer=get_response_synthesizer(),
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
)

query_engine4 = RetrieverQueryEngine(
    retriever=VectorIndexRetriever(
    index=index,
    similarity_top_k=9,
),
    response_synthesizer=get_response_synthesizer(),
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
)

query_engine5 = RetrieverQueryEngine(
    retriever=VectorIndexRetriever(
    index=index,
    similarity_top_k=10,
),
    response_synthesizer=get_response_synthesizer(),
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
)

In [18]:
query_engine1 = RetrieverQueryEngine(
    retriever=VectorIndexRetriever(
    index=index,
    similarity_top_k=2,
),
    response_synthesizer=get_response_synthesizer(),
    node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
)

In [7]:
from datasets import load_dataset

dataset_validation = load_dataset("openlifescienceai/medmcqa", split="validation").filter(lambda example: example["subject_name"] == "Anatomy")

# query engine 1

In [19]:
def predict1(prompt):
    return str(query_engine1.query(prompt))

def evaluate1():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    reference = options[example['cop']]
    references.append(reference)

    prediction = predict1(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict1(prompt)

    if prediction != reference:
       print(prompt)
       print('PREDICTION:', prediction)
       print('\n\n-----------------------------------\n\n')
       

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate1()

Which of the following are not a branch of external carotid Aery in Kiesselbach's plexus.

Sphenopalatine aery
Anterior ethmoidal aery
Greater palatine aery
Septal branch of superior labial aery

Respond with the correct choice from the list above verbatim.  Do not include any explanation.
PREDICTION: Greater palatine aery


-----------------------------------


Which pa of brachial plexus do not give branches

Root
Division
Cord
Trunk

Respond with the correct choice from the list above verbatim.  Do not include any explanation.
PREDICTION: Root


-----------------------------------


The cells which will proliferate from top to bottom of villi are:

Chief cells
Goblet cells
Paneth cells
Parietal cells

Respond with the correct choice from the list above verbatim.  Do not include any explanation.
PREDICTION: Goblet cells


-----------------------------------


Retraction of mandible is achieved by:

Lateral pterygoid
Temporalis
Medial pterygoid
Masseter

Respond with the correct choic

In [20]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5897435897435898
mismatch: 0.017094017094017096


# query engine 2

In [10]:
def predict2(prompt):
    return str(query_engine2.query(prompt))

def evaluate2():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict2(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict2(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate2()

In [11]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5598290598290598
mismatch: 0.021367521367521368


# query engine 3

In [12]:
def predict3(prompt):
    return str(query_engine3.query(prompt))

def evaluate3():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict3(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict3(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate3()

In [13]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5726495726495726
mismatch: 0.02564102564102564


# query engine 4

In [14]:
def predict4(prompt):
    return str(query_engine4.query(prompt))

def evaluate4():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict4(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict4(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate4()

In [15]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5683760683760684
mismatch: 0.02564102564102564


# query engine 5

In [26]:
def predict5(prompt):
    return str(query_engine5.query(prompt))

def evaluate5():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict5(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict5(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate5()

In [27]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)
# was 0.5811965811965812

exact_match: 0.5854700854700855
mismatch: 0.021367521367521368
