In this notebook we will work with MedMCQA benchmark and use Langchain agent to use Google search before answering the questions.

You will need your OpenAI API key, and also for the Google search we use SerpAPI so you might need a key for it as well. Here is a [link  to SerpAPI](https://serpapi.com/) in case you need one.

In [1]:
import os

os.environ["OPENAI_API_KEY"] = ""
os.environ["SERPAPI_API_KEY"] = ""

Let's start by using the dataset factory to create the dataset:

In [2]:
from medplexity.benchmarks.dataset_factory import DatasetFactory

from medplexity.benchmarks.medmcqa.medmcqa_dataset_builder import MedMCQADataPoint
from medplexity.datasets.dataset import Dataset

dataset: Dataset[MedMCQADataPoint] = DatasetFactory().build("medmcqa", "validation")

In [3]:
example_data_point = dataset[0]
example_data_point.input

MultipleChoiceInput(question='Which of the following is not true for myelinated nerve fibers:', options=['Impulse through myelinated fibers is slower than non-myelinated fibers', 'Membrane currents are generated at nodes of Ranvier', 'Saltatory conduction of impulses is seen', 'Local anesthesia is effective only when the nerve is not covered by myelin sheath'], context=None, examples=None)

We will also use the existing medplexity prompt for multiple-choice questions:

In [4]:
from medplexity.prompts.multiple_choice_prompt import MultipleChoiceChainOfThoughtPrompt

prompt = MultipleChoiceChainOfThoughtPrompt()

print(prompt.format(
    question=example_data_point.input.question,
    options=example_data_point.input.options
))

Instructions: The following are multiple choice questions about medical knowledge. Solve them in a step-by-step fashion,
starting by summarizing the available information. Output a single option from the given options as the final answer.
Use the following format for your answer:
Explanation: Reasoning for your answer
Answer: (A) | (B) | (C), etc. Give only the option and nothing else.


Question: Which of the following is not true for myelinated nerve fibers:
(A) Impulse through myelinated fibers is slower than non-myelinated fibers (B) Membrane currents are generated at nodes of Ranvier (C) Saltatory conduction of impulses is seen (D) Local anesthesia is effective only when the nerve is not covered by myelin sheath



## Prediction with Langchain

Now let's define the chain that we want to use with langchain. Here I use their agent with OpenAI functions, and define a tool to use Google Search

In [5]:
from langchain.agents import initialize_agent, AgentType, Tool
from langchain.chat_models import ChatOpenAI
from langchain.utilities import SerpAPIWrapper

In [6]:
llm = ChatOpenAI(temperature=0, model="gpt-4")
search = SerpAPIWrapper()

In [7]:
tools = [
    # Search tool
    Tool(
        name="search",
        func=search.run,
        description="Use this before answering the question to get additional medical information. Ask targeted questions."
    ),
]

In [8]:
agent_executor = initialize_agent(
    tools,
    llm,
    agent=AgentType.OPENAI_FUNCTIONS,
    return_intermediate_steps=True,
    verbose=True,
)

The tricky part is around matching the inputs/outputs to the right format. For now, medplexity supports callbacks that are meant to act as adapters between expected inputs and outputs.

In [9]:
from medplexity.benchmarks.multiple_choice_utils import MultipleChoiceInput
from medplexity.chains.chain import ChainOutput

def input_adapter(input: MultipleChoiceInput) -> str:
    """Converts the multiple choice instance into prompt that the agent can call."""

    return prompt.format(
        question=input.question,
        options=input.options,
    )

def chain_output_adapter(chain_output: ChainOutput) -> ChainOutput:
    """Converts the output from agent into the chain output for comparison"""
    parsed_output = MultipleChoiceChainOfThoughtPrompt.extract_explanation_and_answer(chain_output.output)

    output_metadata = chain_output.output_metadata
    output_metadata["explanation"] = parsed_output.explanation

    return ChainOutput(
        output=parsed_output.answer,
        output_metadata=output_metadata,
    )

To adapt inputs/output medplexity gives you the EvaluationAdapterChain, and for langchain specifically, there is a LangchainWrapper that will use a format more consistent to the ones used in medplexity.

In [10]:
from medplexity.chains.evaluation_adapter_chain import EvaluationAdapterChain
from medplexity.chains.langchain_wrapper import LangchainWrapper

# Evaluation adapter gives you ability to provide adapter to map inputs/outputs to the format consistent with your chain
chain = EvaluationAdapterChain(
    # You can use LangchainWrapper to convert the Langchain chain format to the one more appropriate for medplexity
    chain=LangchainWrapper(agent_executor, store_intermediate_steps=True),
    input_adapter=input_adapter,
    output_adapter=chain_output_adapter,
)

In [11]:
from medplexity.medharness import Medharness

medharness = Medharness(
    dataset=dataset,
    chain=chain,
)

Now we can evaluate on a single example and see the results, and the searching that was used (not that it was very helpful in this case)

In [12]:
evaluation = medharness.run(k=1)

  0%|          | 0/1 [00:00<?, ?it/s]



[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `search` with `myelinated nerve fibers`


[0m[36;1m[1;3mThe myelinated nerve fiber is designed by evolution to provide the body with a rapid and efficient transfer of information from peripheral receptors to the central nervous system (CNS), from the CNS to peripheral effectors, and between different centers within the CNS.[0m

100%|██████████| 1/1 [00:33<00:00, 33.87s/it]

[32;1m[1;3mThe available information tells us that myelinated nerve fibers are designed to provide rapid and efficient transfer of information within the body. 

Let's analyze the options:

(A) Impulse through myelinated fibers is slower than non-myelinated fibers: This is not true. Myelinated fibers conduct impulses faster than non-myelinated fibers due to the insulating properties of the myelin sheath.

(B) Membrane currents are generated at nodes of Ranvier: This is true. The nodes of Ranvier are small gaps in the myelin sheath where the axon is exposed. These nodes allow for the rapid conduction of nerve impulses along the axon.

(C) Saltatory conduction of impulses is seen: This is true. Saltatory conduction is a type of nerve impulse conduction that occurs in myelinated axons, where the impulse jumps from one node of Ranvier to the next.

(D) Local anesthesia is effective only when the nerve is not covered by myelin sheath: This is not necessarily true. Local anesthesia works b




In [13]:
example_output = evaluation.evaluation_results[0]

In [14]:
print(example_output.output)
print(example_output.output_metadata)

(A)
{<LangChainOutputDictKeys.intermediate_steps: 'intermediate_steps'>: [(AgentActionMessageLog(tool='search', tool_input='myelinated nerve fibers', log='\nInvoking: `search` with `myelinated nerve fibers`\n\n\n', message_log=[AIMessage(content='', additional_kwargs={'function_call': {'name': 'search', 'arguments': '{\n  "__arg1": "myelinated nerve fibers"\n}'}})]), 'The myelinated nerve fiber is designed by evolution to provide the body with a rapid and efficient transfer of information from peripheral receptors to the central nervous system (CNS), from the CNS to peripheral effectors, and between different centers within the CNS.')], 'explanation': 'Based on the information about myelinated nerve fibers, the statement that impulse through myelinated fibers is slower than non-myelinated fibers is not true. Myelinated fibers are designed for rapid and efficient transfer of information, which means they conduct impulses faster than non-myelinated fibers.'}
