In [None]:
# pip install -U transformers datasets accelerate peft bitsandbytes

import torch
from datasets import load_dataset
from transformers import (AutoTokenizer, AutoModelForCausalLM,
                          TrainingArguments, Trainer, DataCollatorForLanguageModeling,
                          BitsAndBytesConfig)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
dataset = load_dataset("json", data_files="train.jsonl")["train"]  # expects {"text": "..."} lines

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map="auto",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
)

# Needed for QLoRA stability
model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    # target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

def tokenize_fn(ex):
    out = tokenizer(
        ex["text"],
        truncation=True,
        max_length=2048,
        padding=False,
    )
    out["labels"] = out["input_ids"].copy()
    return out

tokenized = dataset.map(tokenize_fn, remove_columns=dataset.column_names)

collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

args = TrainingArguments(
    output_dir="./sft_peft_out",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
    save_steps=200,
    bf16=True,
    gradient_checkpointing=True,
    optim="paged_adamw_8bit",
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized,
    data_collator=collator,
)

trainer.train()

model.save_pretrained("./lora_adapter")
tokenizer.save_pretrained("./lora_adapter")


In [None]:
#!/usr/bin/env python3
"""
SFT with PEFT (QLoRA/LoRA) using your Arrow dataset that has:
  - query: full instruction + INPUT ... + OUTPUT:
  - answer: JSON string (target)

We build model input as:  query + "\n" + answer
Loss is computed ONLY on answer tokens (prompt tokens are masked to -100).
"""

from __future__ import annotations

import argparse
from pathlib import Path
from typing import Dict, Any

import torch
from datasets import load_from_disk
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding,
    BitsAndBytesConfig,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training


# ---------------------------
# Tokenization / Label masking
# ---------------------------
def build_features(example: Dict[str, Any], tokenizer, max_length: int) -> Dict[str, Any]:
    """
    Build input_ids/attention_mask/labels such that:
    - input = query + "\n" + answer
    - labels = -100 for prompt part; labels = answer token ids for answer part
    """
    query = example["query"].rstrip()
    answer = example["answer"].strip()

    # Add a separator between prompt and answer
    prompt_text = query + "\n"
    full_text = prompt_text + answer

    prompt_ids = tokenizer(prompt_text, add_special_tokens=False)["input_ids"]
    full = tokenizer(
        full_text,
        add_special_tokens=False,
        truncation=True,
        max_length=max_length,
    )
    input_ids = full["input_ids"]
    attention_mask = full["attention_mask"]

    # labels: mask prompt tokens
    labels = [-100] * len(input_ids)

    # Find where answer starts (prompt length), but be careful with truncation
    ans_start = min(len(prompt_ids), len(input_ids))

    # Supervision only on answer tokens
    for i in range(ans_start, len(input_ids)):
        labels[i] = input_ids[i]

    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}


class SimpleCollator(DataCollatorWithPadding):
    """
    Pads input_ids/attention_mask/labels to the longest in batch.
    Ensures labels are padded with -100.
    """

    def __call__(self, features):
        # DataCollatorWithPadding pads using tokenizer.pad; but we must pad labels ourselves.
        batch = super().__call__([{k: v for k, v in f.items() if k != "labels"} for f in features])

        # Pad labels
        max_len = batch["input_ids"].shape[1]
        padded_labels = []
        for f in features:
            lab = f["labels"]
            pad_len = max_len - len(lab)
            padded_labels.append(lab + [-100] * pad_len)

        batch["labels"] = torch.tensor(padded_labels, dtype=torch.long)
        return batch


# ---------------------------
# Main
# ---------------------------
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset_path", type=str, 
                        help="Path to dataset saved by dataset.save_to_disk(). "
                             "Can be either a Dataset (no split) or a DatasetDict with train/test.",
                       default= "/home/lm2445/project_pi_sjf37/lm2445/PV_multiagent/benckmark/PV_benckmark/split_out/non_test/data-00000-of-00001.arrow")
    parser.add_argument("--model_name", type=str, 
                        help="Base causal LM name/path (e.g., meta-llama/Meta-Llama-3-8B-Instruct).",
                       default = "Qwen/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--output_dir", type=str, default="./sft_peft_out")
    parser.add_argument("--max_length", type=int, default=8192)

    # Training hyperparams (reasonable for 32GB + QLoRA)
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--grad_accum", type=int, default=16)
    parser.add_argument("--epochs", type=float, default=1.0)
    parser.add_argument("--lr", type=float, default=2e-4)
    parser.add_argument("--logging_steps", type=int, default=10)
    parser.add_argument("--save_steps", type=int, default=200)

    # PEFT/QLoRA toggles
    parser.add_argument("--use_qlora", action="store_true", help="Enable 4-bit QLoRA (recommended for 32GB).")
    parser.add_argument("--bf16", action="store_true", help="Use bf16 (recommended if supported). If off, uses fp16.")
    parser.add_argument("--lora_r", type=int, default=16)
    parser.add_argument("--lora_alpha", type=int, default=32)
    parser.add_argument("--lora_dropout", type=float, default=0.05)

    # Optional eval
    parser.add_argument("--do_eval", action="store_true", help="Run eval if test split exists.")
    args = parser.parse_args()

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    # ---- Load dataset ----
    ds_obj = load_from_disk(args.dataset_path)
    if hasattr(ds_obj, "keys") and "train" in ds_obj:
        train_ds = ds_obj["train"]
        eval_ds = ds_obj.get("test", None)
    else:
        train_ds = ds_obj
        eval_ds = None

    # ---- Tokenizer ----
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    # ---- Model ----
    torch_dtype = torch.bfloat16 if args.bf16 else torch.float16

    quant_config = None
    if args.use_qlora:
        quant_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch_dtype,
            bnb_4bit_use_double_quant=True,
        )

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name,
        device_map="auto",
        torch_dtype=torch_dtype,
        quantization_config=quant_config,
    )

    if args.use_qlora:
        model = prepare_model_for_kbit_training(model)

    # Target modules for Llama-like models; if you use a different architecture, change this list.
    #target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"]

    lora_cfg = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )

    model = get_peft_model(model, lora_cfg)
    model.print_trainable_parameters()

    # ---- Preprocess dataset (tokenize + labels) ----
    def _map_fn(ex):
        return build_features(ex, tokenizer=tokenizer, max_length=args.max_length)

    train_tok = train_ds.map(_map_fn, remove_columns=train_ds.column_names, desc="Tokenizing train")
    if eval_ds is not None and args.do_eval:


In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel

base = "meta-llama/Meta-Llama-3-8B-Instruct"
adapter = "./lora_adapter"

tok = AutoTokenizer.from_pretrained(base, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(base, torch_dtype=torch.bfloat16, device_map="auto")
model = PeftModel.from_pretrained(model, adapter)

prompt = "Write a 2-sentence summary of patient-centered communication."
inputs = tok(prompt, return_tensors="pt").to(model.device)

with torch.no_grad():
    out = model.generate(**inputs, max_new_tokens=80)
print(tok.decode(out[0], skip_special_tokens=True))
