In [1]:
OPENAI_API_KEY = "INSERT_OPEN_API_KEY_HERE"
ANTHROPIC_API_KEY = "INSERT_ANTHROPIC_API_KEY_HERE"

In [2]:
import nest_asyncio

nest_asyncio.apply()

In [3]:
from llama_index.core import SimpleDirectoryReader

documents = SimpleDirectoryReader(input_files=["anatomybook.pdf"]).load_data()


In [19]:
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.selectors import LLMSingleSelector

nodes1 = SentenceSplitter(chunk_size=128, chunk_overlap=25).get_nodes_from_documents(documents)
nodes2 = SentenceSplitter(chunk_size=256, chunk_overlap=50).get_nodes_from_documents(documents)
nodes3 = SentenceSplitter(chunk_size=512, chunk_overlap=100).get_nodes_from_documents(documents)
nodes4 = SentenceSplitter(chunk_size=1024, chunk_overlap=200).get_nodes_from_documents(documents)
nodes5 = SentenceSplitter(chunk_size=2048, chunk_overlap=400).get_nodes_from_documents(documents)

In [20]:
from llama_index.core import Settings
from llama_index.llms.openai import OpenAI
from llama_index.embeddings.openai import OpenAIEmbedding

Settings.llm = OpenAI(model="gpt-3.5-turbo", api_key=OPENAI_API_KEY)
Settings.embed_model = OpenAIEmbedding(model="text-embedding-ada-002", api_key=OPENAI_API_KEY)

In [21]:
from llama_index.core import VectorStoreIndex
from llama_index.core import StorageContext, load_index_from_storage

vector_index1 = VectorStoreIndex(nodes1)
vector_index2 = VectorStoreIndex(nodes2)
vector_index3 = VectorStoreIndex(nodes3)
vector_index4 = VectorStoreIndex(nodes4)
vector_index5 = VectorStoreIndex(nodes5)

In [22]:
query_engine1 = vector_index1.as_query_engine(chat_mode="best", llm=OpenAI(model="gpt-3.5-turbo"), verbose=True)
query_engine2 = vector_index2.as_query_engine(chat_mode="best", llm=OpenAI(model="gpt-3.5-turbo"), verbose=True)
query_engine3 = vector_index3.as_query_engine(chat_mode="best", llm=OpenAI(model="gpt-3.5-turbo"), verbose=True)
query_engine4 = vector_index4.as_query_engine(chat_mode="best", llm=OpenAI(model="gpt-3.5-turbo"), verbose=True)
query_engine5 = vector_index5.as_query_engine(chat_mode="best", llm=OpenAI(model="gpt-3.5-turbo"), verbose=True)

In [23]:
from datasets import load_dataset

dataset_validation = load_dataset("openlifescienceai/medmcqa", split="validation").filter(lambda example: example["subject_name"] == "Anatomy")

# query engine 1

In [24]:
def predict1(prompt):
    return str(query_engine1.query(prompt))

def evaluate1():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict1(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict1(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate1()

In [25]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5854700854700855
mismatch: 0.021367521367521368


# query engine 2

In [26]:
def predict2(prompt):
    return str(query_engine2.query(prompt))

def evaluate2():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict2(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict2(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate2()

In [27]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5769230769230769
mismatch: 0.017094017094017096


# query engine 3

In [28]:
def predict3(prompt):
    return str(query_engine3.query(prompt))

def evaluate3():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict3(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict3(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate3()

In [29]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5683760683760684
mismatch: 0.03418803418803419


# query engine 4

In [30]:
def predict4(prompt):
    return str(query_engine4.query(prompt))

def evaluate4():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict4(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict4(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate4()

In [31]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5897435897435898
mismatch: 0.02564102564102564


# query engine 5

In [32]:
def predict5(prompt):
    return str(query_engine5.query(prompt))

def evaluate5():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict5(prompt)
    if prediction not in options:
      prompt += prediction + "\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd']
      prediction = predict5(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate5()

In [33]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5982905982905983
mismatch: 0.029914529914529916
