# imports and setup

In [None]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install "xformers<0.0.26" trl peft accelerate bitsandbytes pinecone-client sentence-transformers langchain-openai openai

In [None]:
import torch
from datasets import load_dataset
from huggingface_hub import notebook_login
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLanguageModel
from transformers import AutoTokenizer
import numpy as np

In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
    max_seq_length = 2048,
    dtype = torch.float16,
    load_in_4bit = True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 256,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 512,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

tokenizer.padding_side = "left"

In [None]:
from pinecone import Pinecone, ServerlessSpec

pc = Pinecone(api_key='INSERT_PINECONE_API_KEY_HERE')

In [None]:
index = pc.Index("medmcqa-train-1")

In [None]:
from langchain_openai import OpenAIEmbeddings
embed_model = OpenAIEmbeddings(model="text-embedding-ada-002", openai_api_key="INSERT_OPEN_API_KEY_HERE")

def get_relevant_context(text, k=1):
  embeds = embed_model.embed_documents([text])
  results_1 = index.query(
      vector=embeds[0],
      top_k=k,
      include_metadata=True
  )

  results = []

  for result in results_1["matches"]:
    results.append(result["metadata"]["text"])
  return results

# load datasets

In [None]:
medmcqa = load_dataset("openlifescienceai/medmcqa", split="validation")
medmcqa_mixed = medmcqa.select(range(234))
medmcqa_anatomy = medmcqa.filter(lambda example: example["subject_name"] == "Anatomy")

medqa = load_dataset("GBaker/MedQA-USMLE-4-options", split="test").select(range(234))

pubmedqa = load_dataset("bigbio/pubmed_qa", split="validation").select(range(234))

mmlu_anatomy = load_dataset("openlifescienceai/mmlu_anatomy", split="test")

In [None]:
print(medmcqa_mixed)
print(medmcqa_anatomy)
print(medqa)
print(pubmedqa)
print(mmlu_anatomy)

In [None]:
FastLanguageModel.for_inference(model)

In [None]:
def predict(prompt):
    inputs = tokenizer([prompt], return_tensors = "pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens = 256, use_cache = True)
    return tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens = True)[0].strip()

# medmcqa mixed

In [None]:
def eval_medmcqa_mixed():
  references = []
  predictions = []
  SAMPLE_CNT = len(medmcqa_mixed)
  mismatch_cnt = 0

  for i in range(SAMPLE_CNT):
    example = medmcqa_mixed[i]
    options = [example['opa'], example['opb'], example['opc'], example['opd']]

    prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + example["question"] + "\n\n" + example['opa'] + "\n" + example['opb'] + "\n" + example['opc'] + "\n" + example['opd'] + "\n\nRespond with the correct choice from the list above verbatim. Do not include any explanation."

    context = get_relevant_context(example["question"] + "\n\n" + example['opa'] + "\n" + example['opb'] + "\n" + example['opc'] + "\n" + example['opd'])
    prompt += " You may use the following information only if it is helpful: \n" + context[0]

    prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    prediction = predict(prompt)

    if prediction not in options:
      prompt += prediction + "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd'] + "\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
      prediction = predict(prompt)

    reference = options[example['cop']]

    if prediction not in options:
      mismatch_cnt += 1

    references.append(reference)
    predictions.append(prediction)

  em = sum([1 if prediction == reference else 0 for prediction, reference in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return em, mismatch

em, mismatch = eval_medmcqa_mixed()

In [None]:
print("exact match:", em)
print("mismatch:", mismatch)

exact match: 0.5854700854700855
mismatch: 0.01282051282051282


# medmcqa anatomy

In [None]:
def eval_medmcqa_anatomy():
  references = []
  predictions = []
  SAMPLE_CNT = len(medmcqa_anatomy)
  mismatch_cnt = 0

  for i in range(SAMPLE_CNT):
    example = medmcqa_anatomy[i]
    options = [example['opa'], example['opb'], example['opc'], example['opd']]

    prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + example["question"] + "\n\n" + example['opa'] + "\n" + example['opb'] + "\n" + example['opc'] + "\n" + example['opd'] + "\n\nRespond with the correct choice from the list above verbatim. Do not include any explanation."

    context = get_relevant_context(example["question"] + "\n\n" + example['opa'] + "\n" + example['opb'] + "\n" + example['opc'] + "\n" + example['opd'])
    prompt += " You may use the following information only if it is helpful: \n" + context[0]

    prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
    prediction = predict(prompt)

    if prediction not in options:
      prompt += prediction + "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd'] + "\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
      prediction = predict(prompt)

    reference = options[example['cop']]

    if prediction not in options:
      mismatch_cnt += 1

    references.append(reference)
    predictions.append(prediction)

  em = sum([1 if prediction == reference else 0 for prediction, reference in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return em, mismatch

em, mismatch = eval_medmcqa_anatomy()

In [None]:
print("exact match:", em)
print("mismatch:", mismatch)

exact match: 0.6367521367521367
mismatch: 0.021367521367521368


# medqa

In [None]:
def eval_medqa():
  references = []
  predictions = []
  SAMPLE_CNT = len(medqa)
  mismatch_cnt = 0
  both_right = 0
  both_wrong = 0
  only_rag = 0
  only_without = 0

  for i in range(SAMPLE_CNT):
    example = medqa[i]
    options = [example['options']['A'], example['options']['B'], example['options']['C'], example['options']['D']]

    prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + example["question"] + "\n\n" + example['options']['A'] + "\n" + example['options']['B'] + "\n" + example['options']['C'] + "\n" + example['options']['D'] + "\n\nRespond with the correct choice from the list above verbatim. Do not include any explanation."

    context = get_relevant_context(example["question"] + "\n\n" + example['options']['A'] + "\n" + example['options']['B'] + "\n" + example['options']['C'] + "\n" + example['options']['D'])
    prompt += " You may use the following information only if it is helpful: \n" + context[0]

    prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

    prediction = predict(prompt)

    if prediction not in options:
      prompt += prediction + "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['options']['A'] + "\n\n" + example['options']['B'] + "\n\n" + example['options']['C'] + "\n\n" + example['options']['D'] + "\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
      prediction = predict(prompt)

    if prediction not in options:
      mismatch_cnt += 1

    reference = example['options'][example['answer_idx']]

    references.append(reference)
    predictions.append(prediction)

  em = sum([1 if prediction == reference else 0 for prediction, reference in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return em, mismatch

em, mismatch = eval_medqa()

In [None]:
print("exact match:", em)
print("mismatch:", mismatch)

exact match: 0.5811965811965812
mismatch: 0.01282051282051282


# pubmedqa

In [None]:
def eval_pubmedqa():
  references = []
  predictions = []
  SAMPLE_CNT = len(pubmedqa)
  mismatch_cnt = 0

  for i in range(SAMPLE_CNT):
    example = pubmedqa[i]

    prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + example["QUESTION"] + "\n\nRespond with only lower case yes or lowercase no."

    context = get_relevant_context(example["QUESTION"] + "\n\n")
    prompt += " You may use the following information only if it is helpful: \n" + context[0]

    prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

    prediction = predict(prompt)

    if prediction not in ["yes", "no"]:
      prompt += prediction + "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour response does not exactly match yes or no. Do not apologise, simply respond with yes or no\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
      prediction = predict(prompt)

    if prediction not in ["yes", "no"]:
      mismatch_cnt += 1

    reference = example['final_decision']

    references.append(reference)
    predictions.append(prediction)

    em = sum([1 if prediction == reference else 0 for prediction, reference in zip(predictions, references)]) / SAMPLE_CNT

    mismatch =  mismatch_cnt / SAMPLE_CNT

  return em, mismatch

em, mismatch = eval_pubmedqa()

In [None]:
print("exact match:", em)
print("mismatch:", mismatch)

# mmlu anatomy

In [None]:
def eval_mmlu_anatomy():
  references = []
  predictions = []
  SAMPLE_CNT = len(mmlu_anatomy)
  mismatch_cnt = 0

  for i in range(SAMPLE_CNT):
    example = mmlu_anatomy[i]["data"]
    options = [example["Options"]["A"], example["Options"]["B"], example["Options"]["C"], example["Options"]["D"]]

    prompt = "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + example["Question"] + "\n\n" + example['Options']['A'] + "\n" + example['Options']['B'] + "\n" + example['Options']['C'] + "\n" + example['Options']['D'] + "\n\nRespond with the correct choice from the list above verbatim. Do not include any explanation."

    context = get_relevant_context(example["Question"] + "\n\n" + example['Options']['A'] + "\n" + example['Options']['B'] + "\n" + example['Options']['C'] + "\n" + example['Options']['D'])
    prompt += " You may use the following information only if it is helpful: \n" + context[0]

    prompt += "<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"

    prediction = predict(prompt)

    if prediction not in options:
      prompt += prediction + "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['Options']['A'] + "\n\n" + example['Options']['B'] + "\n\n" + example['Options']['C'] + "\n\n" + example['Options']['D'] + "\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
      prediction = predict(prompt)

    if prediction not in options:
      mismatch_cnt += 1

    reference = example["Correct Answer"]

    references.append(reference)
    predictions.append(prediction)

    em = sum([1 if prediction == reference else 0 for prediction, reference in zip(predictions, references)]) / SAMPLE_CNT
    mismatch =  mismatch_cnt / SAMPLE_CNT

  return em, mismatch

em, mismatch = eval_mmlu_anatomy()

In [None]:
print("exact match:", em)
print("mismatch", mismatch)

exact match: 0.5777777777777777
mismatch 0.05185185185185185
