In [None]:
# ===============================
# STEP 1: Dependencies and Imports
# ===============================
# In a notebook, run this first:
# !pip install -q torch transformers datasets peft accelerate bitsandbytes sentencepiece jsonlines rouge-score evaluate

import os
import random
import json
import re
import torch
from dataclasses import dataclass
from typing import List, Dict
from collections import Counter
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from torch.optim.lr_scheduler import LambdaLR
from rouge_score import rouge_scorer
import pprint

print("PyTorch:", torch.__version__)

# ===============================
# STEP 2: Configuration
# ===============================
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
OUT_DIR = "outputs/llama3_medical_sft"

# --- Dataset Paths ---
TRAIN_FILE = "data/train.jsonl"
VAL_FILE = "data/val.jsonl"
TEST_FILE = "data/test.jsonl"  # Must exist for the final evaluation step

# --- Model & Training Params ---
MAX_LENGTH = 2048
LORA_R = 16
LORA_ALPHA = 16
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "v_proj"]

BATCH_SIZE = 4
EVAL_BATCH = 4
GRAD_ACCUM = 8
EPOCHS = 10
LEARNING_RATE = 2e-5
TWO_STAGE_LR = False

LOGGING_STEPS = 50
SAVE_STEPS = 500

SEED = 42
os.makedirs(OUT_DIR, exist_ok=True)
random.seed(SEED)
torch.manual_seed(SEED)

# ===============================
# STEP 3: Data Collator for SFT
# ===============================
@dataclass
class SFTCollator:
    tokenizer
    max_length: int = 2048

    def __call__(self, features: List[Dict]):
        texts = []
        for f in features:
            messages = [
                {"role": "system", "content": "You are a clinical NLP assistant."},
                {"role": "user", "content": f"{f.get('instruction', '')}\n\nINPUT:\n{f.get('input', '')}"},
                {"role": "assistant", "content": str(f.get('output', ''))}
            ]
            txt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(txt)

        batch = self.tokenizer(
            texts, padding=True, truncation=True,
            max_length=self.max_length, return_tensors="pt"
        )
        labels = batch["input_ids"].clone()
        batch["labels"] = labels
        return batch

# ===============================
# STEP 4: Load Datasets
# ===============================
data_files = {"train": TRAIN_FILE, "validation": VAL_FILE}
ds = load_dataset("json", data_files=data_files)
print("Loaded Training and Validation Datasets:")
print(ds)
print("\nSample:", ds["train"][0])

# ===============================
# STEP 5: Tokenizer & Quantization
# ===============================
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.padding_side = "right"

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.float16
)

# ===============================
# STEP 6: Load Base Model + LoRA
# ===============================
base = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_cfg,
    device_map="auto",
    trust_remote_code=True
)
base = prepare_model_for_kbit_training(base)

lora_cfg = LoraConfig(
    r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT,
    target_modules=TARGET_MODULES, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(base, lora_cfg)
model.print_trainable_parameters()

# ===============================
# STEP 7: Training
# ===============================
train_args = TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=EVAL_BATCH,
    gradient_accumulation_steps=GRAD_ACCUM,
    num_train_epochs=EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    evaluation_strategy="steps",
    eval_steps=SAVE_STEPS,
    save_total_limit=2,
    report_to="none",
    fp16=True,
    lr_scheduler_type="linear",
    optim="adamw_torch",
    remove_unused_columns=False
)

collator = SFTCollator(tok, max_length=MAX_LENGTH)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    tokenizer=tok,
    data_collator=collator
)

if TWO_STAGE_LR:
    trainer.create_optimizer_and_scheduler(trainer.get_num_train_steps())
    def lr_lambda(step):
        ratio = step / max(1, trainer.state.max_steps)
        return (5e-6 if ratio < 0.5 else 1e-4) / 5e-6
    trainer.lr_scheduler = LambdaLR(trainer.optimizer, lr_lambda)

print("\nStarting model training...")
train_result = trainer.train()
trainer.save_state()
print("Training finished.")

# ===============================
# STEP 8: Save Final Model & Tokenizer
# ===============================
final_adapter_dir = os.path.join(OUT_DIR, "final_adapter")
model.save_pretrained(final_adapter_dir)
tok.save_pretrained(final_adapter_dir)
print(f"Saved final adapters and tokenizer to: {final_adapter_dir}")

# ===============================
# STEP 9: Metrics Definitions
# ===============================
def _to_items(y):
    try:
        obj = y if isinstance(y, (list, dict)) else json.loads(y)
    except Exception:
        return []
    items = []
    if isinstance(obj, dict):
        for k, v in obj.items():
            if isinstance(v, list):
                for vi in v: items.append((k, str(vi).lower()))
            else:
                items.append((k, str(v).lower()))
    elif isinstance(obj, list):
        for d in obj:
            if isinstance(d, dict):
                for k, v in d.items():
                    if isinstance(v, list):
                        for vi in v: items.append((k, str(vi).lower()))
                    else:
                        items.append((k, str(v).lower()))
    return items

def ner_prf1(preds, gts):
    tp = fp = fn = 0
    for p, g in zip(preds, gts):
        pset, gset = Counter(_to_items(p)), Counter(_to_items(g))
        for t in set(pset) | set(gset):
            ctp = min(pset[t], gset[t])
            tp += ctp
            fp += max(pset[t] - ctp, 0)
            fn += max(gset[t] - ctp, 0)
    prec = tp / (tp + fp + 1e-12)
    rec = tp / (tp + fn + 1e-12)
    f1 = 2 * prec * rec / (prec + rec + 1e-12)
    return {"precision": prec, "recall": rec, "f1": f1}

def rouge_scores(preds, refs):
    scorer = rouge_scorer.RougeScorer(["rouge1", "rouge2", "rougeL"], use_stemmer=True)
    agg = {"rouge1": 0, "rouge2": 0, "rougeL": 0}
    for p, r in zip(preds, refs):
        s = scorer.score(r, p)
        for k in agg: agg[k] += s[k].fmeasure
    n = max(1, len(preds))
    return {k: v / n for k, v in agg.items()}

def _tok(s): return re.findall(r"\w+", str(s).lower())

def qa_accuracy_f1(preds, refs):
    acc = sum(p.strip() == r.strip() for p, r in zip(preds, refs)) / max(1, len(preds))
    tp = fp = fn = 0
    for p, r in zip(preds, refs):
        P, R = _tok(p), _tok(r)
        Pset, Rset = Counter(P), Counter(R)
        for w in set(Pset) | set(Rset):
            ctp = min(Pset.get(w, 0), Rset.get(w, 0))
            tp += ctp
            fp += max(Pset.get(w, 0) - ctp, 0)
            fn += max(Rset.get(w, 0) - ctp, 0)
    prec = tp / (tp + fp + 1e-12)
    rec = tp / (tp + fn + 1e-12)
    f1 = 2 * prec * rec / (prec + rec + 1e-12)
    return {"accuracy": acc, "f1": f1}

# ===============================
# STEP 10: Evaluate on Test Set
# ===============================
print("\nStarting evaluation on the test set...")
model.eval()
ds_test = load_dataset("json", data_files={"test": TEST_FILE})["test"]

preds, refs, tasks = [], [], []
for ex in ds_test:
    msgs = [
        {"role": "system", "content": "You are a clinical NLP assistant."},
        {"role": "user", "content": f"{ex['instruction']}\n\nINPUT:\n{ex.get('input', '')}"}
    ]
    prompt = tok.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
    inputs = tok(prompt, return_tensors="pt").to(model.device)
    with torch.inference_mode():
        out_ids = model.generate(**inputs, max_new_tokens=256, temperature=0.0, do_sample=False)
    out_text = tok.decode(out_ids[0], skip_special_tokens=True)

    pred = out_text.split("assistant")[-1].strip()
    preds.append(pred)
    refs.append(str(ex["output"]))
    tasks.append(ex.get("task", ""))

# --- Compute metrics by task ---
results = {}
if any(t == "ner" for t in tasks):
    idx = [i for i, t in enumerate(tasks) if t == "ner"]
    results["NER"] = ner_prf1([preds[i] for i in idx], [refs[i] for i in idx])
if any(t == "summarization" for t in tasks):
    idx = [i for i, t in enumerate(tasks) if t == "summarization"]
    results["Summarization"] = rouge_scores([preds[i] for i in idx], [refs[i] for i in idx])
if any(t == "med_qa" for t in tasks):
    idx = [i for i, t in enumerate(tasks) if t == "med_qa"]
    results["MedicalQA"] = qa_accuracy_f1([preds[i] for i in idx], [refs[i] for i in idx])

print("\n" + "="*50)
print("           EVALUATION RESULTS")
print("="*50)
pprint.pprint(results)
print("="*50)