# Evaluating models based on medQA dataset

Using the https://huggingface.co/datasets/medmcqa/viewer/default/validation?row=0 dataset we asses the models performance on the medQA dataset.

## Libraries & Setup

### Imports

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import seaborn as sns

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
from chat_doc.inference.deploy_endpoint import SageMakerDeployment
from chat_doc.inference.chat import Chat

### Model loading

In [None]:
s3_model_uri = "s3://sagemaker-eu-central-1-228610994900/huggingface-qlora-2023-12-08-15-01-13-2023-12-08-15-01-14-300/output/model.tar.gz"
deployment = SageMakerDeployment(s3_model_uri)
llm_endpoint = deployment.deploy_model()

### Setup

In [None]:
chat = Chat(llm_endpoint)

### Data Loading

In [2]:
dataset = load_dataset("medmcqa")
validation = dataset["validation"]

Downloading builder script: 100%|██████████| 5.35k/5.35k [00:00<00:00, 13.3MB/s]
Downloading metadata: 100%|██████████| 2.41k/2.41k [00:00<00:00, 10.8MB/s]
Downloading readme: 100%|██████████| 10.5k/10.5k [00:00<00:00, 50.8MB/s]
Downloading data: 100%|██████████| 55.3M/55.3M [00:02<00:00, 25.4MB/s]
Generating train split: 100%|██████████| 182822/182822 [00:05<00:00, 35740.06 examples/s]
Generating test split: 100%|██████████| 6150/6150 [00:00<00:00, 37087.74 examples/s]
Generating validation split: 100%|██████████| 4183/4183 [00:00<00:00, 35396.36 examples/s]


## Evaluation

In [3]:
validation

Dataset({
    features: ['id', 'question', 'opa', 'opb', 'opc', 'opd', 'cop', 'choice_type', 'exp', 'subject_name', 'topic_name'],
    num_rows: 4183
})

In [4]:
def build_qa_query(row):
    return f"""Please answer the {row['choice_type']}-choice question to the best of your knowledge by just returning the correct option. The subject is {row['subject']}.

    Question: {row['question']}
    Options: 
    A) {row['opa']}
    B) {row['opb']}
    C) {row['opc']}
    D) {row['opd']}"""

In [None]:
chat.predict(
    build_qa_query(validation[0])
)