# MedGemma 4B LoRA Fine‑Tuning (FP16/BF16, No Quantization)

**Notebook version of `train_lora_4b.py`** — split into clean, runnable sections.

## What this does
- Loads **MedGemma 1.5 4B IT** in **fp16/bf16** (no 4‑bit / bitsandbytes)
- Adds **LoRA adapters** (PEFT)
- Uses **assistant‑only loss masking** (labels = -100 on prompt tokens)
- Uses **pad‑safe collator** (`labels` padded with -100)
- Adds **token_type_ids** (required for Gemma3 training)
- Avoids LoRA injection into the vision tower by filtering module paths
- Hard‑disables `use_cache` everywhere + uses non‑reentrant gradient checkpointing where possible

## Prereqs
- CUDA GPU required
- Your dataset JSONL files must have fields: `prompt`, `target`
- You must have **`hf_auth.py`** available (same folder or in `PYTHONPATH`) providing:
  - `get_hf_token(hf_token_arg: str) -> str`
  - `try_with_token(fn, *args, token=..., **kwargs)`


## 1) Install dependencies (optional)
If you're in a fresh environment, run this cell once.
> If you already have these installed, you can skip.

In [None]:
# If needed (uncomment):
# %pip install -U "torch" "transformers" "datasets" "peft" "numpy"

# Optional sanity prints:
import sys, torch
print("Python:", sys.version)
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))


## 2) Configuration
Edit these values instead of CLI flags.

In [None]:
# =============================
# CONFIG (edit me)
# =============================
CFG = dict(
    train_jsonl="data/train.jsonl",
    val_jsonl="data/val.jsonl",
    model_name="google/medgemma-1.5-4b-it",
    out_dir="medgemma4b_icu_lora",

    max_len=1024,
    eval_max_len=0,      # 0 => use max_len; set smaller to reduce eval VRAM
    epochs=1,
    lr=2e-4,
    batch=1,
    grad_accum=16,
    seed=7,

    lora_r=16,
    lora_alpha=32,
    lora_dropout=0.05,

    eval_steps=0,        # 0 => eval each epoch; otherwise eval every N steps
    hf_token="",         # or set HF_TOKEN env var; see hf_auth.py behavior
    warmup_ratio=0.03,
    scheduler="cosine",  # cosine | linear | constant
    log_steps=10,
    save_merged=False,

    no_eval=False,       # True => disable eval to avoid eval OOM
)


## 3) Imports

In [None]:
import math
import os
import random
import warnings
import logging
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Dict, List

import numpy as np
import torch
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoModelForImageTextToText,
    AutoProcessor,
    TrainingArguments,
    Trainer,
)

from hf_auth import get_hf_token, try_with_token


## 4) Repro + dtype helpers

In [None]:
def set_all_seeds(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def pick_compute_dtype():
    if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
        return torch.bfloat16
    if torch.cuda.is_available():
        return torch.float16
    return torch.float32


def print_gpu_mem(prefix=""):
    if not torch.cuda.is_available():
        return
    try:
        torch.cuda.synchronize()
        a = torch.cuda.memory_allocated() / (1024**3)
        r = torch.cuda.memory_reserved() / (1024**3)
        m = torch.cuda.max_memory_allocated() / (1024**3)
        print(f"{prefix}CUDA mem: allocated={a:.2f}GB reserved={r:.2f}GB max_alloc={m:.2f}GB")
    except Exception:
        pass


## 5) Gradient checkpointing + `use_cache` hard-disable
These cells replicate the warning fixes from your script.

In [None]:
# ------------------------------------------------------------
# PyTorch checkpoint warning fix (PyTorch 2.9+ may require explicit use_reentrant)
# ------------------------------------------------------------
def patch_torch_checkpoint_default_use_reentrant_false():
    try:
        import inspect
        import torch.utils.checkpoint as ckpt

        sig = inspect.signature(ckpt.checkpoint)
        if "use_reentrant" not in sig.parameters:
            return

        if getattr(ckpt.checkpoint, "_patched_use_reentrant_default", False):
            return

        _orig = ckpt.checkpoint

        def _wrapped(function, *args, **kwargs):
            kwargs.setdefault("use_reentrant", False)
            return _orig(function, *args, **kwargs)

        _wrapped._patched_use_reentrant_default = True
        ckpt.checkpoint = _wrapped
    except Exception:
        return


def enable_gc_no_reentrant(model):
    if not hasattr(model, "gradient_checkpointing_enable"):
        return
    try:
        model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    except TypeError:
        model.gradient_checkpointing_enable()


def force_disable_use_cache_everywhere(model):
    """Recursively forces use_cache=False anywhere found."""
    visited = set()

    def _set_use_cache(obj):
        if obj is None:
            return
        oid = id(obj)
        if oid in visited:
            return
        visited.add(oid)

        if hasattr(obj, "use_cache"):
            try:
                obj.use_cache = False
            except Exception:
                pass

        for attr in (
            "config",
            "generation_config",
            "text_config",
            "language_config",
            "vision_config",
            "model_config",
        ):
            if hasattr(obj, attr):
                try:
                    _set_use_cache(getattr(obj, attr))
                except Exception:
                    pass

    _set_use_cache(model)
    for attr in ("model", "base_model", "language_model", "vision_tower", "vision_model"):
        if hasattr(model, attr):
            try:
                _set_use_cache(getattr(model, attr))
            except Exception:
                pass

    try:
        for m in model.modules():
            if hasattr(m, "config"):
                _set_use_cache(m.config)
            if hasattr(m, "generation_config"):
                _set_use_cache(m.generation_config)
    except Exception:
        pass


@contextmanager
def suppress_use_cache_gc_warning():
    """Silence the specific transformers logger warning during GC enable."""
    names = ["transformers.modeling_utils", "transformers"]
    loggers = [logging.getLogger(n) for n in names]
    old_levels = [lg.level for lg in loggers]
    try:
        for lg in loggers:
            lg.setLevel(logging.ERROR)
        yield
    finally:
        for lg, lvl in zip(loggers, old_levels):
            lg.setLevel(lvl)


## 6) Chat-template tokenization (assistant-only loss)
This keeps your exact masking logic and adds `token_type_ids`.

In [None]:
def _messages_text_only(prompt: str, target: str | None = None) -> List[Dict[str, Any]]:
    msgs = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
    if target is not None:
        msgs.append({"role": "assistant", "content": [{"type": "text", "text": target}]})
    return msgs


def _ensure_1d_list(x):
    if isinstance(x, dict) and "input_ids" in x:
        x = x["input_ids"]
    if torch.is_tensor(x):
        x = x.tolist()
    if isinstance(x, list) and len(x) > 0 and isinstance(x[0], list):
        return x[0]
    return x


def _apply_chat(processor, messages, max_len: int, add_generation_prompt: bool):
    if not hasattr(processor, "apply_chat_template"):
        raise RuntimeError("Processor does not support apply_chat_template; please upgrade transformers.")

    out = processor.apply_chat_template(
        messages,
        add_generation_prompt=add_generation_prompt,
        tokenize=True,
        return_dict=True,
        truncation=True,
        max_length=max_len,
    )

    input_ids = _ensure_1d_list(out["input_ids"])
    attn = out.get("attention_mask", None)
    attn = _ensure_1d_list(attn) if attn is not None else None
    if attn is None:
        attn = [1] * len(input_ids)

    return {"input_ids": input_ids, "attention_mask": attn}


def tokenize_fn(processor, example, max_len: int):
    prompt = example["prompt"]
    target = example["target"]

    full_msgs = _messages_text_only(prompt, target)
    prompt_msgs = _messages_text_only(prompt, None)

    full = _apply_chat(processor, full_msgs, max_len=max_len, add_generation_prompt=False)
    pref = _apply_chat(processor, prompt_msgs, max_len=max_len, add_generation_prompt=True)

    input_ids = full["input_ids"]
    attention_mask = full["attention_mask"]

    # Gemma3 requires token_type_ids during training; text-only => zeros.
    token_type_ids = [0] * len(input_ids)

    labels = list(input_ids)
    prompt_len = min(len(pref["input_ids"]), len(labels))
    for i in range(prompt_len):
        labels[i] = -100

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "token_type_ids": token_type_ids,
        "labels": labels,
    }


## 7) Collator (pad-safe for labels + token_type_ids)

In [None]:
@dataclass
class CausalLMPadCollator:
    pad_token_id: int
    label_pad_id: int = -100

    def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        max_len = max(len(f["input_ids"]) for f in features)

        input_ids, attn, ttype, labels = [], [], [], []
        for f in features:
            ids = f["input_ids"]
            am = f.get("attention_mask", [1] * len(ids))
            lb = f["labels"]
            tt = f.get("token_type_ids", [0] * len(ids))

            pad_n = max_len - len(ids)
            input_ids.append(ids + [self.pad_token_id] * pad_n)
            attn.append(am + [0] * pad_n)
            ttype.append(tt + [0] * pad_n)
            labels.append(lb + [self.label_pad_id] * pad_n)

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "attention_mask": torch.tensor(attn, dtype=torch.long),
            "token_type_ids": torch.tensor(ttype, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long),
        }


## 8) LoRA target-module selection
Filters out vision/image/encoder paths and targets common projection layers.

In [None]:
def pick_target_module_paths(model):
    suffixes = ("q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj")
    banned_tokens = ("vision", "visual", "image", "encoder")

    targets = []
    for name, _module in model.named_modules():
        leaf = name.split(".")[-1]
        if leaf in suffixes:
            lname = name.lower()
            if any(bt in lname for bt in banned_tokens):
                continue
            targets.append(name)

    if not targets:
        for name, _ in model.named_modules():
            leaf = name.split(".")[-1]
            if leaf in ("q_proj", "k_proj", "v_proj", "o_proj"):
                lname = name.lower()
                if any(bt in lname for bt in banned_tokens):
                    continue
                targets.append(name)

    if not targets:
        return ["q_proj", "k_proj", "v_proj", "o_proj"]

    return targets


def _supports_trainingargs_kw(cls, kw: str) -> bool:
    try:
        import inspect
        return kw in inspect.signature(cls.__init__).parameters
    except Exception:
        return False


## 9) Environment + model/processor load

In [None]:
# --- Safety checks ---
if not torch.cuda.is_available():
    raise SystemExit("CUDA not available.")

patch_torch_checkpoint_default_use_reentrant_false()

os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
torch.backends.cuda.matmul.allow_tf32 = True

set_all_seeds(int(CFG["seed"]))

token = get_hf_token(CFG.get("hf_token", ""))
compute_dtype = pick_compute_dtype()
print("compute_dtype:", compute_dtype)

# Processor
processor = try_with_token(AutoProcessor.from_pretrained, CFG["model_name"], token=token)
tok = getattr(processor, "tokenizer", None)
if tok is None:
    raise SystemExit("AutoProcessor did not expose a tokenizer. Please upgrade transformers.")

# Pad token
if tok.pad_token_id is None:
    if tok.eos_token_id is None:
        raise SystemExit("Tokenizer has no pad_token_id and no eos_token_id; cannot pad.")
    tok.pad_token = tok.eos_token
pad_id = tok.pad_token_id

print_gpu_mem("[before load] ")

# Model (NO quantization)
model = try_with_token(
    AutoModelForImageTextToText.from_pretrained,
    CFG["model_name"],
    device_map={"": 0},
    torch_dtype=compute_dtype,
    low_cpu_mem_usage=True,
    token=token,
)

print_gpu_mem("[after load]  ")

# Hard-disable use_cache everywhere
force_disable_use_cache_everywhere(model)

# Enable gradient checkpointing (silence transformers warning while enabling)
with suppress_use_cache_gc_warning():
    enable_gc_no_reentrant(model)

# Belt + suspenders
force_disable_use_cache_everywhere(model)

# Helps some backbones under GC
if hasattr(model, "enable_input_require_grads"):
    model.enable_input_require_grads()

print("Loaded model + processor OK.")


## 10) Attach LoRA adapters

In [None]:
target_modules = pick_target_module_paths(model)
print("LoRA target modules count:", len(target_modules))
print("LoRA target sample:", target_modules[:8])

lora_cfg = LoraConfig(
    r=int(CFG["lora_r"]),
    lora_alpha=int(CFG["lora_alpha"]),
    lora_dropout=float(CFG["lora_dropout"]),
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=target_modules,
)
model = get_peft_model(model, lora_cfg)

# After PEFT wrap, hard-disable again
force_disable_use_cache_everywhere(model)

try:
    model.print_trainable_parameters()
except Exception:
    pass


## 11) Load datasets + tokenize

In [None]:
train_ds = load_dataset("json", data_files={"train": CFG["train_jsonl"]})["train"].shuffle(seed=int(CFG["seed"]))

if not bool(CFG.get("no_eval", False)):
    val_ds = load_dataset("json", data_files={"val": CFG["val_jsonl"]})["val"]
else:
    val_ds = None

train_tok = train_ds.map(
    lambda ex: tokenize_fn(processor, ex, int(CFG["max_len"])),
    remove_columns=train_ds.column_names,
    desc="Tokenizing train",
)

if val_ds is not None:
    eval_len = int(CFG["max_len"]) if int(CFG["eval_max_len"]) <= 0 else int(CFG["eval_max_len"])
    val_tok = val_ds.map(
        lambda ex: tokenize_fn(processor, ex, eval_len),
        remove_columns=val_ds.column_names,
        desc="Tokenizing val",
    )
else:
    val_tok = None

collator = CausalLMPadCollator(pad_token_id=pad_id)

print("Tokenized train:", len(train_tok))
print("Tokenized val:", (len(val_tok) if val_tok is not None else "disabled"))


## 12) TrainingArguments + Trainer

In [None]:
eval_strategy = "no" if bool(CFG.get("no_eval", False)) else ("epoch" if int(CFG["eval_steps"]) == 0 else "steps")
save_strategy = "epoch" if int(CFG["eval_steps"]) == 0 else "steps"

steps_per_epoch = math.ceil(len(train_tok) / max(1, int(CFG["batch"]) * int(CFG["grad_accum"])))
total_steps = max(1, steps_per_epoch * max(1, int(CFG["epochs"])))
warmup_steps = int(total_steps * float(CFG["warmup_ratio"]))

ta_kwargs = dict(
    output_dir=CFG["out_dir"],
    num_train_epochs=int(CFG["epochs"]),
    per_device_train_batch_size=int(CFG["batch"]),
    gradient_accumulation_steps=int(CFG["grad_accum"]),
    learning_rate=float(CFG["lr"]),
    logging_steps=int(CFG["log_steps"]),
    save_total_limit=2,
    bf16=(compute_dtype == torch.bfloat16),
    fp16=(compute_dtype == torch.float16),
    optim="adamw_torch",
    lr_scheduler_type=str(CFG["scheduler"]),
    report_to="none",
    remove_unused_columns=False,
    group_by_length=True,
    max_grad_norm=1.0,
    warmup_steps=warmup_steps,
    gradient_checkpointing=True,
)

# ---- Eval/Save strategy kw compat ----
if _supports_trainingargs_kw(TrainingArguments, "eval_strategy"):
    ta_kwargs["eval_strategy"] = eval_strategy
else:
    ta_kwargs["evaluation_strategy"] = eval_strategy

if _supports_trainingargs_kw(TrainingArguments, "save_strategy"):
    ta_kwargs["save_strategy"] = save_strategy

# ---- Eval memory reducers (only if eval is enabled) ----
if not bool(CFG.get("no_eval", False)):
    if _supports_trainingargs_kw(TrainingArguments, "per_device_eval_batch_size"):
        ta_kwargs["per_device_eval_batch_size"] = 1

    if _supports_trainingargs_kw(TrainingArguments, "fp16_full_eval"):
        ta_kwargs["fp16_full_eval"] = (compute_dtype == torch.float16)
    if _supports_trainingargs_kw(TrainingArguments, "bf16_full_eval"):
        ta_kwargs["bf16_full_eval"] = (compute_dtype == torch.bfloat16)

    if _supports_trainingargs_kw(TrainingArguments, "prediction_loss_only"):
        ta_kwargs["prediction_loss_only"] = True

    ta_kwargs["load_best_model_at_end"] = True
    ta_kwargs["metric_for_best_model"] = "eval_loss"
    ta_kwargs["greater_is_better"] = False

    if int(CFG["eval_steps"]) != 0:
        ta_kwargs["eval_steps"] = int(CFG["eval_steps"])
        ta_kwargs["save_steps"] = int(CFG["eval_steps"])
else:
    ta_kwargs["load_best_model_at_end"] = False

train_args = TrainingArguments(**ta_kwargs)

trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    data_collator=collator,
)

# Final force-disable just before training
force_disable_use_cache_everywhere(trainer.model)

print("Trainer ready.")
print_gpu_mem("[before train] ")


## 13) Train + save artifacts

In [None]:
trainer.train()
print_gpu_mem("[after train]  ")

# Save adapter + processor
model.save_pretrained(CFG["out_dir"])
processor.save_pretrained(CFG["out_dir"])
print(f"[ok] saved LoRA adapter -> {CFG['out_dir']}")
print("[ok] training logs at:", os.path.join(CFG["out_dir"], "trainer_state.json"))

# Optional merge
if bool(CFG.get("save_merged", False)):
    merged_dir = os.path.join(CFG["out_dir"], "merged")
    os.makedirs(merged_dir, exist_ok=True)
    try:
        merged = model.merge_and_unload()
        merged.save_pretrained(merged_dir, safe_serialization=True)
        processor.save_pretrained(merged_dir)
        print(f"[ok] saved merged full model -> {merged_dir}")
    except Exception as e:
        print("[warn] merge failed (adapter still saved). Error:", repr(e))


## 14) Quick notes
- If you hit **eval OOM**, set `CFG['no_eval']=True` or reduce `CFG['eval_max_len']`.
- If training OOM, reduce `max_len`, increase `grad_accum`, or reduce `batch`.
- Make sure your `hf_auth.py` is in the same folder as this notebook (or installed as a module).