In [None]:
%%capture
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
!pip install --no-deps xformers trl peft accelerate bitsandbytes

In [None]:
import torch
from datasets import load_dataset, ReadInstruction
from huggingface_hub import notebook_login
from transformers import TrainingArguments
from trl import SFTTrainer
from unsloth import FastLanguageModel
from transformers import AutoTokenizer

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit",
    max_seq_length = 2048,
    dtype = torch.float16,
    load_in_4bit = True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r = 64,
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 128,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

tokenizer.padding_side = "left"

# Evaluation Before

In [None]:
dataset_validation = load_dataset("openlifescienceai/medmcqa", split="validation").filter(lambda example: example["subject_name"] == "Anatomy")
FastLanguageModel.for_inference(model)

In [None]:
def predict(prompt):
    inputs = tokenizer([prompt], return_tensors = "pt").to("cuda")
    outputs = model.generate(**inputs, max_new_tokens = 256, use_cache = True)
    return tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens = True)[0].strip()

def evaluate():
  SAMPLE_CNT = len(dataset_validation)
  mismatch_cnt = 0
  predictions = []
  references = []

  for i in range(SAMPLE_CNT):
    example = dataset_validation[i]
    question, option_a, option_b, option_c, option_d = example["question"], example["opa"], example["opb"], example["opc"], example["opd"]
    prompt = f'''<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{question}

{option_a}
{option_b}
{option_c}
{option_d}

Respond with the correct choice from the list above verbatim.  Do not include any explanation.<|eot_id|><|start_header_id|>assistant<|end_header_id|>'''

    options = [example['opa'], example['opb'], example['opc'], example['opd']]
    correct_option = options[example['cop']]
    references.append(correct_option)

    prediction = predict(prompt)
    if prediction not in options:
      prompt += prediction + "<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\nYour response does not exactly match one of the choices from the list. Do not apologise or include any text other than one of the options from the list verbatim without any label. Here are the options again\n\n" + example['opa'] + "\n\n" + example['opb'] + "\n\n" + example['opc'] + "\n\n" + example['opd'] + "\n\n<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
      prediction = predict(prompt)

    predictions.append(prediction)

    mismatch_cnt += prediction not in options

  exact_match = sum([prediction == reference for prediction, reference, in zip(predictions, references)]) / SAMPLE_CNT
  mismatch = mismatch_cnt / SAMPLE_CNT

  return exact_match, mismatch

exact_match, mismatch = evaluate()

In [None]:
print("exact_match score:", exact_match)
print("mismatch:", mismatch)

exact_match score: 0.5854700854700855
mismatch: 0.01282051282051282


# Training

In [None]:
FastLanguageModel.for_training(model)
dataset_train = load_dataset("openlifescienceai/medmcqa", split="train").filter(lambda example: example['subject_name'] == 'Anatomy').select(range(0, 3800))
dataset_train

In [None]:
from datasets import ReadInstruction
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from transformers import AutoModelForCausalLM, AutoTokenizer

def formatting_prompts_func(examples):
  texts = []
  for question, option_a, option_b, option_c, option_d, answer in zip(examples["question"], examples["opa"], examples["opb"], examples["opc"], examples["opd"], examples["exp"]):

    text = f'''<|begin_of_text|><|start_header_id|>user<|end_header_id|>

{question}

{option_a}
{option_b}
{option_c}
{option_d}

<|eot_id|><|start_header_id|>assistant<|end_header_id|>

{answer}<|eot_id|>'''

    texts.append(text)

  return { "text": texts }

dataset_train = dataset_train.map(formatting_prompts_func, batched = True)

trainer = SFTTrainer(
    model = model,
    train_dataset = dataset_train,
    dataset_text_field = "text",
    formatting_func=formatting_prompts_func,
    max_seq_length = 2048,
    dataset_num_proc = 2,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 32,
        warmup_steps = 5,
        max_steps = 0,
        num_train_epochs= 1,
        learning_rate = 4e-5,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 42,
        output_dir = "outputs",
    ),
)

In [None]:
trainer_stats = trainer.train()

# Evaluation After

In [None]:
FastLanguageModel.for_inference(model)

In [None]:
exact_match, mismatch = evaluate()

In [None]:
print("exact_match:", exact_match)
print("mismatch:", mismatch)

exact_match: 0.5854700854700855
mismatch: 0.01282051282051282
