### Import Libraries

In [None]:
import os
import torch
import pandas as pd
import bitsandbytes as bnb

from datasets import load_dataset, load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

from peft import (
    PeftModel,
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
from trl import setup_chat_format, SFTConfig, SFTTrainer
from accelerate import Accelerator

### Setting Environment Variables

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"]="cs769_llama"
# turn off watch to log faster
os.environ["WANDB_WATCH"]="false"

In [3]:
HF_TOKEN = "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"

### Setting Constants

In [4]:
base_model_name = "meta-llama/Llama-3.2-3b-Instruct"
root_model_dir = "Llama-3.2-3b-it-Open-medmcqa-baseline-curriculum"
dataset_name = 'openlifescienceai/medmcqa'

### Loading the model and tokenizer

- setting the configurations for Q-LoRA using BitsAndBytes

In [None]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        token=HF_TOKEN,
        torch_dtype=torch.bfloat16,
    )

tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token

### Loading the Dataset

In [6]:
easy_data = load_from_disk('./json_to_hf/subset1')
medium_data = load_from_disk('./json_to_hf/subset2')
hard_data = load_from_disk('./json_to_hf/subset3')

easy_data = easy_data.to_pandas()
medium_data = medium_data.to_pandas()
hard_data = hard_data.to_pandas()

concatinated_dataset = pd.concat([easy_data, medium_data, hard_data])

concatinated_dataset = Dataset.from_pandas(concatinated_dataset, preserve_index=False)




val_data = load_dataset(dataset_name, split='validation', trust_remote_code=True)

### Formatting the Dataset

In [None]:
pd.DataFrame(easy_data).head(2)

In [None]:
pd.DataFrame(val_data).head(2)

In [10]:
def format_chat_template(row):

    instruction = """Answer the following multiple choice question by giving the most appropriate response. 
Answer should be one among [A, B, C, D]."""

    idx_to_ans_map = {0:"A", 1:"B", 2:"C", 3:"D"}
    

    a = row['opa']
    b = row['opb']
    c = row['opc']
    d = row['opd']

    user_instruction = f"""Question: {row['question']}
                A) {a}
                B) {b}
                C) {c}
                D) {d}
            """

    row_json = [{"role": "system", "content": instruction },
               {"role": "user", "content": user_instruction },
               {"role": "assistant", "content": idx_to_ans_map[row['cop']]}]
    
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)

    return row


In [14]:
from datasets import Dataset

def get_mapped_dataset(easy_data, format_chat_template):
    easy_train_dataset = {col: [] for col in easy_data.column_names}
    easy_train_dataset.update({'text':[]})

    for data in easy_data:
        transformed_example = format_chat_template(data)
        for col in easy_train_dataset.keys():
            easy_train_dataset[col].append(transformed_example[col])

    easy_train_dataset = Dataset.from_dict(easy_train_dataset)

    return easy_train_dataset


In [None]:
val_dataset = val_data.map(format_chat_template)

easy_train_dataset = get_mapped_dataset(easy_data, format_chat_template)
medium_train_dataset = get_mapped_dataset(medium_data, format_chat_template)
hard_train_dataset = get_mapped_dataset(hard_data, format_chat_template)

easy_train_dataset['text'][0]

### Finding the linear module names of the Base Model to train LoRA

In [16]:
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(base_model)

### Configuring LoRA

In [17]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules
)

# https://huggingface.co/docs/trl/en/sft_trainer
model = get_peft_model(base_model, peft_config)

In [None]:
print(type(model))
print(model.print_trainable_parameters())

### Setting the Training Arguments

In [19]:
training_arguments = SFTConfig(
    output_dir=root_model_dir,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=3,
    eval_strategy="steps",
    eval_steps=250,
    logging_steps=1,
    warmup_ratio=0.03,
    logging_strategy='steps',
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    remove_unused_columns=True,
    report_to='wandb',
    max_seq_length=512,
    dataset_text_field='text',
    label_names=["labels"],


    load_best_model_at_end=True,
    
)

### Trainer

In [None]:
print(len(easy_train_dataset))
print(len(medium_train_dataset))
print(len(hard_train_dataset))
print(len(val_dataset))

In [None]:
training_arguments.output_dir = os.path.join(root_model_dir, 'easy')

easy_trainer = SFTTrainer(
    model=model,
    train_dataset=easy_train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)
easy_trainer.train()

easy_trainer.save_model(os.path.join(root_model_dir, 'easy', 'best')) 

In [None]:
training_arguments.output_dir = os.path.join(root_model_dir, 'medium')

easy_model = trainer.model

medium_trainer = SFTTrainer(
    model=easy_model,
    train_dataset=medium_train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)
medium_trainer.train()

medium_trainer.save_model(os.path.join(root_model_dir, 'medium', 'best'))

In [None]:
training_arguments.output_dir = os.path.join(root_model_dir, 'hard')

medium_model = trainer.model

hard_trainer = SFTTrainer(
    model=medium_model,
    train_dataset=hard_train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)
hard_trainer.train()

hard_trainer.save_model(os.path.join(root_model_dir, 'hard', 'best'))

### Model Inference

#### Load the Peft Model

In [None]:
checkpoint_path = "Llama-3.2-3b-it-Open-medmcqa-baseline/checkpoint-5500"
new_base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# Load LoRA adapter
trained_model = PeftModel.from_pretrained(new_base_model, checkpoint_path)
trained_model.eval()


In [96]:
# formatting for inference, the format should not have the answer

def format_chat_prompt_for_inference(row):
    instruction = """Answer the following multiple choice question by giving the most appropriate response. 
Answer should be one among [A, B, C, D]."""

    a = row['opa']
    b = row['opb']
    c = row['opc']
    d = row['opd']

    user_instruction = f"""Question: {row['question']}
                A) {a}
                B) {b}
                C) {c}
                D) {d}
            """

    # No assistant response!
    row_json = [
        {"role": "system", "content": instruction},
        {"role": "user", "content": user_instruction}
    ]

    return tokenizer.apply_chat_template(row_json, tokenize=False, add_generation_prompt=True)


In [97]:
sample_row = {
    "question": "A 35-year-old man has sudden severe chest pain radiating to his back. What is the most likely diagnosis?",
    "opa": "Myocardial infarction",
    "opb": "Pulmonary embolism",
    "opc": "Aortic dissection",
    "opd": "Pneumothorax",
    "cop": 2  
}

prompt = format_chat_prompt_for_inference(sample_row)

In [None]:
prompt

In [None]:
inputs = tokenizer(prompt, return_tensors='pt', truncation=True, padding=True).to(model.device)

with torch.no_grad():
    outputs = trained_model.generate(
        **inputs,
        max_new_tokens=1,
        do_sample=False
    )

decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

predicted_answer = decoded.split("assistant")[-1].strip()
print("Predicted answer:", predicted_answer)

### Evaluate model metrics

In [100]:
# Train Accuracy: 0.7316
# Validation Accuracy: 0.5802

In [101]:
def compute_mcqa_accuracy(model, tokenizer, dataset, max_samples=None):
    model.eval()
    correct = 0
    total = 0
    idx_to_ans_map = {0: "A", 1: "B", 2: "C", 3: "D"}

    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))

    for row in dataset:
        # Prepare the prompt
        instruction = """Answer the following multiple choice question by giving the most appropriate response. 
        Answer should be one among [A, B, C, D]."""

        user_instruction = f"""Question: {row['question']}
                A) {row['opa']}
                B) {row['opb']}
                C) {row['opc']}
                D) {row['opd']}
            """

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": user_instruction}
        ]

        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=1,  # Just want the answer token (A/B/C/D)
                do_sample=False
            )

        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        pred_answer = decoded.split("assistant")[-1].strip()[:1]  # Get the first character after "assistant"

        correct_answer = idx_to_ans_map[row['cop']]
        print(pred_answer, correct_answer)
        if pred_answer == correct_answer:
            correct += 1
        total += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"Evaluated {total} samples")
    print(f"Accuracy: {accuracy:.4f}")
    return accuracy


In [None]:
compute_mcqa_accuracy(trained_model, tokenizer, val_dataset, max_samples=2)

In [56]:
# Validation Accuracy = 57%