In [20]:
pip install torch torchvision transformers accelerate datasets peft bitsandbytes



In [21]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
from datasets import load_dataset
import json

In [22]:
# 4‑bit quantization config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)


model = AutoModelForCausalLM.from_pretrained(
    "BioMistral/BioMistral-7B",
    quantization_config=quant_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained("BioMistral/BioMistral-7B")

In [23]:
# Preprocess for k‑bit training
model = prepare_model_for_kbit_training(model)

# LoRA adapter config
lora_config = LoraConfig(
    r=16,
    lora_alpha=8,
    target_modules="all-linear",
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
# Wrap model with LoRA adapters
model = get_peft_model(model, lora_config)

In [24]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [25]:
import os
os.chdir("/content/drive/MyDrive/CS685 Final/")
print(os.listdir())

['sft_data_train.jsonl', 'sft_data_test.jsonl', 'mistral-1b-qa-qlora2', 'wandb', 'unsloth_compiled_cache', 'huggingface_tokenizers_cache', 'mistral-1b-qa-qlora3', 'biomistral-7b-2']


In [26]:
# 1. Load raw JSONL
ds = load_dataset("json", data_files={"train":"sft_data_train.jsonl","test":"sft_data_test.jsonl"})
split_ds = ds['train'].train_test_split(test_size=0.1, seed=42)
train_set = split_ds['train']
eval_set = split_ds['test']


# ------------------------------------------------------------
# 0 .  One‑time tokenizer fixes
# ------------------------------------------------------------
tokenizer.pad_token = "<|im_end|>"             # same token BioMistral uses for EOS
tokenizer.padding_side = "left"                # left‑pad so causal mask still works

# ------------------------------------------------------------
# 1 .  Chat‑style formatting helpers
# ------------------------------------------------------------
SYSTEM_MSG = (
    "You are a board‑certified physician. Answer a multiple‑choice "
    "USMLE‑style question that presents exactly four options (A–D). "
    "Respond in the form:\n\n"
    "Answer: <LETTER>\n"
    "Explanation: <ONE‑OR‑TWO SENTENCES>\n"
)

SYS_HDR  = "<|im_start|>system\n"
USR_HDR  = "<|im_start|>user\n"
ASST_HDR = "<|im_start|>assistant\n"
END      = "<|im_end|>"

def build_chat(question: str, answer: str | None = None) -> str:
    """Return a single ChatML string.
       If `answer` is None we stop before the assistant content (for generation)."""
    chat =  (
        f"{SYS_HDR}{SYSTEM_MSG}{END}\n"
        f"{USR_HDR}{question}{END}\n"
        f"{ASST_HDR}"
    )
    if answer is not None:
        chat += f"{answer}{END}"
    return chat

# ------------------------------------------------------------
# 2 .  New preprocess function
# ------------------------------------------------------------
def preprocess(batch):
    texts   = [build_chat(q, a) for q, a in zip(batch["prompt"], batch["response"])]
    enc     = tokenizer(texts,
                        truncation=True,
                        max_length=640,          # 512 Q + 128 answer
                        padding=False,
                        return_attention_mask=True)

    labels  = []
    for q, a, ids in zip(batch["prompt"], batch["response"], enc["input_ids"]):
        # length of the string *up to* (but not including) the answer tokens
        prompt_len = len(
            tokenizer(build_chat(q, None), add_special_tokens=False)["input_ids"]
        )
        labels.append([-100] * prompt_len + ids[prompt_len:])

    enc["labels"] = labels
    return enc

tokenized_train = train_set.map(preprocess, batched=True, remove_columns=["prompt","response"])
tokenized_eval = eval_set.map(preprocess, batched=True, remove_columns=["prompt","response"])

In [27]:
from torch.nn.functional import pad

pad_id = tokenizer.pad_token_id

def causal_lm_collate(examples):
    # convert lists → 1‑D int tensors
    ids   = [torch.tensor(e["input_ids"],  dtype=torch.long) for e in examples]
    labs  = [torch.tensor(e["labels"],     dtype=torch.long) for e in examples]

    # build an attention‑mask of 1s for each example
    masks = [torch.ones_like(t) for t in ids]

    max_len = max(t.size(0) for t in ids)

    ids   = [pad(t, (0, max_len - t.size(0)), value=pad_id)    for t in ids]
    labs  = [pad(t, (0, max_len - t.size(0)), value=-100)      for t in labs]
    masks = [pad(t, (0, max_len - t.size(0)), value=0)         for t in masks]

    batch = {
        "input_ids":      torch.stack(ids),
        "attention_mask": torch.stack(masks),
        "labels":         torch.stack(labs),
    }
    return batch

In [28]:
from transformers import DataCollatorForLanguageModeling, TrainingArguments, Trainer

model.gradient_checkpointing_enable()

training_args = TrainingArguments(
    output_dir="biomistral-7b-2",
    per_device_train_batch_size=12, # CHANGED FROM 8
    gradient_accumulation_steps=8,
    num_train_epochs=3, # CHANGED FROM 2
    learning_rate=2e-4, # CHANGED FROM 4e-4
    lr_scheduler_type="cosine",
    warmup_ratio=0.05,
    max_grad_norm=1.0,
    fp16=True,
    logging_steps=10,
    save_steps=500,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset = tokenized_eval,
    data_collator=causal_lm_collate,
    tokenizer=tokenizer
)

trainer.train()

# Save LoRA adapters and tokenizer
model.save_pretrained("biomistral-7b-2")
tokenizer.save_pretrained("biomistral-7b-2")

  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33msuryamg13[0m ([33msuryamg13-umass-amherst[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss
10,0.9748
20,0.6233
30,0.5477
40,0.5327
50,0.5007
60,0.4864
70,0.4822
80,0.4784
90,0.4219
100,0.3861


('biomistral-7b-2/tokenizer_config.json',
 'biomistral-7b-2/special_tokens_map.json',
 'biomistral-7b-2/tokenizer.model',
 'biomistral-7b-2/added_tokens.json',
 'biomistral-7b-2/tokenizer.json')

In [30]:
import torch, re, math
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score
from transformers import (AutoModelForCausalLM,
                          AutoTokenizer)

device = "cuda" if torch.cuda.is_available() else "cpu"

model_id  = "biomistral-7b-2"
model     = AutoModelForCausalLM.from_pretrained(model_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id)
model.eval()

EOS_ID   = tokenizer.eos_token_id
PAD_ID   = tokenizer.pad_token_id

# Collator that pads input_ids, attention_mask, labels together

def collate_test(batch):
    ids, attn, lbl = [], [], []
    for ex in batch:
        ids .append(torch.tensor(ex["input_ids"],      dtype=torch.long))
        attn.append(torch.tensor(ex["attention_mask"], dtype=torch.long))
        lbl .append(torch.tensor(ex["labels"],         dtype=torch.long))

    maxlen = max(t.size(0) for t in ids)
    pad_   = lambda seqs, val: pad_sequence(seqs, batch_first=True, padding_value=val)

    return {
        "input_ids":      pad_(ids,  PAD_ID).to(device),
        "attention_mask": pad_(attn, 0).to(device),
        "labels":         pad_(lbl, -100).to(device)
    }


incorrect_answers = {}

tokenized_test = ds["test"].map(preprocess, batched=True, remove_columns=["prompt","response"])

loader = DataLoader(
    tokenized_test,        # your saved split
    batch_size=1,
    shuffle=False,
    collate_fn=collate_test
)

pred_letters, gold_letters, ppl_vals = [], [], []

for batch in tqdm(loader, desc="eval"):
    inp_ids  = batch["input_ids"]
    attn_mask= batch["attention_mask"]
    labels   = batch["labels"]

    # Determine for each example where the prompt ends
    # Boolean mask of where the answer/explanation starts
    is_answer = labels.ne(-100)
    # position of first True in each row
    first_idx = is_answer.float().argmax(dim=1)    # shape [B]

    # we generate with ONLY the prompt tokens
    max_prompt = first_idx.max().item() + 1            # longest prompt in batch
    prompt_ids   = torch.stack([
        torch.cat([row[:p+1],              # prompt + eos
                   row.new_zeros(max_prompt-p-1)])      # pad
        for row, p in zip(inp_ids, first_idx)
    ])
    prompt_mask  = (prompt_ids != 0) & (prompt_ids != PAD_ID)

    # Generate predictions
    gen = model.generate(
        input_ids = prompt_ids,
        attention_mask = prompt_mask,
        max_new_tokens = 64,
        early_stopping=True,
        pad_token_id = tokenizer.pad_token_id,
        repetition_penalty = 1.5,       # discourages copying
        no_repeat_ngram_size = 4,       # hard constraint
        eos_token_id = tokenizer.eos_token_id,
    )

    # Extract predicted letter
    for g, p_len in zip(gen, first_idx):
        txt = tokenizer.decode(g[p_len+1:], skip_special_tokens=True)
        m = re.search(r"^[\s\n]*([A-D])", txt)   # first bare letter A–D
        #print("Model answer: " + m.group(1) if m else "X")
        pred_letters.append(m.group(1) if m else "X")

    # Gold letter & explanation perplexity (no re‑tokenising)
    for ids, p_len in zip(inp_ids, first_idx):
        # everything up to (and including) the answer letter is context
        raw_text  = tokenizer.decode(ids[p_len:], skip_special_tokens=True)

        gold_char = re.search(r"\b[A-D]\b", raw_text).group(0)
        #print("Correct answer " + gold_char)
        gold_letters.append(gold_char)
        match_pos = re.search(r"\b[A-D]\b", raw_text).end() if re.search(r"\b[A-D]\b", raw_text) else 1

        ctx_ids  = ids[: p_len + match_pos].unsqueeze(0)
        expl_ids = ids[p_len + match_pos :].unsqueeze(0)

        # Build input and labels
        input_ids = torch.cat([ctx_ids, expl_ids], dim=1)       # [1, C+E]
        labels     = input_ids.clone()                          # copy
        labels[:, : ctx_ids.size(1)] = -100                     # mask prompt+answer
        with torch.no_grad():
            loss = model(input_ids=input_ids, labels=labels).loss
        ppl_vals.append(math.exp(loss.item()))

# ------------------------------------------------------------------
# Metrics
# ------------------------------------------------------------------
acc = accuracy_score(gold_letters, pred_letters)
ppl = sum(ppl_vals)/len(ppl_vals)

print(f"Letter accuracy : {acc*100:.2f}%")
print(f"Avg explain PPL : {ppl:.2f}")

KeyboardInterrupt: 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cp -r biomistral-7b-2/ /content/drive/MyDrive

In [None]:
cp -r wandb/run-20250505_041334-beclxxig/ /content/drive/MyDrive