### Import Libraries

In [8]:
import os
import torch
import pandas as pd
import bitsandbytes as bnb
from sklearn.utils import shuffle

from datasets import load_dataset, load_from_disk, Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

from peft import (
    PeftModel,
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training
)
from trl import setup_chat_format, SFTConfig, SFTTrainer


from tqdm import tqdm
# from accelerate import Accelerator

### Setting Environment Variables

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"]="cs769_llama"
# turn off watch to log faster
os.environ["WANDB_WATCH"]="false"

In [3]:
HF_TOKEN = "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"

### Setting Constants

In [4]:
base_model_name = "meta-llama/Llama-3.2-3b-Instruct"
root_model_dir = "Llama-3.2-3b-it-Open-medmcqa-baseline-curriculum_data_replay_full_run"
dataset_name = 'openlifescienceai/medmcqa'

### Loading the model and tokenizer

- setting the configurations for Q-LoRA using BitsAndBytes

In [5]:
bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16
    )

base_model = AutoModelForCausalLM.from_pretrained(
        base_model_name,
        quantization_config=bnb_config,
        device_map="auto",
        token=HF_TOKEN,
        torch_dtype=torch.bfloat16,
    )

tokenizer = AutoTokenizer.from_pretrained(
    base_model_name,
    trust_remote_code=True
)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.76s/it]


### Loading the Dataset

In [6]:
# easy_data = load_from_disk('./json_to_hf/subset1')#.select(range(2000))
# medium_data = load_from_disk('./json_to_hf/subset2')#.select(range(2000))
# hard_data = load_from_disk('./json_to_hf/subset3')#.select(range(2000))

val_data = load_dataset(dataset_name, split='validation', trust_remote_code=True)#.select(range(2000))

In [15]:
hard_trained_dataset = Dataset.from_pandas(pd.DataFrame(load_from_disk('./hard_train_dataset')))

### Formatting the Dataset

In [17]:
pd.DataFrame(val_data).head(2)

Unnamed: 0,id,question,opa,opb,opc,opd,cop,choice_type,exp,subject_name,topic_name
0,45258d3d-b974-44dd-a161-c3fccbdadd88,Which of the following is not true for myelina...,Impulse through myelinated fibers is slower th...,Membrane currents are generated at nodes of Ra...,Saltatory conduction of impulses is seen,Local anesthesia is effective only when the ne...,0,multi,,Physiology,
1,b944ada9-d776-4c2a-9180-3ae5f393f72d,Which of the following is not true about glome...,The oncotic pressure of the fluid leaving the ...,Glucose concentration in the capillaries is th...,Constriction of afferent aeriole decreases the...,Hematocrit of the fluid leaving the capillarie...,0,multi,Ans-a. The oncotic pressure of the fluid leavi...,Physiology,


In [18]:
def format_chat_template(row):

    instruction = """Answer the following multiple choice question by giving the most appropriate response. 
Answer should be one among [A, B, C, D]."""

    idx_to_ans_map = {0:"A", 1:"B", 2:"C", 3:"D"}
    

    a = row['opa']
    b = row['opb']
    c = row['opc']
    d = row['opd']

    user_instruction = f"""Question: {row['question']}
                A) {a}
                B) {b}
                C) {c}
                D) {d}
            """

    row_json = [{"role": "system", "content": instruction },
               {"role": "user", "content": user_instruction },
               {"role": "assistant", "content": idx_to_ans_map[row['cop']]}]
    
    row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False)

    return row


In [19]:
from datasets import Dataset

def get_mapped_dataset(easy_data, format_chat_template):
    easy_train_dataset = {col: [] for col in easy_data.column_names}
    easy_train_dataset.update({'text':[]})

    for data in tqdm(easy_data):
        transformed_example = format_chat_template(data)
        for col in easy_train_dataset.keys():
            easy_train_dataset[col].append(transformed_example[col])

    easy_train_dataset = Dataset.from_dict(easy_train_dataset)

    return easy_train_dataset


In [20]:
val_dataset = val_data.map(format_chat_template)

# easy_train_dataset = get_mapped_dataset(easy_data, format_chat_template)
# medium_train_dataset = get_mapped_dataset(medium_data, format_chat_template)
# hard_train_dataset = get_mapped_dataset(hard_data, format_chat_template)

# easy_train_dataset['text'][0]

### Finding the linear module names of the Base Model to train LoRA

In [21]:
def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(base_model)

### Configuring LoRA

In [22]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules
)

# https://huggingface.co/docs/trl/en/sft_trainer
model = get_peft_model(base_model, peft_config)

In [23]:
print(type(model))
print(model.print_trainable_parameters())

<class 'peft.peft_model.PeftModelForCausalLM'>
trainable params: 24,313,856 || all params: 3,237,063,680 || trainable%: 0.7511
None


### Setting the Training Arguments

In [24]:
training_arguments = SFTConfig(
    output_dir=root_model_dir,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=2,
    eval_strategy="steps",
    eval_steps=250,
    logging_steps=250,
    warmup_ratio=0.03,
    logging_strategy='steps',
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    remove_unused_columns=True,
    report_to='wandb',
    max_seq_length=512,
    dataset_text_field='text',
    label_names=["labels"],


    load_best_model_at_end=True,
    
)

In [25]:
# import torch
# import torch.nn.functional as F

# def compute_mcqa_accuracy(model, tokenizer, dataset, max_samples=None):
#     model.eval()
#     correct = 0
#     total = 0
#     idx_to_ans_map = {0: "A", 1: "B", 2: "C", 3: "D"}
#     ans_to_idx_map = {"A": 0, "B": 1, "C": 2, "D": 3}
#     misclassified_samples = []

#     if max_samples:
#         dataset = dataset.select(range(min(max_samples, len(dataset))))

#     for row in dataset:
#         instruction = """Answer the following multiple choice question by giving the most appropriate response. 
#         Answer should be one among [A, B, C, D]."""

#         user_instruction = f"""Question: {row['question']}
#                 A) {row['opa']}
#                 B) {row['opb']}
#                 C) {row['opc']}
#                 D) {row['opd']}
#         """

#         messages = [
#             {"role": "system", "content": instruction},
#             {"role": "user", "content": user_instruction}
#         ]

#         prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
#         inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)

#         with torch.no_grad():
#             outputs = model.generate(
#                 **inputs,
#                 max_new_tokens=1,
#                 do_sample=False,
#                 return_dict_in_generate=True,
#                 output_scores=True
#             )

#         # Decode prediction
#         decoded = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).strip()
#         pred_answer = decoded.split("assistant")[-1].strip()[:1]

#         correct_answer = idx_to_ans_map[row['cop']]
#         total += 1

#         # Compute classification loss over the predicted token
#         logits = outputs.scores[0][0]  # scores[0] is the logits of the generated token
#         target_token = tokenizer.convert_tokens_to_ids(pred_answer)
#         correct_token = tokenizer.convert_tokens_to_ids(correct_answer)

#         print(logits.shape)
#         print(target_token)
#         print(correct_token)
        
#         if correct_token in logits:
#             loss = F.cross_entropy(logits.unsqueeze(0), torch.tensor([correct_token], device=model.device), reduction='none')
#         else:
#             # Penalize unknown outputs
#             loss = torch.tensor([float('inf')], device=model.device)

#         if pred_answer == correct_answer:
#             correct += 1
#         else:
#             misclassified_samples.append({
#                 "question": row['question'],
#                 "options": {
#                     "A": row["opa"],
#                     "B": row["opb"],
#                     "C": row["opc"],
#                     "D": row["opd"]
#                 },
#                 "predicted": pred_answer,
#                 "correct": correct_answer,
#                 "loss": loss.item()
#             })

#     accuracy = correct / total if total > 0 else 0.0
#     print(f"Evaluated {total} samples")
#     print(f"Accuracy: {accuracy:.4f}")
#     print(f"Misclassified {len(misclassified_samples)} samples")

#     return accuracy, misclassified_samples


In [26]:
# write here


import torch
import torch.nn.functional as F

def compute_mcqa_accuracy(model, tokenizer, dataset, max_samples=None):
    model.eval()
    correct = 0
    total = 0
    idx_to_ans_map = {0: "A", 1: "B", 2: "C", 3: "D"}
    ans_to_idx_map = {"A": 0, "B": 1, "C": 2, "D": 3}
    misclassified_samples = []

    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))

    for row in dataset:
        instruction = """Answer the following multiple choice question by giving the most appropriate response. 
        Answer should be one among [A, B, C, D]."""

        user_instruction = f"""Question: {row['question']}
                A) {row['opa']}
                B) {row['opb']}
                C) {row['opc']}
                D) {row['opd']}
        """

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": user_instruction}
        ]

        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=1,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True
            )

        decoded = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).strip()
        pred_answer = decoded.split("assistant")[-1].strip()[:1]

        correct_answer = idx_to_ans_map[row['cop']]
        total += 1

        # Get logits of the generated token
        logits = outputs.scores[0][0]  # Shape: [vocab_size]

        # Get token ID for the correct multiple-choice letter (e.g., "A", "B", ...)
        correct_token_id = tokenizer.convert_tokens_to_ids(correct_answer)  # <-- CHANGED

        # Compute cross-entropy loss on the logits (batch size 1)
        loss = F.cross_entropy(logits.unsqueeze(0), torch.tensor([correct_token_id], device=logits.device))  # <-- CHANGED

        if pred_answer == correct_answer:
            correct += 1
        else:
            misclassified_samples.append({
                "question": row['question'],
                "options": {
                    "A": row["opa"],
                    "B": row["opb"],
                    "C": row["opc"],
                    "D": row["opd"]
                },
                "predicted": pred_answer,
                "correct": correct_answer,
                "loss": loss.item()  # <-- CHANGED
            })

    accuracy = correct / total if total > 0 else 0.0
    print(f"Evaluated {total} samples")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Misclassified {len(misclassified_samples)} samples")

    return accuracy, misclassified_samples


### Trainer

In [18]:
print(len(easy_train_dataset))
# print(len(medium_train_dataset))
# print(len(hard_train_dataset))
print(len(val_dataset))

38800
4183


In [19]:
training_arguments.output_dir = os.path.join(root_model_dir, 'easy')

easy_trainer = SFTTrainer(
    model=model,
    train_dataset=easy_train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)
easy_trainer.train()

easy_trainer.save_model(os.path.join(root_model_dir, 'easy', 'best')) 

Converting train dataset to ChatML: 100%|██████████| 38800/38800 [00:06<00:00, 6123.28 examples/s]
Adding EOS to train dataset: 100%|██████████| 38800/38800 [00:06<00:00, 6124.98 examples/s]
Tokenizing train dataset: 100%|██████████| 38800/38800 [00:17<00:00, 2251.31 examples/s]
Truncating train dataset: 100%|██████████| 38800/38800 [00:00<00:00, 213265.69 examples/s]
[34m[1mwandb[0m: Currently logged in as: [33msyammohan2103[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
250,0.9196,0.963144
500,0.6968,0.947099
750,0.6847,0.944149
1000,0.6703,0.935745
1250,0.6586,0.93891
1500,0.593,0.940048
1750,0.5812,0.934269
2000,0.5837,0.937233
2250,0.5732,0.937829


In [26]:
# load the trained easy model
model_path = os.path.join(root_model_dir, 'easy', 'best')
trained_model = AutoModelForCausalLM.from_pretrained(model_path) #device='cuda'
trained_model.to('cuda') 
trained_model.eval()
trained_model.device

Loading checkpoint shards: 100%|██████████| 2/2 [00:25<00:00, 12.86s/it]


device(type='cuda', index=0)

In [27]:
# get the misclassified data samples

accuracy, missclassified_samples = compute_mcqa_accuracy(trained_model, tokenizer, easy_train_dataset)

print(len(missclassified_samples))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for

Evaluated 38800 samples
Accuracy: 0.9070
Misclassified 3609 samples
3609


In [None]:
# based on threshhold on loss

# threshold = 0.7

# misclassified_samples = [
#     sample for sample in misclassified_samples
#     if sample['loss'] > threshold
# ]
# len(misclassified_samples)

In [28]:
# take top n samples as replay data from misclassified data

top = 2500 # take 15-20% of the total data or 

misclassified_samples_sorted_easy = sorted(
    missclassified_samples,
    key=lambda x: x['loss'],
    reverse=True
)

top_n_easy = misclassified_samples_sorted_easy[:top]

print(len(top_n_easy))

2500


In [29]:
# formatting the dataset to hugging face dataset

easy_temp = easy_data.to_pandas()

top_n_df = pd.DataFrame(top_n_easy)

selected_subset_easy = easy_temp[easy_temp['question'].isin(top_n_df['question'])]

print(len(selected_subset_easy))

2500


In [30]:
# creating the medium training dataset replaying data from easy subset

medium_tmp = medium_data.to_pandas()

concatenated_df = pd.concat([medium_tmp, selected_subset_easy])

shuffled_df = shuffle(concatenated_df, random_state=42)

shuffled_df = shuffled_df.reset_index(drop=True)

medium_data = Dataset.from_pandas(shuffled_df)

medium_train_dataset = get_mapped_dataset(medium_data, format_chat_template)


100%|██████████| 75832/75832 [00:16<00:00, 4667.68it/s]


In [37]:
training_arguments.output_dir = os.path.join(root_model_dir, 'medium')

easy_model = easy_trainer.model

medium_trainer = SFTTrainer(
    model=easy_model,
    train_dataset=medium_train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)
medium_trainer.train()

medium_trainer.save_model(os.path.join(root_model_dir, 'medium', 'best'))

Converting train dataset to ChatML: 100%|██████████| 75832/75832 [00:10<00:00, 7267.38 examples/s] 
Adding EOS to train dataset: 100%|██████████| 75832/75832 [00:07<00:00, 10563.86 examples/s]
Tokenizing train dataset: 100%|██████████| 75832/75832 [00:35<00:00, 2123.94 examples/s]
Truncating train dataset: 100%|██████████| 75832/75832 [00:00<00:00, 294610.30 examples/s]


Step,Training Loss,Validation Loss
250,0.6591,0.931719
500,0.6567,0.927856
750,0.6457,0.924912
1000,0.6375,0.921349
1250,0.6386,0.91534
1500,0.6307,0.919731
1750,0.6235,0.91552
2000,0.6206,0.909792


In [13]:
# load the trained medium model
model_path = os.path.join(root_model_dir, 'medium', 'best')
trained_model_medium = AutoModelForCausalLM.from_pretrained(model_path) 
trained_model_medium.to('cuda') 
trained_model_medium.eval()
trained_model_medium.device

Loading checkpoint shards: 100%|██████████| 2/2 [00:23<00:00, 11.74s/it]


device(type='cuda', index=0)

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(base_model)
trained_model = Peft_Model(base_model, 'checkpoint_path')
trained_model.eval()

# load the trained medium model
model_path = os.path.join(root_model_dir, 'medium', 'best')
trained_model_medium = AutoModelForCausalLM.from_pretrained(model_path) 
trained_model_medium.to('cuda') 
trained_model_medium.eval()
trained_model_medium.device

Loading checkpoint shards: 100%|██████████| 2/2 [00:23<00:00, 11.74s/it]


device(type='cuda', index=0)

In [18]:
medium_train_dataset = get_mapped_dataset(medium_data, format_chat_template)


100%|██████████| 73332/73332 [00:12<00:00, 5857.85it/s]


In [19]:
# get the misclassified data samples

accuracy, missclassified_samples_medium = compute_mcqa_accuracy(trained_model_medium, tokenizer, medium_train_dataset)

print(len(missclassified_samples_medium))

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for

Evaluated 73332 samples
Accuracy: 0.8654
Misclassified 9870 samples
9870


In [20]:
# take top n samples as replay data from misclassified data

top = 6000 # take 15-20% of the total data or 

misclassified_samples_sorted_medium = sorted(
    missclassified_samples_medium,
    key=lambda x: x['loss'],
    reverse=True
)

top_n_medium = misclassified_samples_sorted_medium[:top]

print(len(top_n_medium))

6000


In [21]:
# formatting the dataset to hugging face dataset

medium_temp = medium_train_dataset.to_pandas()

top_n_df_medium = pd.DataFrame(top_n_medium)

selected_subset_medium = medium_temp[medium_temp['question'].isin(top_n_df_medium['question'])]

print(len(selected_subset_medium))

6000


In [22]:
# creating the medium training dataset replaying data from easy subset

hard_tmp = hard_data.to_pandas()

concatenated_df_medium = pd.concat([hard_tmp, selected_subset_medium])

shuffled_df_medium = shuffle(concatenated_df_medium, random_state=42)

shuffled_df_medium = shuffled_df_medium.reset_index(drop=True)

hard_data = Dataset.from_pandas(shuffled_df_medium)

hard_train_dataset = get_mapped_dataset(hard_data, format_chat_template)

100%|██████████| 76690/76690 [00:14<00:00, 5427.09it/s]


In [27]:
# hard_train_dataset.save_to_disk('./hard_train_dataset')

Saving the dataset (1/1 shards): 100%|██████████| 76690/76690 [00:00<00:00, 88970.79 examples/s] 


In [17]:
hard_train_dataset = load_from_disk('./hard_train_dataset')

In [27]:
checkpoint_path = "Llama-3.2-3b-it-Open-medmcqa-baseline-curriculum_data_replay_full_run/medium/checkpoint-4740"
new_base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=bnb_config,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)

# Load LoRA adapter
trained_model = PeftModel.from_pretrained(new_base_model, checkpoint_path)


Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.99s/it]


In [28]:
trained_model.train()

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 3072)
        (layers): ModuleList(
          (0-27): 28 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=3072, out_features=3072, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=3072, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=3072, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lor

In [29]:
for name, param in trained_model.named_parameters():
    if 'lora' in name:
        param.requires_grad = True

In [30]:
for name, param in trained_model.named_parameters():
    if param.requires_grad:
        print(name)

base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight
base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight
base_model.model.model.layers.0.mlp.gate_proj.lora_A.default.weight
base_model.model.model.layers.0.mlp.gate_proj.lora_B.default.weight
base_model.model.model.layers.0.mlp.up_proj.lora_A.default.weight
base_model.model.model.layers.0.mlp.up_proj.lora_B.default.weight
base_model.model.model.layers.0.mlp.down_proj.lora_A.default.weight
base_model.model.model.layers.0.mlp.down_proj.lora_B.default.weight
base_model.model.model.layer

In [32]:
training_arguments.output_dir = os.path.join(root_model_dir, 'hard')

# medium_model = medium_trainer.model

medium_model = trained_model



hard_trainer = SFTTrainer(
    model=medium_model,
    train_dataset=hard_trained_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
)
hard_trainer.train()

hard_trainer.save_model(os.path.join(root_model_dir, 'hard', 'best'))

Converting train dataset to ChatML: 100%|██████████| 76690/76690 [00:09<00:00, 8309.84 examples/s] 
Adding EOS to train dataset: 100%|██████████| 76690/76690 [00:07<00:00, 9601.60 examples/s] 
Tokenizing train dataset: 100%|██████████| 76690/76690 [00:36<00:00, 2090.78 examples/s]
Truncating train dataset: 100%|██████████| 76690/76690 [00:00<00:00, 209946.12 examples/s]
[34m[1mwandb[0m: Currently logged in as: [33msyammohan2103[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
250,0.65,0.916858
500,0.6431,0.919009
750,0.6385,0.906907
1000,0.6381,0.913609
1250,0.6269,0.902891
1500,0.6165,0.900518
1750,0.6116,0.902203
2000,0.6113,0.899162
2250,0.6086,0.900683
2500,0.5705,0.930314


### Model Inference

#### Load the Peft Model

In [None]:
# checkpoint_path = "Llama-3.2-3b-it-Open-medmcqa-baseline/checkpoint-5500"
# new_base_model = AutoModelForCausalLM.from_pretrained(
#     base_model_name,
#     quantization_config=bnb_config,
#     device_map="auto",
#     torch_dtype=torch.bfloat16,
# )

# # Load LoRA adapter
# trained_model = PeftModel.from_pretrained(new_base_model, checkpoint_path)
# trained_model.eval()


mdoel_path = os.path.join(root_model_dir, 'hard', 'best')
trained_model = AutoModelForCausalLM.from_pretrained(model_path) 
trained_model.eval()

In [96]:
# formatting for inference, the format should not have the answer

def format_chat_prompt_for_inference(row):
    instruction = """Answer the following multiple choice question by giving the most appropriate response. 
Answer should be one among [A, B, C, D]."""

    a = row['opa']
    b = row['opb']
    c = row['opc']
    d = row['opd']

    user_instruction = f"""Question: {row['question']}
                A) {a}
                B) {b}
                C) {c}
                D) {d}
            """

    # No assistant response!
    row_json = [
        {"role": "system", "content": instruction},
        {"role": "user", "content": user_instruction}
    ]

    return tokenizer.apply_chat_template(row_json, tokenize=False, add_generation_prompt=True)


In [97]:
sample_row = {
    "question": "A 35-year-old man has sudden severe chest pain radiating to his back. What is the most likely diagnosis?",
    "opa": "Myocardial infarction",
    "opb": "Pulmonary embolism",
    "opc": "Aortic dissection",
    "opd": "Pneumothorax",
    "cop": 2  
}

prompt = format_chat_prompt_for_inference(sample_row)

In [None]:
prompt

In [None]:
inputs = tokenizer(prompt, return_tensors='pt', truncation=True, padding=True).to(model.device)

with torch.no_grad():
    outputs = trained_model.generate(
        **inputs,
        max_new_tokens=1,
        do_sample=False
    )

decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)

predicted_answer = decoded.split("assistant")[-1].strip()
print("Predicted answer:", predicted_answer)

### Evaluate model metrics

In [100]:
# Train Accuracy: 0.7316
# Validation Accuracy: 0.5802

In [101]:
def compute_mcqa_accuracy(model, tokenizer, dataset, max_samples=None):
    model.eval()
    correct = 0
    total = 0
    idx_to_ans_map = {0: "A", 1: "B", 2: "C", 3: "D"}

    if max_samples:
        dataset = dataset.select(range(min(max_samples, len(dataset))))

    for row in dataset:
        # Prepare the prompt
        instruction = """Answer the following multiple choice question by giving the most appropriate response. 
        Answer should be one among [A, B, C, D]."""

        user_instruction = f"""Question: {row['question']}
                A) {row['opa']}
                B) {row['opb']}
                C) {row['opc']}
                D) {row['opd']}
            """

        messages = [
            {"role": "system", "content": instruction},
            {"role": "user", "content": user_instruction}
        ]

        prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device)

        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=1,  # Just want the answer token (A/B/C/D)
                do_sample=False
            )

        decoded = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
        pred_answer = decoded.split("assistant")[-1].strip()[:1]  # Get the first character after "assistant"

        correct_answer = idx_to_ans_map[row['cop']]
        print(pred_answer, correct_answer)
        if pred_answer == correct_answer:
            correct += 1
        total += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"Evaluated {total} samples")
    print(f"Accuracy: {accuracy:.4f}")
    return accuracy


In [None]:
acc = compute_mcqa_accuracy(trained_model, tokenizer, val_dataset, max_samples=2)

In [None]:
acc

In [56]:

# Validation Accuracy = 57%