In [6]:
OPENAI_API_KEY = "INSERT_OPEN_API_KEY_HERE"
ANTHROPIC_API_KEY = "INSERT_ANTHROPIC_API_KEY_HERE"

In [7]:
from llama_index.core import VectorStoreIndex, get_response_synthesizer
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.postprocessor import SimilarityPostprocessor
from llama_index.core import SimpleDirectoryReader
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.selectors import LLMSingleSelector
from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.core import VectorStoreIndex
from llama_index.core import StorageContext, load_index_from_storage
import nest_asyncio
nest_asyncio.apply()

In [8]:
documents = SimpleDirectoryReader(input_files=["anatomybook.pdf"]).load_data()
nodes = SentenceSplitter(chunk_size=128, chunk_overlap=25).get_nodes_from_documents(documents)

In [9]:
Settings.llm = OpenAI(model="gpt-3.5-turbo", api_key=OPENAI_API_KEY)
Settings.embed_model = OpenAIEmbedding(model="text-embedding-ada-002", api_key=OPENAI_API_KEY)

In [11]:
index = VectorStoreIndex(nodes)

In [19]:
from llama_index.core.retrievers import QueryFusionRetriever

query_engine1 = index.as_query_engine(similarity_top_k=2)

vector_retriever = index.as_retriever(similarity_top_k=2)

next_retreiver = QueryFusionRetriever(
    [vector_retriever],
    similarity_top_k=2,
    num_queries=4,  # set this to 1 to disable query generation
    mode="reciprocal_rerank",
    use_async=True,
    verbose=True,
    # query_gen_prompt="...",  # we could override the query generation prompt here
)

# query_engine2 = RetrieverQueryEngine(
#     retriever = next_retreiver,
#     response_synthesizer=get_response_synthesizer(),
#     node_postprocessors=[SimilarityPostprocessor(similarity_cutoff=0.7)],
# )

from llama_index.core.query_engine import RetrieverQueryEngine

query_engine2 = RetrieverQueryEngine.from_args(next_retreiver)


In [15]:
from datasets import load_dataset

dataset_validation = load_dataset("openlifescienceai/medmcqa", split="validation").filter(lambda example: example["subject_name"] == "Anatomy")

# query engine 1

[NodeWithScore(node=TextNode(id_='e01448b8-a435-425e-aea5-91e2285389af', embedding=None, metadata={'page_label': '478', 'file_name': 'anatomybook.pdf', 'file_path': 'anatomybook.pdf', 'file_type': 'application/pdf', 'file_size': 52581393, 'creation_date': '2024-05-27', 'last_modified_date': '2024-05-20'}, excluded_embed_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], excluded_llm_metadata_keys=['file_name', 'file_type', 'file_size', 'creation_date', 'last_modified_date', 'last_accessed_date'], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='fd3715d4-c0bd-4899-bf71-d4237dc4bafa', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'page_label': '478', 'file_name': 'anatomybook.pdf', 'file_path': 'anatomybook.pdf', 'file_type': 'application/pdf', 'file_size': 52581393, 'creation_date': '2024-05-27', 'last_modified_date': '2024-05-20'}, hash='722159b8a4f8ce8fab16792e14f41bb31dc3b78b5ec1e1ea64629801a6c9

In [23]:
def predict1(prompt):
    return str(query_engine1.query(prompt))
    

def evaluate1():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    reference = options[example['cop']]
    references.append(reference)

    prediction = predict1(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict1(prompt)

    if prediction != reference:
       print(prompt)
       rets = vector_retriever.retrieve(prompt)
       print("RETRIEVAL 1", rets[0].text)
       print("RETRIEVAL 2", rets[1].text)
       print('PREDICTION:', prediction)
       print('\n\n-----------------------------------\n\n')
       

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate1()

Which of the following are not a branch of external carotid Aery in Kiesselbach's plexus.

Sphenopalatine aery
Anterior ethmoidal aery
Greater palatine aery
Septal branch of superior labial aery

Respond with the correct choice from the list above verbatim.  Do not include any explanation.
RETRIEVAL 1 Maxillary a.
Posterior auricular a.
Facial a.
Lingual a.
Superior thyroid a.Ascending pharyngeal a.
Superficial temporal a.
Maxillary a.
Posterior auricular a.
Facial a.
Lingual a.
Ascending pharyngeal a.Occipital a.
Superior thyroid a. and superior laryngeal branch
FIGURE 8.51  External Carotid Artery and Branches.
RETRIEVAL 2 other ﬁbers also may course 
via the deep petrosal nerve  (postganglionic 
sympathetic ﬁbers on the internal carotid artery), 
which joins the greater petrosal nerve to become 
Lateral wall of nasal cavity
Greater palatine arteryLesser palatine arteryMaxillary artery
External carotid arterySphenopalatine artery Anterior lateral nasal branch
External nasal branch 
o

In [13]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5299145299145299
mismatch: 0.042735042735042736


# query engine 2

In [16]:
def predict2(prompt):
    return str(query_engine2.query(prompt))

def evaluate2():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict2(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict2(prompt)

    print("prediction ISSSSS", prediction)
    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate2()

Generated queries:
Which branches of the external carotid artery are in Kiesselbach's plexus?
Which arteries make up Kiesselbach's plexus?
What are the branches of the external carotid artery in Kiesselbach's plexus?
Correct choice: Septal branch of superior labial artery
Generated queries:
Sphenopalatine aery
Anterior ethmoidal aery
Greater palatine aery
prediction ISSSSS Sphenopalatine artery
Generated queries:
Cricoid cailage in respiratory tree
Thyroid cailage in respiratory tree
Cunieform cailage in respiratory tree
prediction ISSSSS Cricoid cailage
Generated queries:
Root
Division
Cord
prediction ISSSSS Root
Generated queries:
Chief cells proliferation in villi
Goblet cells function in villi
Paneth cells location in villi
prediction ISSSSS Goblet cells
Generated queries:
Lateral pterygoid
Temporalis
Medial pterygoid
prediction ISSSSS Lateral pterygoid
Generated queries:
Waldeyer's lymphatic chain is formed by all except Palatine tonsils
Waldeyer's lymphatic chain is formed by all

In [17]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)
# score w rerank thing 0.5641025641025641

exact_match: 0.5641025641025641
mismatch: 0.02564102564102564


# query engine 3

In [12]:
def predict3(prompt):
    return str(query_engine3.query(prompt))

def evaluate3():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict3(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict3(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate3()

In [13]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5726495726495726
mismatch: 0.02564102564102564


# query engine 4

In [14]:
def predict4(prompt):
    return str(query_engine4.query(prompt))

def evaluate4():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict4(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict4(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate4()

In [None]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5769230769230769
mismatch: 0.02564102564102564


# query engine 5

In [None]:
def predict5(prompt):
    return str(query_engine5.query(prompt))

def evaluate5():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict5(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict5(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate5()

In [None]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5726495726495726
mismatch: 0.017094017094017096
