## Constants

In [1]:
TRAIN_SIZE = 300
TEST_SIZE = 30
SEED = 123

## Dataset

In [2]:
# pip install datasets
from datasets import load_dataset

dataset = load_dataset("medalpaca/medical_meadow_medqa")

Downloading readme:   0%|          | 0.00/1.77k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/10.7M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10178 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 10178
    })
})

### Preprocessing

We're not going to focus much on the preprocessing in this tutorial, so feel free to skim over or skip this subsection.

In [3]:
# Sample train and test sets
display(dataset)

dataset = dataset["train"].train_test_split(
    train_size=TRAIN_SIZE,
    test_size=TEST_SIZE,
    shuffle=True,
    seed=SEED
)
display(dataset)

DatasetDict({
    train: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 300
    })
    test: Dataset({
        features: ['input', 'instruction', 'output'],
        num_rows: 30
    })
})

In [4]:
# original format of the dataset
def print_sample(sample: dict[str, str]):
    message = "\n".join(
        f"\n# {k.capitalize()}\n{v}" for k, v in sample.items()
    )[1:]
    print(message)

print_sample(example := dataset["train"][0])

# Input
Q:A 5-year-old boy is brought to the emergency department after he fell on the playground in kindergarten and was unable to get up. His right leg was found to be bent abnormally at the femur, and he was splinted on site by first responders. His past medical history is significant for multiple prior fractures in his left humerus and femur. Otherwise, he has been hitting normal developmental milestones and appears to be excelling in kindergarten. Physical exam also reveals the finding shown in figure A. Which of the following is the most likely cause of this patient's multiple fractures?? 
{'A': 'Abnormal collagen production', 'B': 'Decreased collagen hydroxylation', 'C': 'Increased adenylyl cyclase activity', 'D': 'Mutation in neurofibromin', 'E': 'Non-accidental trauma'},

# Instruction
Please answer with one of the option in the bracket

# Output
C: Increased adenylyl cyclase activity


In [5]:
# reformat the dataset
import json

def reformat_sample(sample: dict[str, str]) -> dict[str, str]:
    input = "Q: " + sample["input"].removeprefix("Q:").removesuffix(",")
    input = input.replace(
        "\n{",
        'Provide your answer as a JSON dictionary with the "option" and answer "text".\n{'
    )

    answer_option = sample["output"][0]
    answer_text = sample["output"][3:]
    output = json.dumps({"option": answer_option, "text": answer_text})

    return {"input": input, "output": output}


dataset = dataset.map(reformat_sample).remove_columns("instruction")

display(dataset)
print_sample(example := dataset["train"][0])

Map:   0%|          | 0/300 [00:00<?, ? examples/s]

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['input', 'output'],
        num_rows: 300
    })
    test: Dataset({
        features: ['input', 'output'],
        num_rows: 30
    })
})

# Input
Q: A 5-year-old boy is brought to the emergency department after he fell on the playground in kindergarten and was unable to get up. His right leg was found to be bent abnormally at the femur, and he was splinted on site by first responders. His past medical history is significant for multiple prior fractures in his left humerus and femur. Otherwise, he has been hitting normal developmental milestones and appears to be excelling in kindergarten. Physical exam also reveals the finding shown in figure A. Which of the following is the most likely cause of this patient's multiple fractures?? Provide your answer as a JSON dictionary with the "option" and answer "text".
{'A': 'Abnormal collagen production', 'B': 'Decreased collagen hydroxylation', 'C': 'Increased adenylyl cyclase activity', 'D': 'Mutation in neurofibromin', 'E': 'Non-accidental trauma'}

# Output
{"option": "C", "text": "Increased adenylyl cyclase activity"}


### Final finetuning dataset

## Load Gemma 2B instruct model

In [6]:
# pip install torch transformers
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")

if torch.cuda.is_available():
    # pip install bitsandbytes accelerate
    from transformers import BitsAndBytesConfig

    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
else:
    quantization_config = None

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2b-it",
    quantization_config=quantization_config
)
print(model.device)

`low_cpu_mem_usage` was None, now set to True since model is quantized.
`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

cuda:0


In [34]:
chat = [
    {"role": "user", "content": example["input"]},
]
input_ids = tokenizer.apply_chat_template(chat, add_generation_prompt=True, return_tensors="pt").to(model.device)

output_ids = model.generate(
    input_ids,
    do_sample=True,
    max_new_tokens=512,
    temperature=1e-3,
)
print(tokenizer.decode(output_ids[0]))

<bos><start_of_turn>user
Q: A 5-year-old boy is brought to the emergency department after he fell on the playground in kindergarten and was unable to get up. His right leg was found to be bent abnormally at the femur, and he was splinted on site by first responders. His past medical history is significant for multiple prior fractures in his left humerus and femur. Otherwise, he has been hitting normal developmental milestones and appears to be excelling in kindergarten. Physical exam also reveals the finding shown in figure A. Which of the following is the most likely cause of this patient's multiple fractures?? Provide your answer as a JSON dictionary with the "option" and answer "text".
{'A': 'Abnormal collagen production', 'B': 'Decreased collagen hydroxylation', 'C': 'Increased adenylyl cyclase activity', 'D': 'Mutation in neurofibromin', 'E': 'Non-accidental trauma'}<end_of_turn>
<start_of_turn>model
{
"option": "A",
"answer": "Abnormal collagen production"
}<eos>


In [None]:
def get_generated_texts(input_ids: torch.Tensor, output_ids: torch.Tensor, remove_eos: bool = True) -> list[str]:
    """Retreive only the generated text based on input and output ids

    Args:
        input_ids (torch.Tensor): batch of input ids
        output_ids (torch.Tensor): corresponding output ids
        remove_eos (bool, optional): whether to remove the final <eos> token. Defaults to True.

    Returns:
        list[str]: _description_
    """
    texts = [
        tokenizer.decode(out_seq[len(in_seq):])
        for in_seq, out_seq in zip(input_ids, output_ids)
    ]
    if remove_eos:
        texts = [text.removesuffix("<eos>") for text in texts]
    return texts


In [21]:
inputs = tokenizer(["Match the pattern: one two three", "my name is bob, what is your name?"], return_tensors="pt", padding=True)
model.generate(
    **inputs,
    do_sample=True,
    max_new_tokens=200,
    temperature=1e-3,
)

tensor([[     0,      0,      0,      2,  10914,    573,   6883, 235292,    974,
           1378,   2149,   2785,   4105,   4442,   6861,   8469,   9981,   2797,
            109,    651,   3448,    603,    573,  40027, 235265,    109,    651,
          40027,    603,    476,   6883,    576,  11187,    674,    708,   1671,
            577,   5598, 235265,    714,  11187,    708,  20677,    575,    476,
           3724,   2184, 235269,    578,    984,    798,    614,   1671,    577,
           1736,   3907, 235265,      1],
        [     2,   1723,   1503,    603,  23191, 235269,   1212,    603,    861,
           1503, 235336,    109,   4521,  11270, 235341,   1165, 235303, 235256,
           4866,    577,   4664,    692, 235265,   2439, 235303, 235256,    861,
           1503, 235336,      1,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,   

## Evaluate before finetuning

### Evaluation utils

In [None]:
import re

# pip install rapidfuzz
from rapidfuzz import fuzz
from rapidfuzz.utils import default_process

example_passage = """
Q:A 67-year-old man with a past medical history of poorly-controlled type 2 diabetes mellitus (T2DM) is brought to the emergency department for acute onset nausea and vomiting. According to the patient, he suddenly experienced vertigo and began vomiting 3 hours ago while watching TV. He reports hiking in New Hampshire with his wife 2 days ago. Past medical history is significant for a myocardial infarction (MI) that was treated with cardiac stenting, T2DM, and hypertension. Medications include lisinopril, aspirin, atorvastatin, warfarin, and insulin. Physical examination demonstrates left-sided facial droop and decreased pinprick sensation at the right arm and leg. What is the most likely etiology of this patient’s symptoms?? {'A': 'Early disseminated Lyme disease', 'B': 'Embolic stroke at the posterior inferior cerebellar artery (PICA)', 'C': 'Hypoperfusion of the anterior spinal artery (ASA)', 'D': 'Labryrinthitis', 'E': 'Thrombotic stroke at the anterior inferior cerebellar artery (AICA)'},
""".strip()

example_substr = "stroke at the anterior inferior cerebellar artery"

def get_available_qa_choices(passage: str) -> dict[str, str]:
    return {
        match.group(1): match.group(2)
        for match in re.finditer("'([^']+)': '([^']+)'", passage)
    }

def get_best_match(answer_text: str, passage: str) -> str:
    choices = get_available_qa_choices(passage)
    if not choices:
        return ""  # no choices found
    return sorted(
        choices,
        key=lambda c: fuzz.token_set_ratio(choices[c], answer_text, processor=default_process)
    )[-1]

display(get_available_qa_choices(example_passage))
display(get_best_match(example_substr, example_passage))

{'A': 'Early disseminated Lyme disease',
 'B': 'Embolic stroke at the posterior inferior cerebellar artery (PICA)',
 'C': 'Hypoperfusion of the anterior spinal artery (ASA)',
 'D': 'Labryrinthitis',
 'E': 'Thrombotic stroke at the anterior inferior cerebellar artery (AICA)'}

'E'

In [None]:
example_pred_text = """
{
"option": "A",
"answer": "Abnormal collagen production"
}<eos>
""".strip()

def parse_prediction(pred_text: str, passage: str = "") -> str:
    """Parse the predicted answer based on the output text, with text matching as backup.

    Args:
        pred_text (str): text outputted from the language model.
        passage (str, optional): The input passage for text matchingin case the LLM does
            not output parseable JSON. Useful for evaluating the model before finetuning.
            Defaults to "" (no passage); in this case the backup prediction will be "".

    Returns:
        str: option (a letter or "") predicted by the model.
    """
    json_text = pred_text
    if match := re.search(r"\{", json_text):  # remove anything before first {
        json_text = json_text[match.start() :]
    if match := re.search(r"\}", json_text[::-1]):  # remove anything after last }
        json_text = json_text[: len(json_text) - match.start()]

    try:
        return json.loads(json_text)["option"]  # try to get the option in the JSON
    except (json.JSONDecodeError, KeyError):
        if passage:  # if cannot, and a passage is supplied, get the best match
            return get_best_match(pred_text, passage)
        else:
            return ""

display(parse_prediction(example_pred_text))
display(parse_prediction(example_substr, example_passage))
display(parse_prediction(example_substr))

'A'

'E'

''