In this notebook we are using Llama-2 hosted on deepinfra (https://deepinfra.com/) for evaluation on MedMCQA dataset

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from medplexity.benchmarks.medmcqa.medmcqa_loader import MedMCQALoader
from medplexity.benchmarks.medmcqa.medmcqa_dataset_builder import MedMCQADatasetBuilder
from medplexity.benchmarks.medmcqa.medmcqa_prompt_template import MedMCQAPromptTemplate


In [3]:
DEEPINFRA_API_KEY = ""

In [4]:
loader = MedMCQALoader()

In [5]:
dataset = MedMCQADatasetBuilder().build_dataset("validation")

In [6]:
example_data_point = next(dataset.__iter__())

In [7]:
example_data_point

MedMCQADataPoint(input=MedMCQAInput(question='Which of the following is not true for myelinated nerve fibers:', options=['Impulse through myelinated fibers is slower than non-myelinated fibers', 'Membrane currents are generated at nodes of Ranvier', 'Saltatory conduction of impulses is seen', 'Local anesthesia is effective only when the nerve is not covered by myelin sheath']), expected_output=0, metadata=MedMCQAOutputMetadata(explanation=None, subject_name='Physiology'))

In [8]:
from medplexity.benchmarks.medmcqa.medmcqa_dataset_builder import MedMCQAInput

def input_adapter(medmcqa_input: MedMCQAInput):
    prompt_template = MedMCQAPromptTemplate()

    sys_prompt = '<<SYS>>Always output a JSON of the format {"answer": "(A) | (B) | (C) | (D)", "explanation": "text explaining the choice"}<</SYS>> \n'

    instructions = "[INST]" + prompt_template.format(
        question=medmcqa_input.question,
        options=medmcqa_input.options
    ) + "[/INST]"

    return sys_prompt + instructions

In [9]:
from benchmarks.multiple_choice_utils import AnswerWithExplanation
import re


def extract_option(s):
    options = re.findall(r'\((A|B|C|D)\)', s)
    if len(options) > 1:
        raise ValueError("More than one option found!")
    elif options:
        return '(' + options[0] + ')'
    else:
        raise ValueError("No option provided in the answer")



def output_adapter(output_json: str) -> AnswerWithExplanation:
    parsed_output = AnswerWithExplanation.model_validate_json(output_json)

    # sometimes in addition the letter it returns also an explanation, so here we just extract the relevant letter
    parsed_output.answer = extract_option(parsed_output.answer)

    return parsed_output

In [10]:
from medplexity.llms.deepinfra import Deepinfra
from medplexity.chains.evaluation_adapter_chain import EvaluationAdapterChain

chain = EvaluationAdapterChain(
    llm=Deepinfra(
        api_token=DEEPINFRA_API_KEY
    ),
    input_adapter=input_adapter,
    output_adapter=output_adapter,
)

In [11]:
def comparator(expected_output: int, predicted_output: AnswerWithExplanation):
    letter_to_idx = { "(A)" : 0, "(B)": 1, "(C)": 2, "(D)": 3 }
    predicted_idx =  letter_to_idx[predicted_output.answer]

    return expected_output == predicted_idx

In [12]:
from medplexity.evaluators.sequential_evaluator import SequentialEvaluator

evaluator = SequentialEvaluator(
    chain=chain,
    comparator=comparator
)

In [13]:
dataset[0].input

MedMCQAInput(question='Which of the following is not true for myelinated nerve fibers:', options=['Impulse through myelinated fibers is slower than non-myelinated fibers', 'Membrane currents are generated at nodes of Ranvier', 'Saltatory conduction of impulses is seen', 'Local anesthesia is effective only when the nerve is not covered by myelin sheath'])

In [19]:
evaluation = evaluator.evaluate(dataset[2:4])

100%|██████████| 2/2 [00:11<00:00,  5.90s/it]


In [20]:
evaluation.accuracy()

0.5

In [21]:
correct, incorrect = evaluation.partition_by_correctness()

In [22]:
incorrect

[EvaluationResult(input=MedMCQAInput(question='Axonal transport is:', options=['Antegrade', 'Retrograde', 'Antegrade and retrograde', 'None']), input_metadata=MedMCQAOutputMetadata(explanation='Fast anterograde (400 mm/day) transport occurs by kinesin molecular motor and retrograde transport (200 mm/day) occurs by dynein molecular motor.', subject_name='Physiology'), expected_output=2, output=AnswerWithExplanation(answer='(A)', explanation='Axonal transport refers to the movement of materials along the length of a nerve fiber, and it is primarily antegrade, meaning that it moves from the cell body to the synapse.'), correct=False)]