# Evaluating models based on medQA dataset

Using the https://huggingface.co/datasets/medmcqa/viewer/default/validation?row=0 dataset we asses the models performance on the medQA dataset.

## Libraries & Setup

In [None]:
%pip install datasets

In [None]:
%pip install transformers

In [None]:
%pip install pretty-errors

### Imports

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
from chat_doc.inference.deploy_endpoint import SageMakerDeployment
from chat_doc.inference.chat import Chat
from chat_doc.config import SEED

[14/12/2023 13:42:10] - INFO : Found credentials in shared credentials file: ~/.aws/credentials
sagemaker.config INFO - Not applying SDK defaults from location: /opt/homebrew/share/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /Users/tilmankerl/Library/Application Support/sagemaker/config.yaml


### Model loading

In [None]:
s3_model_uri = "s3://sagemaker-eu-central-1-228610994900/huggingface-qlora-2023-12-08-15-01-13-2023-12-08-15-01-14-300/output/model.tar.gz"
deployment = SageMakerDeployment(s3_model_uri)
llm_endpoint = deployment.deploy_model()

### Setup

In [None]:
chat = Chat(llm_endpoint)

### Data Loading

In [12]:
dataset = load_dataset("medmcqa")
validation = dataset["validation"]

## Evaluation

In [13]:
# filter huggingface dataset to only include single choice questions
validation = validation.filter(lambda example: example["choice_type"] == "single")

Filter: 100%|██████████| 4183/4183 [00:00<00:00, 106575.47 examples/s]


In [15]:
validation_samples = validation.to_pandas().sample(100, random_state=SEED)

In [4]:
def build_qa_query(row):
    return f"""Please answer the {row['choice_type']}-choice question to the best of your knowledge by just returning the correct option. The subject is {row['subject_name']}.

    Question: {row['question']}
    Options: 
    A) {row['opa']}
    B) {row['opb']}
    C) {row['opc']}
    D) {row['opd']}"""

In [None]:
validation_samples = [
    chat.predict(
        build_qa_query(row)
    ) for _, row in validation_samples.iterrows()

]