In [None]:
"""
CPU-only LoRA fine-tuning for:
  Base model: Qwen/Qwen2.5-1.5B-Instruct
  Dataset:    Medical_QA_Dataset.csv with columns [qtype, Question, Answer]
------------------------------------------------------------
RUN (script):
python finetune_qwen25_medqa_cpu.py
------------------------------------------------------------
"""

import os
import math
import time
import random
import inspect
from dataclasses import dataclass
from typing import Dict, List, Any, Tuple

import pandas as pd
import torch
from packaging import version

from datasets import Dataset
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    set_seed,
)

from peft import LoraConfig, get_peft_model, PeftModel


# -----------------------------
# Your paths / defaults
# -----------------------------
DEFAULT_DATA_PATH = r"YOUR DIRECTORY to the Dataset\Medical_QA_Dataset.csv"
DEFAULT_MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
DEFAULT_OUTPUT_DIR = "./qwen25_1p5b_medqa_lora_cpu"

DEFAULT_MAX_SEQ_LEN = 512
DEFAULT_TEST_SIZE = 0.05
DEFAULT_SEED = 42

# LoRA defaults
DEFAULT_LORA_R = 8
DEFAULT_LORA_ALPHA = 16
DEFAULT_LORA_DROPOUT = 0.05


# -----------------------------
# Safety-oriented system prompt
# -----------------------------
SYSTEM_PROMPT = (
    "You are a careful medical information assistant. Provide general educational information, "
    "not personal medical advice. Encourage consulting qualified clinicians for diagnosis and treatment. "
    "If symptoms suggest an emergency, advise seeking urgent care. If unsure, say you don't know."
)


def parse_args():
    """
    IMPORTANT FIX:
    - Jupyter/IPykernel passes extra args like: -f <path_to_kernel.json>
    - argparse would normally crash.
    - We use parse_known_args() and ignore unknown args.
    """
    import argparse

    p = argparse.ArgumentParser()

    p.add_argument("--data_path", type=str, default=DEFAULT_DATA_PATH)
    p.add_argument("--model_name", type=str, default=DEFAULT_MODEL_NAME)
    p.add_argument("--output_dir", type=str, default=DEFAULT_OUTPUT_DIR)

    p.add_argument("--max_seq_len", type=int, default=DEFAULT_MAX_SEQ_LEN)
    p.add_argument("--test_size", type=float, default=DEFAULT_TEST_SIZE)
    p.add_argument("--seed", type=int, default=DEFAULT_SEED)

    # Training hyperparams (CPU-friendly)
    p.add_argument("--per_device_train_batch_size", type=int, default=1)
    p.add_argument("--per_device_eval_batch_size", type=int, default=1)
    p.add_argument("--gradient_accumulation_steps", type=int, default=16)
    p.add_argument("--learning_rate", type=float, default=2e-4)
    p.add_argument("--weight_decay", type=float, default=0.0)
    p.add_argument("--num_train_epochs", type=float, default=1.0)
    p.add_argument("--max_steps", type=int, default=-1)  # set >0 to cap steps for quick test
    p.add_argument("--warmup_ratio", type=float, default=0.03)
    p.add_argument("--lr_scheduler_type", type=str, default="cosine")

    # LoRA hyperparams
    p.add_argument("--lora_r", type=int, default=DEFAULT_LORA_R)
    p.add_argument("--lora_alpha", type=int, default=DEFAULT_LORA_ALPHA)
    p.add_argument("--lora_dropout", type=float, default=DEFAULT_LORA_DROPOUT)

    # Data limiting for smoke tests
    p.add_argument("--max_train_samples", type=int, default=-1)
    p.add_argument("--max_eval_samples", type=int, default=-1)

    # Optional: merge adapter into base weights (standalone model)
    p.add_argument("--merge_model", action="store_true")

    # Optional: quick test generation
    p.add_argument("--run_test_prompt", action="store_true")

    args, unknown = p.parse_known_args()
    if unknown:
        print(f"[INFO] Ignoring unknown argv passed by environment: {unknown}")
    return args


def require_min_transformers():
    """
    Qwen2.5 generally needs a reasonably recent Transformers.
    If your environment is older, you may see errors like KeyError: 'qwen2'.
    """
    min_ver = version.parse("4.37.0")
    cur_ver = version.parse(transformers.__version__)
    if cur_ver < min_ver:
        raise RuntimeError(
            f"transformers>={min_ver} required for Qwen2.5. You have transformers=={transformers.__version__}. "
            f"Run: pip install -U transformers"
        )


def safe_filter_kwargs_for_callable(callable_obj, kwargs: Dict[str, Any]) -> Dict[str, Any]:
    """
    Keep only kwargs that the callable's signature accepts.
    This makes the script robust across Transformers versions.
    """
    try:
        sig = inspect.signature(callable_obj)
        accepted = set(sig.parameters.keys())
        return {k: v for k, v in kwargs.items() if k in accepted}
    except (TypeError, ValueError):
        # If signature can't be inspected, return original kwargs
        return kwargs


def load_and_clean_csv(data_path: str) -> Dataset:
    if not os.path.isfile(data_path):
        raise FileNotFoundError(f"CSV not found: {data_path}")

    df = pd.read_csv(data_path)

    required = ["qtype", "Question", "Answer"]
    missing = [c for c in required if c not in df.columns]
    if missing:
        raise ValueError(f"Missing columns in CSV: {missing}. Found columns: {list(df.columns)}")

    df = df[required].copy()
    df["qtype"] = df["qtype"].fillna("").astype(str).str.strip()
    df["Question"] = df["Question"].fillna("").astype(str).str.strip()
    df["Answer"] = df["Answer"].fillna("").astype(str).str.strip()

    # Drop empty Q/A
    df = df[(df["Question"] != "") & (df["Answer"] != "")].reset_index(drop=True)
    return Dataset.from_pandas(df, preserve_index=False)


def build_user_content(qtype: str, question: str) -> str:
    if qtype and qtype.lower() not in {"nan", "none"}:
        return f"Question type: {qtype}\n\nQuestion: {question}"
    return question


def find_lora_target_modules(model: torch.nn.Module) -> List[str]:
    """
    Auto-detect common projection names for LoRA on Qwen/Llama-like architectures.
    """
    common = {"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}
    found = set()

    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Linear):
            last = name.split(".")[-1]
            if last in common:
                found.add(last)

    if not found:
        raise RuntimeError(
            "Could not auto-detect LoRA target modules. "
            "Inspect model.named_modules() and set target_modules manually."
        )

    return sorted(found)


@dataclass
class DataCollatorForCausalLMWithLabels:
    tokenizer: Any

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(
            {
                "input_ids": [f["input_ids"] for f in features],
                "attention_mask": [f["attention_mask"] for f in features],
            },
            padding=True,
            return_tensors="pt",
        )

        max_len = batch["input_ids"].shape[1]
        labels = []
        for f in features:
            lab = f["labels"]
            if len(lab) < max_len:
                lab = lab + [-100] * (max_len - len(lab))
            else:
                lab = lab[:max_len]
            labels.append(lab)

        batch["labels"] = torch.tensor(labels, dtype=torch.long)
        return batch


def make_tokenize_fn(tokenizer: Any, max_seq_len: int):
    """
    Creates a function that:
    - builds chat-formatted text with tokenizer.apply_chat_template
    - tokenizes prompt and full conversation
    - masks labels on the prompt so loss is only computed on the assistant answer
    """

    def tokenize_example(ex: Dict[str, Any]) -> Dict[str, Any]:
        qtype = ex.get("qtype", "")
        question = ex.get("Question", "")
        answer = ex.get("Answer", "")

        user_content = build_user_content(qtype=qtype, question=question)

        messages_prompt = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": user_content},
        ]

        prompt_text = tokenizer.apply_chat_template(
            messages_prompt,
            tokenize=False,
            add_generation_prompt=True,
        )

        messages_full = messages_prompt + [{"role": "assistant", "content": answer}]
        full_text = tokenizer.apply_chat_template(
            messages_full,
            tokenize=False,
            add_generation_prompt=False,
        )

        prompt_ids = tokenizer(
            prompt_text,
            add_special_tokens=False,
            truncation=True,
            max_length=max_seq_len,
        )["input_ids"]

        full = tokenizer(
            full_text,
            add_special_tokens=False,
            truncation=True,
            max_length=max_seq_len,
        )

        input_ids = full["input_ids"]
        attention_mask = full["attention_mask"]

        labels = input_ids.copy()
        prompt_len = min(len(prompt_ids), len(labels))
        for i in range(prompt_len):
            labels[i] = -100

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

    return tokenize_example


def maybe_make_output_dir_unique(output_dir: str) -> str:
    """
    If output_dir exists and is non-empty, make a unique directory to avoid collisions,
    without deleting anything.
    """
    if os.path.isdir(output_dir) and len(os.listdir(output_dir)) > 0:
        stamp = time.strftime("%Y%m%d_%H%M%S")
        new_dir = f"{output_dir.rstrip('/\\\\')}_{stamp}"
        print(f"[INFO] output_dir is not empty. Writing to a new directory:\n       {new_dir}")
        os.makedirs(new_dir, exist_ok=True)
        return new_dir

    os.makedirs(output_dir, exist_ok=True)
    return output_dir


def build_training_arguments(args) -> TrainingArguments:
    """
    Robust TrainingArguments builder across Transformers versions:
    - We prepare a superset of args
    - Then filter by what's supported in your installed version
    """

    # Some versions may not support overwrite_output_dir.
    # To avoid failing or overwriting user data, we just ensure a unique output directory if needed.
    args.output_dir = maybe_make_output_dir_unique(args.output_dir)

    base_kwargs = {
        "output_dir": args.output_dir,

        # Batch & accumulation
        "per_device_train_batch_size": args.per_device_train_batch_size,
        "per_device_eval_batch_size": args.per_device_eval_batch_size,
        "gradient_accumulation_steps": args.gradient_accumulation_steps,

        # Optim schedule
        "learning_rate": args.learning_rate,
        "weight_decay": args.weight_decay,
        "num_train_epochs": args.num_train_epochs,
        "max_steps": args.max_steps,
        "warmup_ratio": args.warmup_ratio,
        "lr_scheduler_type": args.lr_scheduler_type,

        # Logging/saving
        "logging_steps": 10,
        "save_total_limit": 2,

        # Trainer behavior
        "remove_unused_columns": False,
        "report_to": "none",

        # Windows / CPU friendliness
        "dataloader_num_workers": 0,
        "dataloader_pin_memory": False,

        # Force CPU in most versions
        "no_cuda": True,     # if supported
        "use_cpu": True,     # if supported

        # Precision flags (should be off on CPU)
        "fp16": False,
        "bf16": False,

        # Helpful on low-memory (slower but reduces RAM)
        "gradient_checkpointing": True,

        # Optimizer choice (if supported)
        "optim": "adamw_torch",
    }

    # Add evaluation/save strategy ONLY if supported (names changed in some versions)
    ta_sig = inspect.signature(TrainingArguments.__init__)
    accepted = set(ta_sig.parameters.keys())

    if "evaluation_strategy" in accepted:
        base_kwargs["evaluation_strategy"] = "epoch"
    elif "eval_strategy" in accepted:
        base_kwargs["eval_strategy"] = "epoch"

    if "save_strategy" in accepted:
        base_kwargs["save_strategy"] = "epoch"

    # overwrite_output_dir is optional and apparently not supported in your env.
    # Only pass if accepted.
    if "overwrite_output_dir" in accepted:
        base_kwargs["overwrite_output_dir"] = True

    filtered = {k: v for k, v in base_kwargs.items() if k in accepted}
    return TrainingArguments(**filtered)


def build_trainer(model, training_args, train_dataset, eval_dataset, data_collator, tokenizer) -> Trainer:
    """
    Trainer signatures also change across versions (tokenizer arg deprecations).
    We pass only what is supported.
    """
    trainer_kwargs = {
        "model": model,
        "args": training_args,
        "train_dataset": train_dataset,
        "eval_dataset": eval_dataset,
        "data_collator": data_collator,
        # "tokenizer": tokenizer,  # only if supported
    }

    tr_sig = inspect.signature(Trainer.__init__)
    accepted = set(tr_sig.parameters.keys())

    if "tokenizer" in accepted:
        trainer_kwargs["tokenizer"] = tokenizer
    elif "processing_class" in accepted:
        # Some newer versions replace "tokenizer" with "processing_class"
        trainer_kwargs["processing_class"] = tokenizer

    trainer_kwargs = {k: v for k, v in trainer_kwargs.items() if k in accepted}
    return Trainer(**trainer_kwargs)


def load_model_cpu(model_name: str):
    """
    Load model on CPU in float32.
    Avoids torch_dtype deprecation warnings where possible.
    """
    # Many versions accept torch_dtype; some warn and prefer dtype.
    # We'll try dtype first, fall back to torch_dtype.
    try:
        return AutoModelForCausalLM.from_pretrained(
            model_name,
            dtype=torch.float32,
            trust_remote_code=False,
        )
    except Exception:
        return AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float32,
            trust_remote_code=False,
        )


def main():
    args = parse_args()

    print(f"[INFO] torch: {torch.__version__}")
    print(f"[INFO] transformers: {transformers.__version__}")

    require_min_transformers()

    set_seed(args.seed)
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Load dataset
    raw_ds = load_and_clean_csv(args.data_path)
    split = raw_ds.train_test_split(test_size=args.test_size, seed=args.seed)
    train_ds = split["train"]
    eval_ds = split["test"]

    # Optional: limit samples for CPU speed
    if args.max_train_samples and args.max_train_samples > 0:
        train_ds = train_ds.select(range(min(args.max_train_samples, len(train_ds))))
    if args.max_eval_samples and args.max_eval_samples > 0:
        eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds))))

    # Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name, trust_remote_code=False)
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

    # Model (CPU float32)
    model = load_model_cpu(args.model_name)

    # Training settings
    model.config.use_cache = False
    if hasattr(model, "gradient_checkpointing_enable"):
        model.gradient_checkpointing_enable()
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()

    # LoRA
    target_modules = find_lora_target_modules(model)
    print(f"[INFO] LoRA target_modules detected: {target_modules}")

    lora_config = LoraConfig(
        r=args.lora_r,
        lora_alpha=args.lora_alpha,
        lora_dropout=args.lora_dropout,
        bias="none",
        task_type="CAUSAL_LM",
        target_modules=target_modules,
    )
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()

    # Tokenize
    tokenize_fn = make_tokenize_fn(tokenizer, args.max_seq_len)
    train_tok = train_ds.map(tokenize_fn, remove_columns=train_ds.column_names)
    eval_tok = eval_ds.map(tokenize_fn, remove_columns=eval_ds.column_names)

    data_collator = DataCollatorForCausalLMWithLabels(tokenizer)

    # TrainingArguments (robust)
    training_args = build_training_arguments(args)

    # Trainer (robust)
    trainer = build_trainer(
        model=model,
        training_args=training_args,
        train_dataset=train_tok,
        eval_dataset=eval_tok,
        data_collator=data_collator,
        tokenizer=tokenizer,
    )

    # Train
    print("[INFO] Starting training...")
    train_result = trainer.train()
    if hasattr(train_result, "metrics"):
        trainer.log_metrics("train", train_result.metrics)
        trainer.save_metrics("train", train_result.metrics)
    trainer.save_state()

    # Eval
    print("[INFO] Evaluating...")
    eval_metrics = trainer.evaluate()
    if "eval_loss" in eval_metrics and eval_metrics["eval_loss"] is not None:
        try:
            eval_metrics["perplexity"] = math.exp(eval_metrics["eval_loss"])
        except OverflowError:
            eval_metrics["perplexity"] = float("inf")
    trainer.log_metrics("eval", eval_metrics)
    trainer.save_metrics("eval", eval_metrics)

    # Save LoRA adapter
    adapter_dir = os.path.join(training_args.output_dir, "adapter")
    os.makedirs(adapter_dir, exist_ok=True)
    print(f"[INFO] Saving LoRA adapter to: {adapter_dir}")
    model.save_pretrained(adapter_dir)
    tokenizer.save_pretrained(adapter_dir)

    # Optional: merge LoRA into base model
    if args.merge_model:
        print("[INFO] Merging LoRA adapter into base model (CPU)...")
        base = load_model_cpu(args.model_name)
        merged = PeftModel.from_pretrained(base, adapter_dir)
        merged = merged.merge_and_unload()

        merged_dir = os.path.join(training_args.output_dir, "merged_model")
        os.makedirs(merged_dir, exist_ok=True)
        print(f"[INFO] Saving merged model to: {merged_dir}")
        merged.save_pretrained(merged_dir, safe_serialization=True)
        tokenizer.save_pretrained(merged_dir)

    # Optional: quick generation test
    if args.run_test_prompt:
        print("[INFO] Running a quick generation test...")
        model.eval()

        test_question = "What are common symptoms of influenza (flu), and when should someone seek urgent care?"
        messages = [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": test_question},
        ]
        text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        inputs = tokenizer([text], return_tensors="pt")

        with torch.no_grad():
            out = model.generate(
                **inputs,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
            )

        gen = out[0][inputs["input_ids"].shape[1] :]
        print("\n=== MODEL OUTPUT ===")
        print(tokenizer.decode(gen, skip_special_tokens=True))
        print("====================\n")

    print("[DONE] Training complete.")
    print(f"[DONE] Adapter saved at: {adapter_dir}")
    if args.merge_model:
        print(f"[DONE] Merged model saved at: {os.path.join(training_args.output_dir, 'merged_model')}")


if __name__ == "__main__":
    main()


[INFO] Ignoring unknown argv passed by environment: ['-f', 'C:\\Users\\Jaber\\AppData\\Roaming\\jupyter\\runtime\\kernel-5fa8135d-4c2d-4c81-b489-a61ed1f53f92.json']
[INFO] torch: 2.10.0+cpu
[INFO] transformers: 5.1.0


Loading weights:   0%|          | 0/338 [00:00<?, ?it/s]

[INFO] LoRA target_modules detected: ['down_proj', 'gate_proj', 'k_proj', 'o_proj', 'q_proj', 'up_proj', 'v_proj']
trainable params: 9,232,384 || all params: 1,552,946,688 || trainable%: 0.5945


Map:   0%|          | 0/15586 [00:00<?, ? examples/s]

Map:   0%|          | 0/821 [00:00<?, ? examples/s]

warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


[INFO] Starting training...


Epoch,Training Loss,Validation Loss
1,1.089835,1.093011


***** train metrics *****
  epoch                    =                1.0
  total_flos               =         34264056GF
  train_loss               =             1.0867
  train_runtime            = 3 days, 6:56:47.67
  train_samples_per_second =              0.055
  train_steps_per_second   =              0.003
[INFO] Evaluating...


***** eval metrics *****
  epoch                   =        1.0
  eval_loss               =      1.093
  eval_runtime            = 0:59:45.62
  eval_samples_per_second =      0.229
  eval_steps_per_second   =      0.229
  perplexity              =     2.9832
[INFO] Saving LoRA adapter to: ./qwen25_1p5b_medqa_lora_cpu\adapter
[DONE] Training complete.
[DONE] Adapter saved at: ./qwen25_1p5b_medqa_lora_cpu\adapter
