In [None]:
# Qwen2.5-VL 3B regression fine-tuning (image + caption → alignment_score)
# ✅ LoRA (FEATURE_EXTRACTION)
# ✅ Single-process / single-GPU
# ✅ Lots of progress/debug printing
# ✅ Fixes common CUDA device-side assert causes:
#    - GPU poisoned -> restart kernel
#    - unexpected keys passed to model
#    - overly long sequences (truncation)
#    - bf16 instability -> fp16
#    - fp16 loss instability -> compute MSE in fp32

# %% [markdown]
# ## 0) (Optional) install deps
# %%
# !pip -q install -U "transformers>=4.45" datasets accelerate peft pillow

# %% [markdown]
# ## 1) IMPORTANT: env + cache (RUN FIRST; restart kernel if you previously imported torch/transformers)
# %%
from pathlib import Path
import os

# Make CUDA errors synchronous (must be set before CUDA init / torch import for best results)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

HF_CACHE_ROOT = "/work/hdd/bfrc"
os.environ["HF_HOME"] = str(Path(HF_CACHE_ROOT) / "hf_home")
os.environ["HUGGINGFACE_HUB_CACHE"] = str(Path(HF_CACHE_ROOT) / "hub")
os.environ["TRANSFORMERS_CACHE"] = str(Path(HF_CACHE_ROOT) / "transformers")
os.environ["HF_DATASETS_CACHE"] = str(Path(HF_CACHE_ROOT) / "datasets")
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

for p in [
    os.environ["HF_HOME"],
    os.environ["HUGGINGFACE_HUB_CACHE"],
    os.environ["TRANSFORMERS_CACHE"],
    os.environ["HF_DATASETS_CACHE"],
    CACHE_DIR,
]:
    Path(p).mkdir(parents=True, exist_ok=True)

print("[env] CUDA_LAUNCH_BLOCKING:", os.environ["CUDA_LAUNCH_BLOCKING"])
print("[cache] HF_HOME:", os.environ["HF_HOME"])
print("[cache] HUGGINGFACE_HUB_CACHE:", os.environ["HUGGINGFACE_HUB_CACHE"])
print("[cache] TRANSFORMERS_CACHE:", os.environ["TRANSFORMERS_CACHE"])
print("[cache] HF_DATASETS_CACHE:", os.environ["HF_DATASETS_CACHE"])
print("[cache] CACHE_DIR:", CACHE_DIR)

# %% [markdown]
# ## 2) Config
# %%
from pathlib import Path
import torch

TRAIN_CSV = ""
EVAL_CSV  = ""
IMG_ROOT  = ""

MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# training
EPOCHS = 2
PER_DEVICE_BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 8
LR = 2e-4
WARMUP_RATIO = 0.03
WEIGHT_DECAY = 0.0
MAX_STEPS = -1

# prompt + truncation
PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048  # try 1024 if still unstable/oom

# LoRA
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

SEED = 42
NUM_WORKERS = 2

print("[cfg] OUTPUT_DIR:", OUTPUT_DIR)
print("[cfg] CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("[cfg] GPU:", torch.cuda.get_device_name(0))
    print("[cfg] device_count:", torch.cuda.device_count())

# %% [markdown]
# ## 3) Load dataset
# %%
from datasets import load_dataset

print("[data] Loading CSV(s)...")
data_files = {"train": TRAIN_CSV}
if EVAL_CSV:
    data_files["eval"] = EVAL_CSV

raw = load_dataset("csv", data_files=data_files, cache_dir=CACHE_DIR)

print("[data] splits:", list(raw.keys()))
print("[data] columns:", raw["train"].column_names)
print("[data] train rows:", len(raw["train"]))
if EVAL_CSV:
    print("[data] eval rows:", len(raw["eval"]))

required_cols = {"img_name", "caption", "alignment_score"}
for split in raw.keys():
    missing = required_cols - set(raw[split].column_names)
    if missing:
        raise ValueError(f"{split} is missing columns: {missing}")

print("[data] example row:", raw["train"][0])

# %% [markdown]
# ## 4) Torch dataset (image load on the fly)
# %%
import os
import math
from PIL import Image
from torch.utils.data import Dataset

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])

        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")

        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

train_ds = CSVImageCaptionRegressionDataset(raw["train"], IMG_ROOT)
eval_ds  = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT) if EVAL_CSV else None
print("[ds] train:", len(train_ds), "| eval:", (len(eval_ds) if eval_ds else 0))

# %% [markdown]
# ## 5) Load processor + base model (fp16) + regression head + LoRA (FEATURE_EXTRACTION)
# %%
from transformers import AutoProcessor, AutoModel
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType

print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base model (fp16)...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)
print("[model] Base loaded.")

class Qwen2VLForRegression(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        # Only pass tensor inputs relevant to the base model (safety).
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}

        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B, T, H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.mean(dim=1)
        else:
            attn = attn.to(last_hidden.dtype)
            denom = attn.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (last_hidden * attn.unsqueeze(-1)).sum(dim=1) / denom

        pred = self.reg_head(pooled).squeeze(-1)  # [B]

        loss = None
        if labels is not None:
            # fp32 loss for stability
            loss = F.mse_loss(pred.float(), labels.float())

        return {"loss": loss, "predictions": pred}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA (FEATURE_EXTRACTION)...")
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=target_modules,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)
model.base.print_trainable_parameters()

for p in model.reg_head.parameters():
    p.requires_grad = True
print("[model] reg_head params:", sum(p.numel() for p in model.reg_head.parameters()))

# %% [markdown]
# ## 6) Collator (NO extra keys passed into Trainer), with truncation + token-range checks
# %%
from typing import Any, Dict, List
import torch

class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int, debug: bool = False):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len
        self.debug = debug
        self._seen = 0

        # best-effort vocab size for safety checks
        self.vocab_size = None
        tok = getattr(processor, "tokenizer", None)
        if tok is not None and hasattr(tok, "vocab_size"):
            self.vocab_size = int(len(tok))  # includes added/special tokens


    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        model_inputs["labels"] = labels

        # Safety checks that often catch device-side asserts early
        if "input_ids" in model_inputs:
            max_id = int(model_inputs["input_ids"].max())
            if self.vocab_size is not None and max_id >= self.vocab_size:
                # This would cause embedding index OOB -> device-side assert
                raise ValueError(f"input_ids has id {max_id} >= vocab_size {self.vocab_size}")

        if self.debug and self._seen < 3:
            print("[collator] prompt preview:\n", texts[0][:400], "...")
            print("[collator] input_ids:", tuple(model_inputs["input_ids"].shape))
            print("[collator] labels:", labels[: min(4, len(labels))].tolist())
            self._seen += 1

        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN, debug=False)

# sanity check
b = collator([train_ds[0]])
print("[collator] keys:", list(b.keys()))

# %% [markdown]
# ## 7) Trainer + progress prints
# %%
from transformers import TrainingArguments, Trainer, set_seed, TrainerCallback
import numpy as np
import time

set_seed(SEED)

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.asarray(preds, dtype=np.float32)
    labels = np.asarray(labels, dtype=np.float32)
    mse = float(np.mean((preds - labels) ** 2))
    mae = float(np.mean(np.abs(preds - labels)))
    var = float(np.var(labels))
    r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)
    return {"mse": mse, "mae": mae, "r2": r2}

class PrintProgressCallback(TrainerCallback):
    def __init__(self, trainer_ref=None, sample_every_steps: int = 100):
        self.trainer_ref = trainer_ref
        self.sample_every_steps = sample_every_steps
        self.t0 = None

    def on_train_begin(self, args, state, control, **kwargs):
        self.t0 = time.time()
        print(f"[train] begin | max_steps={state.max_steps} | epochs={args.num_train_epochs}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        elapsed = (time.time() - self.t0) if self.t0 else 0.0
        msg = f"[log] step={state.global_step} epoch={state.epoch:.3f} elapsed={elapsed/60:.1f}m"
        if "loss" in logs:
            msg += f" loss={logs['loss']:.4f}"
        if "learning_rate" in logs:
            msg += f" lr={logs['learning_rate']:.2e}"
        if "grad_norm" in logs:
            msg += f" grad_norm={logs['grad_norm']:.3f}"
        print(msg)

    def on_step_end(self, args, state, control, **kwargs):
        if self.trainer_ref is None:
            return
        if state.global_step > 0 and (state.global_step % self.sample_every_steps == 0):
            try:
                m = self.trainer_ref.model
                m.eval()
                ex = train_ds[0]
                batch = collator([ex])
                device = next(m.parameters()).device
                batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}
                with torch.no_grad():
                    out = m(**batch)
                pred = float(out["predictions"].detach().float().cpu().item())
                print(f"[sample] step={state.global_step} pred={pred:.4f} label={float(ex['labels']):.4f}")
            except Exception as e:
                print("[sample] failed:", repr(e))

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    logging_steps=10,
    save_steps=200,
    eval_steps=200 if eval_ds is not None else None,
    eval_strategy="steps" if eval_ds is not None else "no",
    save_total_limit=2,
    bf16=False,
    fp16=True,
    dataloader_num_workers=NUM_WORKERS,
    remove_unused_columns=False,
    max_steps=MAX_STEPS,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    compute_metrics=compute_metrics if eval_ds is not None else None,
)

trainer.add_callback(PrintProgressCallback(trainer_ref=trainer, sample_every_steps=100))

print("[train] Starting training...")
trainer.train()
print("[train] Training finished.")

# %% [markdown]
# ## 8) Save
# %%
from pathlib import Path

print("[save] Saving full model + processor to:", OUTPUT_DIR)
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

lora_dir = str(Path(OUTPUT_DIR) / "lora_adapter")
print("[save] Saving LoRA adapter to:", lora_dir)
model.base.save_pretrained(lora_dir)

print("[save] Done.")


In [None]:
# Qwen2.5-VL 3B regression fine-tuning (image + caption -> alignment_score)
# ✅ LoRA (FEATURE_EXTRACTION)
# ✅ Single-process / single-GPU
# ✅ Resume from latest checkpoint automatically
# ✅ Validation during training (every eval_steps)
# ✅ Stability fixes: fp32 pooling + fp32 head + fp32 MSE loss
# ✅ Trainer-compatible outputs: returns "logits" (not "predictions")

# (Optional) install deps:
# !pip -q install -U "transformers>=4.45" datasets accelerate peft pillow numpy

# ----------------------------
# 1) IMPORTANT: env + cache
# ----------------------------
from pathlib import Path
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

HF_CACHE_ROOT = "/work/hdd/bfrc"
os.environ["HF_HOME"] = str(Path(HF_CACHE_ROOT) / "hf_home")
os.environ["HUGGINGFACE_HUB_CACHE"] = str(Path(HF_CACHE_ROOT) / "hub")
os.environ["TRANSFORMERS_CACHE"] = str(Path(HF_CACHE_ROOT) / "transformers")
os.environ["HF_DATASETS_CACHE"] = str(Path(HF_CACHE_ROOT) / "datasets")
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

for p in [
    os.environ["HF_HOME"],
    os.environ["HUGGINGFACE_HUB_CACHE"],
    os.environ["TRANSFORMERS_CACHE"],
    os.environ["HF_DATASETS_CACHE"],
    CACHE_DIR,
]:
    Path(p).mkdir(parents=True, exist_ok=True)

print("[env] CUDA_LAUNCH_BLOCKING:", os.environ["CUDA_LAUNCH_BLOCKING"])
print("[cache] HF_HOME:", os.environ["HF_HOME"])
print("[cache] HUGGINGFACE_HUB_CACHE:", os.environ["HUGGINGFACE_HUB_CACHE"])
print("[cache] TRANSFORMERS_CACHE:", os.environ["TRANSFORMERS_CACHE"])
print("[cache] HF_DATASETS_CACHE:", os.environ["HF_DATASETS_CACHE"])
print("[cache] CACHE_DIR:", CACHE_DIR)

# ----------------------------
# 2) Config
# ----------------------------
import torch

TRAIN_CSV = 
EVAL_CSV  = 
IMG_ROOT  = ""

MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# training
EPOCHS = 2
PER_DEVICE_BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 8
LR = 2e-4
WARMUP_RATIO = 0.03
WEIGHT_DECAY = 0.0
MAX_STEPS = -1

# prompt + truncation
PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048  # try 1024 if still unstable/oom

# LoRA
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05

SEED = 42
NUM_WORKERS = 2

print("[cfg] OUTPUT_DIR:", OUTPUT_DIR)
print("[cfg] CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("[cfg] GPU:", torch.cuda.get_device_name(0))
    print("[cfg] device_count:", torch.cuda.device_count())

# ----------------------------
# 3) Load dataset
# ----------------------------
from datasets import load_dataset

print("[data] Loading CSV(s)...")
data_files = {"train": TRAIN_CSV}
if EVAL_CSV:
    data_files["eval"] = EVAL_CSV

raw = load_dataset("csv", data_files=data_files, cache_dir=CACHE_DIR)

print("[data] splits:", list(raw.keys()))
print("[data] columns:", raw["train"].column_names)
print("[data] train rows:", len(raw["train"]))
if EVAL_CSV:
    print("[data] eval rows:", len(raw["eval"]))

required_cols = {"img_name", "caption", "alignment_score"}
for split in raw.keys():
    missing = required_cols - set(raw[split].column_names)
    if missing:
        raise ValueError(f"{split} is missing columns: {missing}")

print("[data] example row:", raw["train"][0])

# ----------------------------
# 4) Torch dataset (image load on the fly)
# ----------------------------
import math
from PIL import Image
from torch.utils.data import Dataset

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])

        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")

        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

train_ds = CSVImageCaptionRegressionDataset(raw["train"], IMG_ROOT)
eval_ds  = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT) if EVAL_CSV else None
print("[ds] train:", len(train_ds), "| eval:", (len(eval_ds) if eval_ds else 0))

# ----------------------------
# 5) Load processor + base model + regression head + LoRA
# ----------------------------
from transformers import AutoProcessor, AutoModel
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType

print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base model (fp16)...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)
print("[model] Base loaded.")

class Qwen2VLForRegression(nn.Module):
    """
    Trainer-friendly regression wrapper:
      - returns dict with keys: loss, logits
      - logits shape: [B, 1]
    Stability:
      - pooling in fp32
      - regression head run in fp32
      - mse loss in fp32
    """
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}

        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B, T, H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.float().mean(dim=1)
        else:
            attn_f = attn.float()
            lh_f = last_hidden.float()
            denom = attn_f.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (lh_f * attn_f.unsqueeze(-1)).sum(dim=1) / denom

        logits = self.reg_head(pooled).float()  # [B, 1] fp32

        loss = None
        if labels is not None:
            loss = F.mse_loss(logits.squeeze(-1), labels.float())

        return {"loss": loss, "logits": logits}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA (FEATURE_EXTRACTION)...")
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=target_modules,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)
model.base.print_trainable_parameters()

for p in model.reg_head.parameters():
    p.requires_grad = True
print("[model] reg_head params:", sum(p.numel() for p in model.reg_head.parameters()))

# ----------------------------
# 6) Collator (truncation + token-range checks)
# ----------------------------
from typing import Any, Dict, List, Optional

class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int, debug: bool = False):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len
        self.debug = debug
        self._seen = 0

        self.vocab_size = None
        tok = getattr(processor, "tokenizer", None)
        if tok is not None and hasattr(tok, "__len__"):
            self.vocab_size = int(len(tok))  # includes added/special tokens

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        model_inputs["labels"] = labels

        if "input_ids" in model_inputs:
            max_id = int(model_inputs["input_ids"].max())
            if self.vocab_size is not None and max_id >= self.vocab_size:
                raise ValueError(f"input_ids has id {max_id} >= vocab_size {self.vocab_size}")

        if self.debug and self._seen < 3:
            print("[collator] prompt preview:\n", texts[0][:400], "...")
            print("[collator] input_ids:", tuple(model_inputs["input_ids"].shape))
            print("[collator] labels:", labels[: min(4, len(labels))].tolist())
            self._seen += 1

        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN, debug=False)

b = collator([train_ds[0]])
print("[collator] keys:", list(b.keys()))

# ----------------------------
# 7) Trainer + progress prints + resume
# ----------------------------
from transformers import TrainingArguments, Trainer, set_seed, TrainerCallback
import numpy as np
import time
import re

set_seed(SEED)

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.asarray(preds, dtype=np.float32).squeeze()   # handles [N,1] or [N]
    labels = np.asarray(labels, dtype=np.float32).squeeze()

    bad = ~np.isfinite(preds)
    if bad.any():
        return {
            "mse": float("nan"),
            "mae": float("nan"),
            "r2": float("nan"),
            "pred_nan_frac": float(bad.mean()),
        }

    mse = float(np.mean((preds - labels) ** 2))
    mae = float(np.mean(np.abs(preds - labels)))
    var = float(np.var(labels))
    r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)
    return {"mse": mse, "mae": mae, "r2": r2}

class PrintProgressCallback(TrainerCallback):
    def __init__(self, trainer_ref=None, sample_every_steps: int = 100):
        self.trainer_ref = trainer_ref
        self.sample_every_steps = sample_every_steps
        self.t0 = None

    def on_train_begin(self, args, state, control, **kwargs):
        self.t0 = time.time()
        print(f"[train] begin | max_steps={state.max_steps} | epochs={args.num_train_epochs}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        elapsed = (time.time() - self.t0) if self.t0 else 0.0
        msg = f"[log] step={state.global_step} epoch={state.epoch:.3f} elapsed={elapsed/60:.1f}m"
        if "loss" in logs:
            msg += f" loss={logs['loss']:.4f}"
        if "learning_rate" in logs:
            msg += f" lr={logs['learning_rate']:.2e}"
        if "grad_norm" in logs:
            msg += f" grad_norm={logs['grad_norm']:.3f}"
        print(msg)

    def on_step_end(self, args, state, control, **kwargs):
        if self.trainer_ref is None:
            return
        if state.global_step > 0 and (state.global_step % self.sample_every_steps == 0):
            try:
                m = self.trainer_ref.model
                m.eval()
                ex = train_ds[0]
                batch = collator([ex])
                device = next(m.parameters()).device
                batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}
                with torch.no_grad():
                    out = m(**batch)
                pred = float(out["logits"].detach().float().cpu().squeeze().item())
                print(f"[sample] step={state.global_step} pred={pred:.4f} label={float(ex['labels']):.4f}")
            except Exception as e:
                print("[sample] failed:", repr(e))

def get_latest_checkpoint(output_dir: str) -> Optional[str]:
    out = Path(output_dir)
    if not out.exists():
        return None
    ckpts = [p for p in out.glob("checkpoint-*") if p.is_dir()]
    if not ckpts:
        return None

    def step_num(p: Path) -> int:
        m = re.search(r"checkpoint-(\d+)", p.name)
        return int(m.group(1)) if m else -1

    return str(max(ckpts, key=step_num))

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    logging_steps=10,

    save_steps=200,
    save_total_limit=2,

    eval_strategy="steps" if eval_ds is not None else "no",
    eval_steps=200 if eval_ds is not None else None,

    bf16=False,
    fp16=True,
    dataloader_num_workers=NUM_WORKERS,
    remove_unused_columns=False,
    max_steps=MAX_STEPS,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    compute_metrics=compute_metrics if eval_ds is not None else None,
)

trainer.add_callback(PrintProgressCallback(trainer_ref=trainer, sample_every_steps=100))

latest_ckpt = get_latest_checkpoint(OUTPUT_DIR)
print("[resume] latest_ckpt:", latest_ckpt)

print("[train] Starting training...")
trainer.train(resume_from_checkpoint=latest_ckpt if latest_ckpt else None)
print("[train] Training finished.")

# ----------------------------
# 8) Save
# ----------------------------
print("[save] Saving full model + processor to:", OUTPUT_DIR)
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

lora_dir = str(Path(OUTPUT_DIR) / "lora_adapter")
print("[save] Saving LoRA adapter to:", lora_dir)
model.base.save_pretrained(lora_dir)

print("[save] Done.")


In [None]:
!pkill -9 -u $USER -f "torchrun|accelerate|deepspeed"
!pkill -9 -u $USER -f "python.*train|python.*ipykernel"


In [None]:
!nvidia-smi


In [None]:
import torch
print("cuda available:", torch.cuda.is_available())
if torch.cuda.is_available():
    free, total = torch.cuda.mem_get_info()
    print("torch free GB:", free/1024**3, "total GB:", total/1024**3)


In [None]:
# Load latest checkpoint + eval on val + return arrays (labels, preds)
# ✅ prints progress during prediction/eval

from pathlib import Path
import os, re, math
from typing import Any, Dict, List, Optional

import torch
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import AutoProcessor, AutoModel, TrainingArguments, Trainer
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType


# -----------------------------
# 0) Paths / config (edit if needed)
# -----------------------------
HF_CACHE_ROOT = "/work/hdd/bfrc"
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

EVAL_CSV  = ""
IMG_ROOT  = ""
MODEL_ID  = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")

PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048

PER_DEVICE_EVAL_BATCH_SIZE = 4
NUM_WORKERS = 2

# LoRA (MUST match training)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]


# -----------------------------
# 1) Find latest checkpoint
# -----------------------------
def get_latest_checkpoint(output_dir: str) -> Optional[str]:
    out = Path(output_dir)
    if not out.exists():
        return None
    ckpts = [p for p in out.glob("checkpoint-*") if p.is_dir()]
    if not ckpts:
        return None

    def step_num(p: Path) -> int:
        m = re.search(r"checkpoint-(\d+)", p.name)
        return int(m.group(1)) if m else -1

    return str(max(ckpts, key=step_num))

ckpt_dir = get_latest_checkpoint(OUTPUT_DIR)
if ckpt_dir is None:
    raise FileNotFoundError(f"No checkpoint-* found under: {OUTPUT_DIR}")
print("[ckpt] latest:", ckpt_dir)


# -----------------------------
# 2) Val dataset
# -----------------------------
required_cols = {"img_name", "caption", "alignment_score"}
raw = load_dataset("csv", data_files={"eval": EVAL_CSV}, cache_dir=CACHE_DIR)
missing = required_cols - set(raw["eval"].column_names)
if missing:
    raise ValueError(f"eval is missing columns: {missing}")

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])
        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")
        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

eval_ds = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT)
print("[data] eval rows:", len(eval_ds))


# -----------------------------
# 3) Processor + model (must match training)
# -----------------------------
print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)

class Qwen2VLForRegression(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}
        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B,T,H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.float().mean(dim=1)
        else:
            attn_f = attn.float()
            lh_f = last_hidden.float()
            denom = attn_f.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (lh_f * attn_f.unsqueeze(-1)).sum(dim=1) / denom

        logits = self.reg_head(pooled).float()  # [B,1]
        loss = None
        if labels is not None:
            loss = F.mse_loss(logits.squeeze(-1), labels.float())
        return {"loss": loss, "logits": logits}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA...")
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=TARGET_MODULES,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)


# -----------------------------
# 4) Collator
# -----------------------------
class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        model_inputs["labels"] = labels
        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN)


# -----------------------------
# 5) Load checkpoint weights
# -----------------------------
def load_checkpoint_state(model: nn.Module, ckpt_dir: str):
    ckpt = Path(ckpt_dir)
    st_path = ckpt / "model.safetensors"
    pt_path = ckpt / "pytorch_model.bin"

    if st_path.exists():
        from safetensors.torch import load_file
        print("[ckpt] loading safetensors:", st_path)
        state = load_file(str(st_path))
    elif pt_path.exists():
        print("[ckpt] loading pytorch bin:", pt_path)
        state = torch.load(str(pt_path), map_location="cpu")
    else:
        raise FileNotFoundError(f"No model.safetensors or pytorch_model.bin in {ckpt_dir}")

    missing, unexpected = model.load_state_dict(state, strict=False)
    print("[ckpt] missing keys:", len(missing), "| unexpected keys:", len(unexpected))

load_checkpoint_state(model, ckpt_dir)


# -----------------------------
# 6) Evaluate + get arrays with progress printing
# -----------------------------
def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.asarray(preds, dtype=np.float32).squeeze()
    labels = np.asarray(labels, dtype=np.float32).squeeze()
    mse = float(np.mean((preds - labels) ** 2))
    mae = float(np.mean(np.abs(preds - labels)))
    var = float(np.var(labels))
    r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)
    return {"mse": mse, "mae": mae, "r2": r2}

eval_args = TrainingArguments(
    output_dir=str(Path(OUTPUT_DIR) / "_eval_tmp"),
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    dataloader_num_workers=NUM_WORKERS,
    fp16=torch.cuda.is_available(),
    bf16=False,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=eval_args,
    eval_dataset=eval_ds,
    data_collator=collator,
    compute_metrics=compute_metrics,
)

# ---- Progressy manual predict loop (prints batches) ----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

from torch.utils.data import DataLoader

loader = DataLoader(
    eval_ds,
    batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collator,
    pin_memory=torch.cuda.is_available(),
)

all_preds = []
all_labels = []

print("[predict] running manual loop with progress prints...")
total = len(loader)

with torch.no_grad():
    for i, batch in enumerate(loader, start=1):
        labels = batch["labels"].cpu().numpy().astype(np.float32)
        batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}

        out = model(**batch)
        logits = out["logits"].detach().float().cpu().numpy().astype(np.float32).squeeze()

        all_labels.append(labels)
        all_preds.append(logits)

        # progress print
        if i == 1 or i % 10 == 0 or i == total:
            done = i * loader.batch_size
            done = min(done, len(eval_ds))
            print(f"[predict] batch {i}/{total} | seen {done}/{len(eval_ds)}")

val_labels = np.concatenate(all_labels, axis=0).squeeze()
val_predictions = np.concatenate(all_preds, axis=0).squeeze()

print("[predict] done. shapes:", val_labels.shape, val_predictions.shape)

# Compute metrics on the arrays (same as compute_metrics)
mse = float(np.mean((val_predictions - val_labels) ** 2))
mae = float(np.mean(np.abs(val_predictions - val_labels)))
var = float(np.var(val_labels))
r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)

print("[metrics] mse:", mse)
print("[metrics] mae:", mae)
print("[metrics] r2 :", r2)

print("[sample] first 5 (label, pred):")
for j in range(min(5, len(val_labels))):
    print(f"  {j}: label={val_labels[j]:.4f} pred={val_predictions[j]:.4f}")




In [None]:
print("[sample] first 5 (label, pred):")
for j in range(min(100, len(val_labels))):
    print(f"  {j}: label={val_labels[j]:.4f} pred={val_predictions[j]:.4f}")


In [None]:
# Load latest checkpoint + run predictions on test_set_final.csv + save CSV with new column `regressor_pred`
# ✅ prints progress during prediction/eval

from pathlib import Path
import os, re, math
from typing import Any, Dict, List, Optional

import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoProcessor, AutoModel, TrainingArguments, Trainer
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType


# -----------------------------
# 0) Paths / config (edit if needed)
# -----------------------------
HF_CACHE_ROOT = "/work/hdd/bfrc"
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

# ⬇️ replaced val path with test_set_final.csv
EVAL_CSV  = 
IMG_ROOT  = 
MODEL_ID  = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")

PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048

PER_DEVICE_EVAL_BATCH_SIZE = 4
NUM_WORKERS = 2

# LoRA (MUST match training)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]


# -----------------------------
# 1) Find latest checkpoint
# -----------------------------
def get_latest_checkpoint(output_dir: str) -> Optional[str]:
    out = Path(output_dir)
    if not out.exists():
        return None
    ckpts = [p for p in out.glob("checkpoint-*") if p.is_dir()]
    if not ckpts:
        return None

    def step_num(p: Path) -> int:
        m = re.search(r"checkpoint-(\d+)", p.name)
        return int(m.group(1)) if m else -1

    return str(max(ckpts, key=step_num))

ckpt_dir = get_latest_checkpoint(OUTPUT_DIR)
if ckpt_dir is None:
    raise FileNotFoundError(f"No checkpoint-* found under: {OUTPUT_DIR}")
print("[ckpt] latest:", ckpt_dir)


# -----------------------------
# 2) Dataset
# -----------------------------
required_cols = {"img_name", "caption", "alignment_score"}
raw = load_dataset("csv", data_files={"eval": EVAL_CSV}, cache_dir=CACHE_DIR)
missing = required_cols - set(raw["eval"].column_names)
if missing:
    raise ValueError(f"eval is missing columns: {missing}")

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])
        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")
        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

eval_ds = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT)
print("[data] rows:", len(eval_ds))


# -----------------------------
# 3) Processor + model (must match training)
# -----------------------------
print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)

class Qwen2VLForRegression(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}
        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B,T,H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.float().mean(dim=1)
        else:
            attn_f = attn.float()
            lh_f = last_hidden.float()
            denom = attn_f.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (lh_f * attn_f.unsqueeze(-1)).sum(dim=1) / denom

        logits = self.reg_head(pooled).float()  # [B,1]
        loss = None
        if labels is not None:
            loss = F.mse_loss(logits.squeeze(-1), labels.float())
        return {"loss": loss, "logits": logits}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA...")
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=TARGET_MODULES,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)


# -----------------------------
# 4) Collator
# -----------------------------
class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        model_inputs["labels"] = labels
        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN)


# -----------------------------
# 5) Load checkpoint weights (robust)
# -----------------------------
def load_checkpoint_state(model: nn.Module, ckpt_dir: str):
    ckpt = Path(ckpt_dir)

    candidates = [
        ckpt / "model.safetensors",
        ckpt / "pytorch_model.bin",
        ckpt / "adapter_model.safetensors",
        ckpt / "adapter_model.bin",
    ]

    found = None
    for p in candidates:
        if p.exists():
            found = p
            break
    if found is None:
        raise FileNotFoundError(
            f"No supported checkpoint file found in {ckpt_dir}. "
            f"Tried: {[str(p.name) for p in candidates]}"
        )

    if found.suffix == ".safetensors":
        from safetensors.torch import load_file
        print("[ckpt] loading safetensors:", found)
        state = load_file(str(found))
    else:
        print("[ckpt] loading pytorch bin:", found)
        state = torch.load(str(found), map_location="cpu")

    # If this is an adapter-only checkpoint, loading into full model may miss reg_head.
    # We still try to load into full model first.
    missing, unexpected = model.load_state_dict(state, strict=False)
    print("[ckpt] missing keys:", len(missing), "| unexpected keys:", len(unexpected))

load_checkpoint_state(model, ckpt_dir)


# -----------------------------
# 6) Manual predict loop (with progress) + metrics
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

loader = DataLoader(
    eval_ds,
    batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collator,
    pin_memory=torch.cuda.is_available(),
)

all_preds = []
all_labels = []

print("[predict] running manual loop with progress prints...")
total = len(loader)

with torch.no_grad():
    for i, batch in enumerate(loader, start=1):
        labels = batch["labels"].cpu().numpy().astype(np.float32)
        batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}

        out = model(**batch)
        logits = out["logits"].detach().float().cpu().numpy().astype(np.float32).squeeze()

        all_labels.append(labels)
        all_preds.append(logits)

        if i == 1 or i % 10 == 0 or i == total:
            done = i * loader.batch_size
            done = min(done, len(eval_ds))
            print(f"[predict] batch {i}/{total} | seen {done}/{len(eval_ds)}")

val_labels = np.concatenate(all_labels, axis=0).squeeze()
val_predictions = np.concatenate(all_preds, axis=0).squeeze()

print("[predict] done. shapes:", val_labels.shape, val_predictions.shape)

mse = float(np.mean((val_predictions - val_labels) ** 2))
mae = float(np.mean(np.abs(val_predictions - val_labels)))
var = float(np.var(val_labels))
r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)

print("[metrics] mse:", mse)
print("[metrics] mae:", mae)
print("[metrics] r2 :", r2)

print("[sample] first 5 (label, pred):")
for j in range(min(5, len(val_labels))):
    print(f"  {j}: label={val_labels[j]:.4f} pred={val_predictions[j]:.4f}")


# -----------------------------
# 7) Append predictions to CSV + save (prints full absolute paths)
# -----------------------------
pred_col = "regressor_pred"

in_csv_path = Path(EVAL_CSV).resolve()
out_csv_path = in_csv_path.with_name(in_csv_path.stem + "_with_preds" + in_csv_path.suffix)

print(f"[csv] input : {in_csv_path}")
print(f"[csv] output: {out_csv_path}")

df = pd.read_csv(in_csv_path)

if len(df) != len(val_predictions):
    raise ValueError(
        f"Row mismatch: CSV has {len(df)} rows but val_predictions has {len(val_predictions)}. "
        "Ensure eval_ds order matches CSV order and shuffle=False."
    )

df[pred_col] = np.asarray(val_predictions, dtype=np.float32).reshape(-1)
df.to_csv(out_csv_path, index=False)

print(f"[csv] wrote: {out_csv_path} (added column: {pred_col})")


roco

In [None]:
from pathlib import Path
from PIL import Image
import numpy as np

root = Path("")  # e.g., "/mnt/data/dataset" or r"C:\data\dataset"

# pick the first folder (sorted for determinism)
subfolders = sorted([p for p in root.iterdir() if p.is_dir()])
if not subfolders:
    raise FileNotFoundError(f"No subfolders found in: {root}")

first_folder = subfolders[0]

# pick the first image in that folder (common extensions)
img_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
images = sorted([p for p in first_folder.iterdir() if p.is_file() and p.suffix.lower() in img_exts])
if not images:
    raise FileNotFoundError(f"No image files found in: {first_folder}")

img_path = images[0]
img = Image.open(img_path)
arr = np.array(img)

print("First folder:", first_folder)
print("Image path:", img_path)
print("PIL mode:", img.mode)          # e.g., RGB, L
print("PIL size (W,H):", img.size)    # (width, height)
print("NumPy shape:", arr.shape)      # (H, W, C) or (H, W)
print("dtype:", arr.dtype)


In [None]:
from pathlib import Path
from PIL import Image

def resize_images_in_folder(
    in_dir: str,
    out_dir: str,
    size=(256, 256),          # (width, height)
    keep_aspect: bool = False,
    pad_color=(0, 0, 0),      # used only if keep_aspect=True
):
    in_path = Path(in_dir)
    out_path = Path(out_dir)
    out_path.mkdir(parents=True, exist_ok=True)

    img_exts = {".jpg", ".jpeg", ".png", ".bmp", ".tif", ".tiff", ".webp"}
    files = sorted([p for p in in_path.iterdir() if p.is_file() and p.suffix.lower() in img_exts])
    if not files:
        raise FileNotFoundError(f"No images found in: {in_path}")

    target_w, target_h = size

    for p in files:
        with Image.open(p) as img:
            img = img.convert("RGB")  # consistent output; remove if you want to preserve mode

            if keep_aspect:
                # Fit within target while preserving aspect ratio, then pad
                img.thumbnail((target_w, target_h), Image.Resampling.LANCZOS)
                canvas = Image.new("RGB", (target_w, target_h), pad_color)
                x = (target_w - img.width) // 2
                y = (target_h - img.height) // 2
                canvas.paste(img, (x, y))
                out_img = canvas
            else:
                # Direct resize (may distort aspect ratio)
                out_img = img.resize((target_w, target_h), Image.Resampling.LANCZOS)

            out_file = out_path / p.name
            if out_file.suffix.lower() in {".jpg", ".jpeg"}:
                out_img.save(out_file, quality=95)
            else:
                out_img.save(out_file)

    print(f"Resized {len(files)} images from {in_path} -> {out_path} to {size}.")

# --------- EDIT THESE ----------
input_folder = "roco"
output_folder = "roco2"
# ------------------------------

resize_images_in_folder(input_folder, output_folder, size=(256, 256), keep_aspect=False)


In [None]:
# Load latest checkpoint + run predictions on test_set_final.csv + save CSV with new column `regressor_pred`
# ✅ prints progress during prediction/eval

from pathlib import Path
import os, re, math
from typing import Any, Dict, List, Optional

import torch
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from transformers import AutoProcessor, AutoModel, TrainingArguments, Trainer
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType


# -----------------------------
# 0) Paths / config (edit if needed)
# -----------------------------
HF_CACHE_ROOT = "/work/hdd/bfrc"
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

# ⬇️ replaced val path with test_set_final.csv
EVAL_CSV  = "ROCO.csv"

IMG_ROOT  = "
MODEL_ID  = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")

PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048

PER_DEVICE_EVAL_BATCH_SIZE = 4
NUM_WORKERS = 2

# LoRA (MUST match training)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]


# -----------------------------
# 1) Find latest checkpoint
# -----------------------------
def get_latest_checkpoint(output_dir: str) -> Optional[str]:
    out = Path(output_dir)
    if not out.exists():
        return None
    ckpts = [p for p in out.glob("checkpoint-*") if p.is_dir()]
    if not ckpts:
        return None

    def step_num(p: Path) -> int:
        m = re.search(r"checkpoint-(\d+)", p.name)
        return int(m.group(1)) if m else -1

    return str(max(ckpts, key=step_num))

ckpt_dir = get_latest_checkpoint(OUTPUT_DIR)
if ckpt_dir is None:
    raise FileNotFoundError(f"No checkpoint-* found under: {OUTPUT_DIR}")
print("[ckpt] latest:", ckpt_dir)


# -----------------------------
# 2) Dataset
# -----------------------------
required_cols = {"img_name", "caption", "alignment_score"}
raw = load_dataset("csv", data_files={"eval": EVAL_CSV}, cache_dir=CACHE_DIR)
missing = required_cols - set(raw["eval"].column_names)
if missing:
    raise ValueError(f"eval is missing columns: {missing}")

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])
        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")
        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

eval_ds = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT)
print("[data] rows:", len(eval_ds))


# -----------------------------
# 3) Processor + model (must match training)
# -----------------------------
print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)

class Qwen2VLForRegression(nn.Module):
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}
        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B,T,H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.float().mean(dim=1)
        else:
            attn_f = attn.float()
            lh_f = last_hidden.float()
            denom = attn_f.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (lh_f * attn_f.unsqueeze(-1)).sum(dim=1) / denom

        logits = self.reg_head(pooled).float()  # [B,1]
        loss = None
        if labels is not None:
            loss = F.mse_loss(logits.squeeze(-1), labels.float())
        return {"loss": loss, "logits": logits}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA...")
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=TARGET_MODULES,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)


# -----------------------------
# 4) Collator
# -----------------------------
class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        model_inputs["labels"] = labels
        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN)


# -----------------------------
# 5) Load checkpoint weights (robust)
# -----------------------------
def load_checkpoint_state(model: nn.Module, ckpt_dir: str):
    ckpt = Path(ckpt_dir)

    candidates = [
        ckpt / "model.safetensors",
        ckpt / "pytorch_model.bin",
        ckpt / "adapter_model.safetensors",
        ckpt / "adapter_model.bin",
    ]

    found = None
    for p in candidates:
        if p.exists():
            found = p
            break
    if found is None:
        raise FileNotFoundError(
            f"No supported checkpoint file found in {ckpt_dir}. "
            f"Tried: {[str(p.name) for p in candidates]}"
        )

    if found.suffix == ".safetensors":
        from safetensors.torch import load_file
        print("[ckpt] loading safetensors:", found)
        state = load_file(str(found))
    else:
        print("[ckpt] loading pytorch bin:", found)
        state = torch.load(str(found), map_location="cpu")

    # If this is an adapter-only checkpoint, loading into full model may miss reg_head.
    # We still try to load into full model first.
    missing, unexpected = model.load_state_dict(state, strict=False)
    print("[ckpt] missing keys:", len(missing), "| unexpected keys:", len(unexpected))

load_checkpoint_state(model, ckpt_dir)


# -----------------------------
# 6) Manual predict loop (with progress) + metrics
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

loader = DataLoader(
    eval_ds,
    batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    collate_fn=collator,
    pin_memory=torch.cuda.is_available(),
)

all_preds = []
all_labels = []

print("[predict] running manual loop with progress prints...")
total = len(loader)

with torch.no_grad():
    for i, batch in enumerate(loader, start=1):
        labels = batch["labels"].cpu().numpy().astype(np.float32)
        batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}

        out = model(**batch)
        logits = out["logits"].detach().float().cpu().numpy().astype(np.float32).squeeze()

        all_labels.append(labels)
        all_preds.append(logits)

        if i == 1 or i % 10 == 0 or i == total:
            done = i * loader.batch_size
            done = min(done, len(eval_ds))
            print(f"[predict] batch {i}/{total} | seen {done}/{len(eval_ds)}")

val_labels = np.concatenate(all_labels, axis=0).squeeze()
val_predictions = np.concatenate(all_preds, axis=0).squeeze()

print("[predict] done. shapes:", val_labels.shape, val_predictions.shape)

mse = float(np.mean((val_predictions - val_labels) ** 2))
mae = float(np.mean(np.abs(val_predictions - val_labels)))
var = float(np.var(val_labels))
r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)

print("[metrics] mse:", mse)
print("[metrics] mae:", mae)
print("[metrics] r2 :", r2)

print("[sample] first 5 (label, pred):")
for j in range(min(5, len(val_labels))):
    print(f"  {j}: label={val_labels[j]:.4f} pred={val_predictions[j]:.4f}")


# -----------------------------
# 7) Append predictions to CSV + save (prints full absolute paths)
# -----------------------------
pred_col = "regressor_pred"

in_csv_path = Path(EVAL_CSV).resolve()
out_csv_path = in_csv_path.with_name(in_csv_path.stem + "_with_preds" + in_csv_path.suffix)

print(f"[csv] input : {in_csv_path}")
print(f"[csv] output: {out_csv_path}")

df = pd.read_csv(in_csv_path)

if len(df) != len(val_predictions):
    raise ValueError(
        f"Row mismatch: CSV has {len(df)} rows but val_predictions has {len(val_predictions)}. "
        "Ensure eval_ds order matches CSV order and shuffle=False."
    )

df[pred_col] = np.asarray(val_predictions, dtype=np.float32).reshape(-1)
df.to_csv(out_csv_path, index=False)

print(f"[csv] wrote: {out_csv_path} (added column: {pred_col})")


In [None]:
# Qwen2.5-VL 3B regression fine-tuning (image + caption -> alignment_score)
# ✅ LoRA (FEATURE_EXTRACTION)
# ✅ Single-process / single-GPU
# ✅ Resume from latest checkpoint automatically
# ✅ Validation during training (every eval_steps)
# ✅ Stability fixes: fp32 pooling + fp32 head + fp32 MSE loss
# ✅ Trainer-compatible outputs: returns "logits" (not "predictions")
#
# CHANGES MADE (hyperparams only):
# - LR: 2e-4 -> 5e-5
# - WARMUP_RATIO: 0.03 -> 0.10
# - WEIGHT_DECAY: 0.0 -> 0.01
# - MAX_TEXT_LEN: 2048 -> 1024
# - LORA_DROPOUT: 0.05 -> 0.10
# - NUM_WORKERS: 2 -> 0
# - TrainingArguments: max_grad_norm=1.0, lr_scheduler_type="cosine", optim="adamw_torch_fused"

from pathlib import Path
import os

# ----------------------------
# 1) IMPORTANT: env + cache
# ----------------------------
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

# Optional: helps avoid tokenizer fork warning / deadlocks when using multiprocessing
os.environ["TOKENIZERS_PARALLELISM"] = "false"

HF_CACHE_ROOT = "/work/hdd/bfrc"
os.environ["HF_HOME"] = str(Path(HF_CACHE_ROOT) / "hf_home")
os.environ["HUGGINGFACE_HUB_CACHE"] = str(Path(HF_CACHE_ROOT) / "hub")
os.environ["TRANSFORMERS_CACHE"] = str(Path(HF_CACHE_ROOT) / "transformers")
os.environ["HF_DATASETS_CACHE"] = str(Path(HF_CACHE_ROOT) / "datasets")
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

for p in [
    os.environ["HF_HOME"],
    os.environ["HUGGINGFACE_HUB_CACHE"],
    os.environ["TRANSFORMERS_CACHE"],
    os.environ["HF_DATASETS_CACHE"],
    CACHE_DIR,
]:
    Path(p).mkdir(parents=True, exist_ok=True)

print("[env] CUDA_LAUNCH_BLOCKING:", os.environ["CUDA_LAUNCH_BLOCKING"])
print("[env] TOKENIZERS_PARALLELISM:", os.environ.get("TOKENIZERS_PARALLELISM"))
print("[cache] HF_HOME:", os.environ["HF_HOME"])
print("[cache] HUGGINGFACE_HUB_CACHE:", os.environ["HUGGINGFACE_HUB_CACHE"])
print("[cache] TRANSFORMERS_CACHE:", os.environ["TRANSFORMERS_CACHE"])
print("[cache] HF_DATASETS_CACHE:", os.environ["HF_DATASETS_CACHE"])
print("[cache] CACHE_DIR:", CACHE_DIR)

# ----------------------------
# 2) Config
# ----------------------------
import torch

TRAIN_CSV = "/
EVAL_CSV  = "/
IMG_ROOT  = ""

MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# training (UPDATED)
EPOCHS = 2
PER_DEVICE_BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 8
LR = 5e-5                 # was 2e-4
WARMUP_RATIO = 0.10       # was 0.03
WEIGHT_DECAY = 0.01       # was 0.0
MAX_STEPS = -1

# prompt + truncation (UPDATED)
PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048       # was 2048

# LoRA (UPDATED)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.10       # was 0.05

SEED = 42
NUM_WORKERS = 0           # was 2 (avoid tokenizers fork issues)

print("[cfg] OUTPUT_DIR:", OUTPUT_DIR)
print("[cfg] CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("[cfg] GPU:", torch.cuda.get_device_name(0))
    print("[cfg] device_count:", torch.cuda.device_count())

# ----------------------------
# 3) Load dataset
# ----------------------------
from datasets import load_dataset

print("[data] Loading CSV(s)...")
data_files = {"train": TRAIN_CSV}
if EVAL_CSV:
    data_files["eval"] = EVAL_CSV

raw = load_dataset("csv", data_files=data_files, cache_dir=CACHE_DIR)

print("[data] splits:", list(raw.keys()))
print("[data] columns:", raw["train"].column_names)
print("[data] train rows:", len(raw["train"]))
if EVAL_CSV:
    print("[data] eval rows:", len(raw["eval"]))

required_cols = {"img_name", "caption", "alignment_score"}
for split in raw.keys():
    missing = required_cols - set(raw[split].column_names)
    if missing:
        raise ValueError(f"{split} is missing columns: {missing}")

print("[data] example row:", raw["train"][0])

# ----------------------------
# 4) Torch dataset (image load on the fly)
# ----------------------------
import math
from PIL import Image
from torch.utils.data import Dataset

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])

        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")

        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

train_ds = CSVImageCaptionRegressionDataset(raw["train"], IMG_ROOT)
eval_ds  = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT) if EVAL_CSV else None
print("[ds] train:", len(train_ds), "| eval:", (len(eval_ds) if eval_ds else 0))

# ----------------------------
# 5) Load processor + base model + regression head + LoRA
# ----------------------------
from transformers import AutoProcessor, AutoModel
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType

print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base model (fp16)...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)
print("[model] Base loaded.")

class Qwen2VLForRegression(nn.Module):
    """
    Trainer-friendly regression wrapper:
      - returns dict with keys: loss, logits
      - logits shape: [B, 1]
    Stability:
      - pooling in fp32
      - regression head run in fp32
      - mse loss in fp32
    """
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}

        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B, T, H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.float().mean(dim=1)
        else:
            attn_f = attn.float()
            lh_f = last_hidden.float()
            denom = attn_f.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (lh_f * attn_f.unsqueeze(-1)).sum(dim=1) / denom

        logits = self.reg_head(pooled).float()  # [B, 1] fp32

        loss = None
        if labels is not None:
            loss = F.mse_loss(logits.squeeze(-1), labels.float())

        return {"loss": loss, "logits": logits}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA (FEATURE_EXTRACTION)...")
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=target_modules,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)
model.base.print_trainable_parameters()

for p in model.reg_head.parameters():
    p.requires_grad = True
print("[model] reg_head params:", sum(p.numel() for p in model.reg_head.parameters()))

# ----------------------------
# 6) Collator (truncation + token-range checks)
# ----------------------------
from typing import Any, Dict, List, Optional

class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int, debug: bool = False):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len
        self.debug = debug
        self._seen = 0

        self.vocab_size = None
        tok = getattr(processor, "tokenizer", None)
        if tok is not None and hasattr(tok, "__len__"):
            self.vocab_size = int(len(tok))  # includes added/special tokens

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        model_inputs["labels"] = labels

        if "input_ids" in model_inputs:
            max_id = int(model_inputs["input_ids"].max())
            if self.vocab_size is not None and max_id >= self.vocab_size:
                raise ValueError(f"input_ids has id {max_id} >= vocab_size {self.vocab_size}")

        if self.debug and self._seen < 3:
            print("[collator] prompt preview:\n", texts[0][:400], "...")
            print("[collator] input_ids:", tuple(model_inputs["input_ids"].shape))
            print("[collator] labels:", labels[: min(4, len(labels))].tolist())
            self._seen += 1

        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN, debug=False)

b = collator([train_ds[0]])
print("[collator] keys:", list(b.keys()))

# ----------------------------
# 7) Trainer + progress prints + resume
# ----------------------------
from transformers import TrainingArguments, Trainer, set_seed, TrainerCallback
import numpy as np
import time
import re

set_seed(SEED)

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.asarray(preds, dtype=np.float32).squeeze()   # handles [N,1] or [N]
    labels = np.asarray(labels, dtype=np.float32).squeeze()

    bad = ~np.isfinite(preds)
    if bad.any():
        return {
            "mse": float("nan"),
            "mae": float("nan"),
            "r2": float("nan"),
            "pred_nan_frac": float(bad.mean()),
        }

    mse = float(np.mean((preds - labels) ** 2))
    mae = float(np.mean(np.abs(preds - labels)))
    var = float(np.var(labels))
    r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)
    return {"mse": mse, "mae": mae, "r2": r2}

class PrintProgressCallback(TrainerCallback):
    def __init__(self, trainer_ref=None, sample_every_steps: int = 100):
        self.trainer_ref = trainer_ref
        self.sample_every_steps = sample_every_steps
        self.t0 = None

    def on_train_begin(self, args, state, control, **kwargs):
        self.t0 = time.time()
        print(f"[train] begin | max_steps={state.max_steps} | epochs={args.num_train_epochs}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        elapsed = (time.time() - self.t0) if self.t0 else 0.0
        msg = f"[log] step={state.global_step} epoch={state.epoch:.3f} elapsed={elapsed/60:.1f}m"
        if "loss" in logs:
            msg += f" loss={logs['loss']:.4f}"
        if "learning_rate" in logs:
            msg += f" lr={logs['learning_rate']:.2e}"
        if "grad_norm" in logs:
            msg += f" grad_norm={logs['grad_norm']:.3f}"
        print(msg)

    def on_step_end(self, args, state, control, **kwargs):
        if self.trainer_ref is None:
            return
        if state.global_step > 0 and (state.global_step % self.sample_every_steps == 0):
            try:
                m = self.trainer_ref.model
                m.eval()
                ex = train_ds[0]
                batch = collator([ex])
                device = next(m.parameters()).device
                batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}
                with torch.no_grad():
                    out = m(**batch)
                pred = float(out["logits"].detach().float().cpu().squeeze().item())
                print(f"[sample] step={state.global_step} pred={pred:.4f} label={float(ex['labels']):.4f}")
            except Exception as e:
                print("[sample] failed:", repr(e))

def get_latest_checkpoint(output_dir: str) -> Optional[str]:
    out = Path(output_dir)
    if not out.exists():
        return None
    ckpts = [p for p in out.glob("checkpoint-*") if p.is_dir()]
    if not ckpts:
        return None

    def step_num(p: Path) -> int:
        m = re.search(r"checkpoint-(\d+)", p.name)
        return int(m.group(1)) if m else -1

    return str(max(ckpts, key=step_num))

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    logging_steps=10,

    save_steps=200,
    save_total_limit=2,

    eval_strategy="steps" if eval_ds is not None else "no",
    eval_steps=200 if eval_ds is not None else None,

    bf16=False,
    fp16=True,

    dataloader_num_workers=NUM_WORKERS,
    remove_unused_columns=False,
    max_steps=MAX_STEPS,
    report_to="none",

    # UPDATED stability knobs
    max_grad_norm=1.0,
    lr_scheduler_type="cosine",
    optim="adamw_torch_fused",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    compute_metrics=compute_metrics if eval_ds is not None else None,
)

trainer.add_callback(PrintProgressCallback(trainer_ref=trainer, sample_every_steps=100))

latest_ckpt = get_latest_checkpoint(OUTPUT_DIR)
print("[resume] latest_ckpt:", latest_ckpt)

print("[train] Starting training...")
trainer.train(resume_from_checkpoint=latest_ckpt if latest_ckpt else None)
print("[train] Training finished.")

# ----------------------------
# 8) Save
# ----------------------------
print("[save] Saving full model + processor to:", OUTPUT_DIR)
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

lora_dir = str(Path(OUTPUT_DIR) / "lora_adapter")
print("[save] Saving LoRA adapter to:", lora_dir)
model.base.save_pretrained(lora_dir)

print("[save] Done.")


In [None]:
# Qwen2.5-VL 3B regression fine-tuning (image + caption -> alignment_score)
# ✅ LoRA (FEATURE_EXTRACTION)
# ✅ Single-process / single-GPU
# ✅ Resume from latest checkpoint automatically
# ✅ Validation during training (every eval_steps)
# ✅ Stability fixes: fp32 pooling + fp32 head + fp32 MSE loss
# ✅ Trainer-compatible outputs: returns "logits" (not "predictions")
#
# MODS (hyperparams only, plus printing sanity):
# - Keep MAX_TEXT_LEN=2048 (avoid Qwen2.5-VL image-token truncation mismatch)
# - LR: 2e-4 -> 5e-5
# - WARMUP_RATIO: 0.03 -> 0.10
# - WEIGHT_DECAY: 0.0 -> 0.01
# - LORA_DROPOUT: 0.05 -> 0.10
# - NUM_WORKERS: 2 -> 0
# - TrainingArguments: max_grad_norm=1.0, lr_scheduler_type="cosine"
# - optim: try "adamw_torch_fused" (falls back to "adamw_torch" if unavailable)

from pathlib import Path
import os

# ----------------------------
# 1) IMPORTANT: env + cache
# ----------------------------
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

HF_CACHE_ROOT = "/work/hdd/bfrc"
os.environ["HF_HOME"] = str(Path(HF_CACHE_ROOT) / "hf_home")
os.environ["HUGGINGFACE_HUB_CACHE"] = str(Path(HF_CACHE_ROOT) / "hub")
os.environ["TRANSFORMERS_CACHE"] = str(Path(HF_CACHE_ROOT) / "transformers")
os.environ["HF_DATASETS_CACHE"] = str(Path(HF_CACHE_ROOT) / "datasets")
CACHE_DIR = str(Path(HF_CACHE_ROOT) / "cache_dir")

for p in [
    os.environ["HF_HOME"],
    os.environ["HUGGINGFACE_HUB_CACHE"],
    os.environ["TRANSFORMERS_CACHE"],
    os.environ["HF_DATASETS_CACHE"],
    CACHE_DIR,
]:
    Path(p).mkdir(parents=True, exist_ok=True)

print("[env] CUDA_LAUNCH_BLOCKING:", os.environ["CUDA_LAUNCH_BLOCKING"])
print("[env] TOKENIZERS_PARALLELISM:", os.environ.get("TOKENIZERS_PARALLELISM"))
print("[cache] HF_HOME:", os.environ["HF_HOME"])
print("[cache] HUGGINGFACE_HUB_CACHE:", os.environ["HUGGINGFACE_HUB_CACHE"])
print("[cache] TRANSFORMERS_CACHE:", os.environ["TRANSFORMERS_CACHE"])
print("[cache] HF_DATASETS_CACHE:", os.environ["HF_DATASETS_CACHE"])
print("[cache] CACHE_DIR:", CACHE_DIR)

# ----------------------------
# 2) Config
# ----------------------------
import torch

TRAIN_CSV = "/"
EVAL_CSV  = "/"
IMG_ROOT  = ""

MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

OUTPUT_DIR = str(Path(HF_CACHE_ROOT) / "finetunes" / "qwen2vl-regression-out")
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# training (UPDATED)
EPOCHS = 2
PER_DEVICE_BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 8

LR = 5e-5                 # was 2e-4
WARMUP_RATIO = 0.10       # was 0.03
WEIGHT_DECAY = 0.01       # was 0.0
MAX_STEPS = -1

# prompt + truncation (KEEP 2048 to avoid image-token mismatch)
PROMPT_PREFIX = "Predict the alignment score (a real number). Caption:"
MAX_TEXT_LEN = 2048

# LoRA (UPDATED)
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.10       # was 0.05

SEED = 42
NUM_WORKERS = 0           # was 2 (avoid tokenizers fork issues)

print("[cfg] OUTPUT_DIR:", OUTPUT_DIR)
print("[cfg] CUDA:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("[cfg] GPU:", torch.cuda.get_device_name(0))
    print("[cfg] device_count:", torch.cuda.device_count())

# ----------------------------
# 3) Load dataset
# ----------------------------
from datasets import load_dataset

print("[data] Loading CSV(s)...")
data_files = {"train": TRAIN_CSV}
if EVAL_CSV:
    data_files["eval"] = EVAL_CSV

raw = load_dataset("csv", data_files=data_files, cache_dir=CACHE_DIR)

print("[data] splits:", list(raw.keys()))
print("[data] columns:", raw["train"].column_names)
print("[data] train rows:", len(raw["train"]))
if EVAL_CSV:
    print("[data] eval rows:", len(raw["eval"]))

required_cols = {"img_name", "caption", "alignment_score"}
for split in raw.keys():
    missing = required_cols - set(raw[split].column_names)
    if missing:
        raise ValueError(f"{split} is missing columns: {missing}")

print("[data] example row:", raw["train"][0])

# ----------------------------
# 4) Torch dataset (image load on the fly)
# ----------------------------
import math
from PIL import Image
from torch.utils.data import Dataset

def load_image(path: str) -> Image.Image:
    with Image.open(path) as im:
        return im.convert("RGB")

class CSVImageCaptionRegressionDataset(Dataset):
    def __init__(self, hf_dataset, img_root: str):
        self.ds = hf_dataset
        self.img_root = img_root

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        row = self.ds[idx]
        img_path = os.path.join(self.img_root, str(row["img_name"]))
        caption = str(row["caption"])
        label = float(row["alignment_score"])

        if not math.isfinite(label):
            raise ValueError(f"Non-finite label at idx={idx}: {label}")

        image = load_image(img_path)
        return {"image": image, "caption": caption, "labels": label, "img_path": img_path}

train_ds = CSVImageCaptionRegressionDataset(raw["train"], IMG_ROOT)
eval_ds  = CSVImageCaptionRegressionDataset(raw["eval"], IMG_ROOT) if EVAL_CSV else None
print("[ds] train:", len(train_ds), "| eval:", (len(eval_ds) if eval_ds else 0))

# ----------------------------
# 5) Load processor + base model + regression head + LoRA
# ----------------------------
from transformers import AutoProcessor, AutoModel
import torch.nn as nn
import torch.nn.functional as F
from peft import LoraConfig, get_peft_model, TaskType

print("[model] Loading processor...")
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)

print("[model] Loading base model (fp16)...")
base = AutoModel.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    cache_dir=CACHE_DIR,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    low_cpu_mem_usage=True,
)
print("[model] Base loaded.")

class Qwen2VLForRegression(nn.Module):
    """
    Trainer-friendly regression wrapper:
      - returns dict with keys: loss, logits
      - logits shape: [B, 1]
    Stability:
      - pooling in fp32
      - regression head run in fp32
      - mse loss in fp32
    """
    def __init__(self, base_model: nn.Module):
        super().__init__()
        self.base = base_model
        hidden_size = getattr(base_model.config, "hidden_size", None)
        if hidden_size is None:
            hidden_size = getattr(getattr(base_model.config, "text_config", None), "hidden_size", None)
        if hidden_size is None:
            raise ValueError("Could not infer hidden_size from model config.")
        self.reg_head = nn.Linear(hidden_size, 1)

    def forward(self, labels=None, **inputs):
        base_inputs = {k: v for k, v in inputs.items() if torch.is_tensor(v) and k != "labels"}

        out = self.base(**base_inputs, output_hidden_states=False, return_dict=True)
        last_hidden = out.last_hidden_state  # [B, T, H]

        attn = base_inputs.get("attention_mask", None)
        if attn is None:
            pooled = last_hidden.float().mean(dim=1)
        else:
            attn_f = attn.float()
            lh_f = last_hidden.float()
            denom = attn_f.sum(dim=1).clamp_min(1.0).unsqueeze(-1)
            pooled = (lh_f * attn_f.unsqueeze(-1)).sum(dim=1) / denom

        logits = self.reg_head(pooled).float()  # [B, 1] fp32

        loss = None
        if labels is not None:
            loss = F.mse_loss(logits.squeeze(-1), labels.float())

        return {"loss": loss, "logits": logits}

model = Qwen2VLForRegression(base)

print("[lora] Applying LoRA (FEATURE_EXTRACTION)...")
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
lora_cfg = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    bias="none",
    target_modules=target_modules,
    task_type=TaskType.FEATURE_EXTRACTION,
)
model.base = get_peft_model(model.base, lora_cfg)
model.base.print_trainable_parameters()

for p in model.reg_head.parameters():
    p.requires_grad = True
print("[model] reg_head params:", sum(p.numel() for p in model.reg_head.parameters()))

# ----------------------------
# 6) Collator (keep your original truncation behavior)
# ----------------------------
from typing import Any, Dict, List, Optional

class Qwen2VLRegressionCollator:
    def __init__(self, processor, prompt_prefix: str, max_text_len: int, debug: bool = False):
        self.processor = processor
        self.prompt_prefix = prompt_prefix
        self.max_text_len = max_text_len
        self.debug = debug
        self._seen = 0

        self.vocab_size = None
        tok = getattr(processor, "tokenizer", None)
        if tok is not None and hasattr(tok, "__len__"):
            self.vocab_size = int(len(tok))  # includes added/special tokens

    def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
        images = [ex["image"] for ex in batch]
        captions = [ex["caption"] for ex in batch]
        labels = torch.tensor([float(ex["labels"]) for ex in batch], dtype=torch.float32)

        texts = []
        for cap in captions:
            prompt = f"{self.prompt_prefix}\n{cap}\nScore:"
            messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]}]
            text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
            texts.append(text)

        model_inputs = self.processor(
            text=texts,
            images=images,
            padding=True,
            truncation=True,
            max_length=self.max_text_len,   # keep 2048 to avoid mismatch
            return_tensors="pt",
        )
        model_inputs["labels"] = labels

        if "input_ids" in model_inputs:
            max_id = int(model_inputs["input_ids"].max())
            if self.vocab_size is not None and max_id >= self.vocab_size:
                raise ValueError(f"input_ids has id {max_id} >= vocab_size {self.vocab_size}")

        if self.debug and self._seen < 3:
            print("[collator] prompt preview:\n", texts[0][:400], "...")
            print("[collator] input_ids:", tuple(model_inputs["input_ids"].shape))
            print("[collator] labels:", labels[: min(4, len(labels))].tolist())
            self._seen += 1

        return model_inputs

collator = Qwen2VLRegressionCollator(processor, PROMPT_PREFIX, MAX_TEXT_LEN, debug=False)

b = collator([train_ds[0]])
print("[collator] keys:", list(b.keys()))

# ----------------------------
# 7) Trainer + progress prints + resume
# ----------------------------
from transformers import TrainingArguments, Trainer, set_seed, TrainerCallback
import numpy as np
import time
import re

set_seed(SEED)

def compute_metrics(eval_pred):
    preds, labels = eval_pred
    preds = np.asarray(preds, dtype=np.float32).squeeze()   # handles [N,1] or [N]
    labels = np.asarray(labels, dtype=np.float32).squeeze()

    bad = ~np.isfinite(preds)
    if bad.any():
        return {
            "mse": float("nan"),
            "mae": float("nan"),
            "r2": float("nan"),
            "pred_nan_frac": float(bad.mean()),
        }

    mse = float(np.mean((preds - labels) ** 2))
    mae = float(np.mean(np.abs(preds - labels)))
    var = float(np.var(labels))
    r2 = float("nan") if var == 0.0 else float(1.0 - mse / var)
    return {"mse": mse, "mae": mae, "r2": r2}

class PrintProgressCallback(TrainerCallback):
    def __init__(self, trainer_ref=None, sample_every_steps: int = 100):
        self.trainer_ref = trainer_ref
        self.sample_every_steps = sample_every_steps
        self.t0 = None

    def on_train_begin(self, args, state, control, **kwargs):
        self.t0 = time.time()
        print(f"[train] begin | max_steps={state.max_steps} | epochs={args.num_train_epochs}")

    def on_log(self, args, state, control, logs=None, **kwargs):
        if not logs:
            return
        elapsed = (time.time() - self.t0) if self.t0 else 0.0
        msg = f"[log] step={state.global_step} epoch={state.epoch:.3f} elapsed={elapsed/60:.1f}m"
        loss_val = logs.get("loss", logs.get("train_loss", None))
        if loss_val is not None:
            msg += f" loss={loss_val:.4f}"
        if "learning_rate" in logs:
            msg += f" lr={logs['learning_rate']:.2e}"
        if "grad_norm" in logs:
            msg += f" grad_norm={logs['grad_norm']:.3f}"
        print(msg)

    def on_step_end(self, args, state, control, **kwargs):
        if self.trainer_ref is None:
            return
        if state.global_step > 0 and (state.global_step % self.sample_every_steps == 0):
            try:
                m = self.trainer_ref.model
                m.eval()
                ex = train_ds[0]
                batch = collator([ex])
                device = next(m.parameters()).device
                batch = {k: v.to(device) for k, v in batch.items() if torch.is_tensor(v)}
                with torch.no_grad():
                    out = m(**batch)
                pred = float(out["logits"].detach().float().cpu().squeeze().item())
                print(f"[sample] step={state.global_step} pred={pred:.4f} label={float(ex['labels']):.4f}")
            except Exception as e:
                print("[sample] failed:", repr(e))

def get_latest_checkpoint(output_dir: str) -> Optional[str]:
    out = Path(output_dir)
    if not out.exists():
        return None
    ckpts = [p for p in out.glob("checkpoint-*") if p.is_dir()]
    if not ckpts:
        return None

    def step_num(p: Path) -> int:
        m = re.search(r"checkpoint-(\d+)", p.name)
        return int(m.group(1)) if m else -1

    return str(max(ckpts, key=step_num))

# Optim: fused if available, else fallback
try:
    _OPTIM = "adamw_torch_fused"
except Exception:
    _OPTIM = "adamw_torch"

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LR,
    warmup_ratio=WARMUP_RATIO,
    weight_decay=WEIGHT_DECAY,
    logging_steps=10,

    save_steps=200,
    save_total_limit=2,

    # NOTE: if your transformers complains, rename eval_strategy -> evaluation_strategy
    eval_strategy="steps" if eval_ds is not None else "no",
    eval_steps=200 if eval_ds is not None else None,

    bf16=False,
    fp16=True,

    dataloader_num_workers=NUM_WORKERS,
    remove_unused_columns=False,
    max_steps=MAX_STEPS,
    report_to="none",

    # Stability knobs (IMPORTANT)
    max_grad_norm=1.0,
    lr_scheduler_type="cosine",
    optim=_OPTIM,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collator,
    compute_metrics=compute_metrics if eval_ds is not None else None,
)

trainer.add_callback(PrintProgressCallback(trainer_ref=trainer, sample_every_steps=100))

latest_ckpt = get_latest_checkpoint(OUTPUT_DIR)
print("[resume] latest_ckpt:", latest_ckpt)

# Sanity print (so you know you're running the new hyperparams)
print("[sanity] trainer.args.learning_rate:", trainer.args.learning_rate)
print("[sanity] trainer.args.warmup_ratio:", trainer.args.warmup_ratio)
print("[sanity] trainer.args.weight_decay:", trainer.args.weight_decay)
print("[sanity] trainer.args.max_grad_norm:", trainer.args.max_grad_norm)
print("[sanity] trainer.args.lr_scheduler_type:", trainer.args.lr_scheduler_type)
print("[sanity] trainer.args.optim:", trainer.args.optim)

print("[train] Starting training...")
trainer.train(resume_from_checkpoint=latest_ckpt if latest_ckpt else None)
print("[train] Training finished.")

# ----------------------------
# 8) Save
# ----------------------------
print("[save] Saving full model + processor to:", OUTPUT_DIR)
trainer.save_model(OUTPUT_DIR)
processor.save_pretrained(OUTPUT_DIR)

lora_dir = str(Path(OUTPUT_DIR) / "lora_adapter")
print("[save] Saving LoRA adapter to:", lora_dir)
model.base.save_pretrained(lora_dir)

print("[save] Done.")
