In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ============ 环境稳态（必须放最前）============
import os
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")  # 避免 hf_transfer 报错
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# 如果你想强制同步定位 GPU 报错，打开下面这行（会慢很多）
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import json, math, random, re, inspect
from collections import Counter
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple

import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

from transformers import (
    BlipProcessor,
    BlipForQuestionAnswering,
    get_linear_schedule_with_warmup,
)

# =========================================================
# 0) CONFIG
# =========================================================

# —— 数据：HF PathVQA ——（推荐先跑通）
USE_HF_DATASET = True
HF_DATASET_NAME = "flaviagiammarino/path-vqa"
HF_TRAIN_SPLIT = "train"
HF_VAL_SPLIT = "validation"
HF_TEST_SPLIT = "test"
HF_IMAGE_COL = "image"
HF_QUESTION_COL = "question"
HF_ANSWER_COL = "answer"

# —— 本地 JSONL（可选）——
DATA_DIR = None
IMAGE_ROOT = None
TRAIN_JSONL = "train.jsonl"
VAL_JSONL = "val.jsonl"
TEST_JSONL = "test.jsonl"

# 模型
MODEL_NAME = "Salesforce/blip-vqa-base"

# 输出
OUTPUT_DIR = "./runs/blip_pathvqa_best"

# 训练超参（先保守跑通）
SEED = 42
EPOCHS = 20
BATCH_SIZE = 4
GRAD_ACCUM = 2
LR = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.06

MAX_QUESTION_LEN = 32   # 跑稳后可改 64
MAX_ANSWER_LEN = 8      # 跑稳后可改 16
USE_PROMPT = True

NUM_WORKERS = 0         # notebook/共享环境最稳；跑通后你再加到 2/4
PIN_MEMORY = True

LOG_EVERY = 50
MAX_GRAD_NORM = 1.0

# 混合精度
USE_FP16 = True
USE_BF16 = False

# Constrained decoding（只影响 eval）
CONSTRAINED_DECODE = True
ANSWER_VOCAB_SIZE = -1
NUM_BEAMS = 3

# Debug/预检
CPU_PREFLIGHT_CHECK = True   # 先在 CPU 跑一个 batch，能把越界问题直接抓出来
DEBUG_SANITY_STEPS = 5       # 训练前几个 step 打印/断言 token 范围

# =========================================================
# 1) utils
# =========================================================

def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

def normalize_answer(s: str) -> str:
    if s is None:
        return ""
    s = str(s).strip().lower()
    s = s.replace("\n", " ").replace("\t", " ")
    s = re.sub(r"\s+", " ", s).strip()
    s = s.strip(" .,:;!?\"'()[]{}")
    if s in {"y", "yeah", "yep", "true"}:
        s = "yes"
    if s in {"n", "nope", "nah", "false"}:
        s = "no"
    return s

# =========================================================
# 2) dataset
# =========================================================

class HFVQADataset(Dataset):
    """
    重要：缓存 answer 列，避免构造答案集合时把图片全读一遍。
    """
    def __init__(self, name: str, split: str, image_col: str, q_col: str, a_col: str, normalize: bool = True):
        from datasets import load_dataset
        self.ds = load_dataset(name, split=split)
        self.image_col = image_col
        self.q_col = q_col
        self.a_col = a_col
        self.normalize = normalize

        for c in [image_col, q_col, a_col]:
            if c not in self.ds.column_names:
                raise ValueError(f"Column '{c}' not in {self.ds.column_names}")

        raw_answers = list(self.ds[a_col])
        self._answers = [normalize_answer(a) for a in raw_answers] if normalize else [str(a) for a in raw_answers]

    def get_answers(self) -> List[str]:
        return self._answers

    def __len__(self) -> int:
        return len(self.ds)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        ex = self.ds[idx]
        img = ex[self.image_col]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img).convert("RGB")
        q = str(ex[self.q_col])
        a = self._answers[idx]
        return img.convert("RGB"), q, a


class JsonlVQADataset(Dataset):
    def __init__(self, jsonl_path: str, image_root: Optional[str], normalize: bool = True):
        if not os.path.isfile(jsonl_path):
            raise FileNotFoundError(f"JSONL not found: {jsonl_path}")
        self.samples: List[Dict[str, Any]] = []
        self.image_root = image_root
        self.normalize = normalize
        self._answers: List[str] = []

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                line = line.strip()
                if not line:
                    continue
                ex = json.loads(line)
                for k in ["image", "question", "answer"]:
                    if k not in ex:
                        raise ValueError(f"{jsonl_path} line {i} missing key '{k}', got keys={list(ex.keys())}")
                if self.normalize:
                    ex["answer"] = normalize_answer(ex["answer"])
                self.samples.append(ex)
                self._answers.append(ex["answer"])

    def get_answers(self) -> List[str]:
        return self._answers

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        ex = self.samples[idx]
        img_path = ex["image"]
        if self.image_root and not os.path.isabs(img_path):
            img_path = os.path.join(self.image_root, img_path)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        img = Image.open(img_path).convert("RGB")
        q = str(ex["question"])
        a = str(ex["answer"])
        return img, q, a

# =========================================================
# 3) constrained decoding trie
# =========================================================

class TrieNode:
    __slots__ = ("children", "is_end")
    def __init__(self):
        self.children: Dict[int, "TrieNode"] = {}
        self.is_end: bool = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, token_ids: List[int]) -> None:
        node = self.root
        for t in token_ids:
            if t not in node.children:
                node.children[t] = TrieNode()
            node = node.children[t]
        node.is_end = True

    def next_tokens(self, prefix: List[int]) -> Tuple[List[int], bool]:
        node = self.root
        for t in prefix:
            if t not in node.children:
                return [], False
            node = node.children[t]
        return list(node.children.keys()), node.is_end

def build_answer_set(ds: Dataset, max_answers: int) -> List[str]:
    if hasattr(ds, "get_answers"):
        answers = ds.get_answers()
    else:
        answers = [ds[i][2] for i in range(len(ds))]
    cnt = Counter(answers)
    if max_answers is not None and max_answers > 0:
        ans_list = [a for a, _ in cnt.most_common(max_answers)]
    else:
        ans_list = list(cnt.keys())
    for a in ["yes", "no"]:
        if a not in ans_list:
            ans_list.append(a)
    return ans_list

def build_trie_and_lookup(answer_list: List[str], tokenizer) -> Tuple[Trie, Dict[Tuple[int, ...], str], int]:
    trie = Trie()
    seq2ans: Dict[Tuple[int, ...], str] = {}
    max_len = 0
    for ans in answer_list:
        ids = tokenizer(ans, add_special_tokens=False).input_ids
        if not ids:
            continue
        trie.insert(ids)
        seq2ans[tuple(ids)] = ans
        max_len = max(max_len, len(ids))
    return trie, seq2ans, max_len

def make_prefix_allowed_tokens_fn(trie: Trie, bos_id: Optional[int], eos_id: int):
    def _fn(batch_id: int, input_ids: torch.LongTensor) -> List[int]:
        ids = input_ids.tolist()
        if bos_id is not None and len(ids) > 0 and ids[0] == bos_id:
            prefix = ids[1:]
        else:
            prefix = ids
        nxt, is_end = trie.next_tokens(prefix)
        allowed = list(nxt)
        if is_end:
            allowed.append(eos_id)
        if not allowed:
            allowed = [eos_id]
        return allowed
    return _fn

# =========================================================
# 4) collate（关键：图像/文本分开处理，确保 truncation 生效）
# =========================================================

@dataclass
class Batch:
    pixel_values: torch.FloatTensor
    input_ids: torch.LongTensor
    attention_mask: torch.LongTensor
    labels: torch.LongTensor
    answers: List[str]

def make_collate_fn(processor: BlipProcessor, max_q_len: int, max_a_len: int):
    tok = processor.tokenizer

    def _collate(examples: List[Tuple[Image.Image, str, str]]) -> Batch:
        images, questions, answers = zip(*examples)
        if USE_PROMPT:
            questions = tuple([f"Question: {q} Answer:" for q in questions])

        # 图像
        pixel = processor(images=list(images), return_tensors="pt")

        # 问题文本（强制 tokenizer，确保截断）
        txt = tok(
            list(questions),
            padding=True,
            truncation=True,
            max_length=max_q_len,
            return_tensors="pt",
        )

        # 答案 labels
        ans_tok = tok(
            list(answers),
            padding=True,
            truncation=True,
            max_length=max_a_len,
            return_tensors="pt",
            add_special_tokens=True,
        )
        labels = ans_tok.input_ids

        # pad -> -100（loss ignore）
        pad_id = tok.pad_token_id
        labels = labels.clone()
        labels[labels == pad_id] = -100

        return Batch(
            pixel_values=pixel["pixel_values"],
            input_ids=txt["input_ids"],
            attention_mask=txt["attention_mask"],
            labels=labels,
            answers=list(answers),
        )

    return _collate

# =========================================================
# 5) eval
# =========================================================

@torch.no_grad()
def evaluate(model, processor, loader, device, use_amp, amp_dtype,
             constrained, prefix_allowed_tokens_fn, seq2ans, max_gen_len, num_beams):
    model.eval()
    tok = processor.tokenizer
    bos = tok.cls_token_id if tok.cls_token_id is not None else tok.bos_token_id
    eos = tok.sep_token_id if tok.sep_token_id is not None else tok.eos_token_id
    pad = tok.pad_token_id

    total = correct = 0
    total_yesno = correct_yesno = 0
    total_open = correct_open = 0

    for batch in tqdm(loader, desc="eval", leave=False):
        pixel_values = batch.pixel_values.to(device, non_blocking=True)
        input_ids = batch.input_ids.to(device, non_blocking=True)
        attention_mask = batch.attention_mask.to(device, non_blocking=True)

        gen_kwargs = dict(
            max_length=1 + max_gen_len,
            num_beams=num_beams,
            do_sample=False,
        )
        if constrained:
            gen_kwargs["prefix_allowed_tokens_fn"] = prefix_allowed_tokens_fn

        with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
            gen_ids = model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                **gen_kwargs,
            )

        preds: List[str] = []
        for seq in gen_ids:
            ids = seq.tolist()
            if bos is not None and len(ids) > 0 and ids[0] == bos:
                ids = ids[1:]
            if eos is not None and eos in ids:
                ids = ids[:ids.index(eos)]
            if pad is not None:
                ids = [t for t in ids if t != pad]

            if constrained and seq2ans is not None:
                pred = seq2ans.get(tuple(ids), tok.decode(ids, skip_special_tokens=True))
            else:
                pred = tok.decode(ids, skip_special_tokens=True)

            preds.append(normalize_answer(pred))

        gts = batch.answers
        for p, gt in zip(preds, gts):
            total += 1
            ok = (p == gt)
            correct += int(ok)

            if gt in ("yes", "no"):
                total_yesno += 1
                correct_yesno += int(ok)
            else:
                total_open += 1
                correct_open += int(ok)

    return {
        "acc": correct / max(1, total),
        "acc_yesno": correct_yesno / max(1, total_yesno),
        "acc_open": correct_open / max(1, total_open),
        "n": total,
        "n_yesno": total_yesno,
        "n_open": total_open,
    }

# =========================================================
# 6) main
# =========================================================

def main():
    if USE_FP16 and USE_BF16:
        raise ValueError("Choose only one: USE_FP16 or USE_BF16")

    set_seed(SEED)
    ensure_dir(OUTPUT_DIR)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = USE_FP16 or USE_BF16
    amp_dtype = torch.float16 if USE_FP16 else (torch.bfloat16 if USE_BF16 else torch.float32)
    print(f"Device={device}  AMP={use_amp}  dtype={amp_dtype}")

    # 先在 CPU 加载，先做预检
    processor = BlipProcessor.from_pretrained(MODEL_NAME)
    model = BlipForQuestionAnswering.from_pretrained(MODEL_NAME)

    tok = processor.tokenizer

    # ---- 关键：保证 PAD 存在，否则 labels(-100) shift 可能出大坑
    if tok.pad_token is None:
        # 加一个 pad token 并扩展 embedding
        tok.add_special_tokens({"pad_token": "[PAD]"})
        model.resize_token_embeddings(len(tok))

    pad_id = tok.pad_token_id
    bos_id = tok.cls_token_id if tok.cls_token_id is not None else tok.bos_token_id
    eos_id = tok.sep_token_id if tok.sep_token_id is not None else tok.eos_token_id
    if bos_id is None:
        raise ValueError("Cannot determine bos/cls token id for BLIP tokenizer.")
    if eos_id is None:
        raise ValueError("Cannot determine eos/sep token id for BLIP tokenizer.")

    # 同步到 config（非常重要）
    model.config.pad_token_id = pad_id
    model.config.bos_token_id = bos_id
    model.config.eos_token_id = eos_id
    model.config.decoder_start_token_id = bos_id

    # 有的版本把 text_config 分开存
    if hasattr(model.config, "text_config") and model.config.text_config is not None:
        model.config.text_config.pad_token_id = pad_id
        model.config.text_config.bos_token_id = bos_id
        model.config.text_config.eos_token_id = eos_id
        model.config.text_config.decoder_start_token_id = bos_id

    if hasattr(model, "text_encoder") and hasattr(model.text_encoder, "config"):
        model.text_encoder.config.pad_token_id = pad_id
    if hasattr(model, "text_decoder") and hasattr(model.text_decoder, "config"):
        model.text_decoder.config.pad_token_id = pad_id
        model.text_decoder.config.decoder_start_token_id = bos_id

    # generation_config 也同步一下（避免 generate 内部出幺蛾子）
    if getattr(model, "generation_config", None) is not None:
        model.generation_config.pad_token_id = pad_id
        model.generation_config.bos_token_id = bos_id
        model.generation_config.eos_token_id = eos_id

    # 长度上限（避免 position embedding 越界）
    max_pos = int(model.text_encoder.config.max_position_embeddings)
    max_q_len = min(MAX_QUESTION_LEN, max_pos)
    max_a_len = min(MAX_ANSWER_LEN, max_pos)
    print(f"[Len] max_pos={max_pos}, use max_q_len={max_q_len}, max_a_len={max_a_len}")
    print(f"[Tokens] pad={pad_id} bos={bos_id} eos={eos_id} vocab(len(tokenizer))={len(tok)}")

    # Load dataset
    if USE_HF_DATASET:
        train_ds = HFVQADataset(HF_DATASET_NAME, HF_TRAIN_SPLIT, HF_IMAGE_COL, HF_QUESTION_COL, HF_ANSWER_COL, normalize=True)
        val_ds = HFVQADataset(HF_DATASET_NAME, HF_VAL_SPLIT, HF_IMAGE_COL, HF_QUESTION_COL, HF_ANSWER_COL, normalize=True)
        test_ds = HFVQADataset(HF_DATASET_NAME, HF_TEST_SPLIT, HF_IMAGE_COL, HF_QUESTION_COL, HF_ANSWER_COL, normalize=True)
    else:
        if DATA_DIR is None or not os.path.isdir(DATA_DIR):
            raise ValueError("Set USE_HF_DATASET=True or provide valid DATA_DIR for local JSONL.")
        image_root = IMAGE_ROOT if IMAGE_ROOT is not None else DATA_DIR
        train_ds = JsonlVQADataset(os.path.join(DATA_DIR, TRAIN_JSONL), image_root=image_root, normalize=True)
        val_ds = JsonlVQADataset(os.path.join(DATA_DIR, VAL_JSONL), image_root=image_root, normalize=True)
        test_ds = JsonlVQADataset(os.path.join(DATA_DIR, TEST_JSONL), image_root=image_root, normalize=True)

    print(f"Train={len(train_ds)}  Val={len(val_ds)}  Test={len(test_ds)}")

    collate_fn = make_collate_fn(processor, max_q_len=max_q_len, max_a_len=max_a_len)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
    )

    # ---------- CPU 预检：抓 index 越界（强烈建议先开） ----------
    if CPU_PREFLIGHT_CHECK:
        print("[Preflight] Running 1 batch on CPU to catch index errors...")
        model.eval()
        batch = next(iter(train_loader))
        with torch.no_grad():
            # CPU 上做一次范围检查
            input_ids = batch.input_ids
            labels = batch.labels

            # 检查 input_ids 范围（注意：用 len(tok)，不是 tok.vocab_size）
            V = len(tok)
            assert int(input_ids.min()) >= 0
            assert int(input_ids.max()) < V
            assert input_ids.shape[1] <= max_pos

            # 检查 labels（忽略 -100）
            mask = labels != -100
            if mask.any():
                lb_min = int(labels[mask].min())
                lb_max = int(labels[mask].max())
                assert lb_min >= 0
                assert lb_max < V

            # 如果 forward 支持 decoder_input_ids，我们自己构造（绕开内部 shift 坑）
            fwd_params = inspect.signature(model.forward).parameters
            supports_decoder = "decoder_input_ids" in fwd_params

            kwargs = dict(
                pixel_values=batch.pixel_values,
                input_ids=batch.input_ids,
                attention_mask=batch.attention_mask,
                labels=batch.labels,
            )
            if supports_decoder:
                dec = batch.labels.clone()
                dec[dec == -100] = pad_id
                dec = torch.roll(dec, shifts=1, dims=1)
                dec[:, 0] = bos_id
                dec_attn = (dec != pad_id).long()
                kwargs["decoder_input_ids"] = dec
                if "decoder_attention_mask" in fwd_params:
                    kwargs["decoder_attention_mask"] = dec_attn

            _ = model(**kwargs)

        print("[Preflight] OK on CPU.")

    # 移到 GPU
    model.to(device)
    model.config.use_cache = False

    # 约束解码（只用于 eval）
    prefix_allowed_tokens_fn = None
    seq2ans = None
    trie_max_len = max_a_len
    if CONSTRAINED_DECODE:
        ans_list = build_answer_set(train_ds, ANSWER_VOCAB_SIZE)
        trie, seq2ans, trie_max_len = build_trie_and_lookup(ans_list, tok)
        prefix_allowed_tokens_fn = make_prefix_allowed_tokens_fn(trie, bos_id=bos_id, eos_id=eos_id)
        print(f"[ConstrainedDecode] answer_set={len(ans_list)} trie_max_len={trie_max_len}")

    # optimizer / scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    total_steps = math.ceil(len(train_loader) / max(1, GRAD_ACCUM)) * EPOCHS
    warmup_steps = int(total_steps * WARMUP_RATIO)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    scaler = torch.cuda.amp.GradScaler(enabled=USE_FP16)

    best_val = -1.0
    global_step = 0

    # 保存 config
    with open(os.path.join(OUTPUT_DIR, "config.json"), "w", encoding="utf-8") as f:
        cfg = {k: v for k, v in globals().items() if k.isupper()}
        json.dump(cfg, f, ensure_ascii=False, indent=2)

    # forward 是否支持 decoder_input_ids（动态兼容 transformers 版本）
    fwd_params = inspect.signature(model.forward).parameters
    supports_decoder = "decoder_input_ids" in fwd_params
    supports_decoder_attn = "decoder_attention_mask" in fwd_params

    for epoch in range(1, EPOCHS + 1):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        running = 0.0

        pbar = tqdm(train_loader, desc=f"train epoch {epoch}", leave=False)
        for step, batch in enumerate(pbar, start=1):
            pixel_values = batch.pixel_values.to(device, non_blocking=True)
            input_ids = batch.input_ids.to(device, non_blocking=True)
            attention_mask = batch.attention_mask.to(device, non_blocking=True)
            labels = batch.labels.to(device, non_blocking=True)

            # ---- 训练前几个 step 做硬断言，定位越界 ----
            if DEBUG_SANITY_STEPS > 0 and global_step < DEBUG_SANITY_STEPS:
                V = len(tok)
                print(f"[SANITY step={global_step}] q_len={input_ids.shape[1]} "
                      f"input_ids(min,max)=({int(input_ids.min())},{int(input_ids.max())}) V={V} max_pos={max_pos}")
                assert input_ids.shape[1] <= max_pos
                assert int(input_ids.min()) >= 0
                assert int(input_ids.max()) < V
                mask = labels != -100
                if mask.any():
                    lb_min = int(labels[mask].min())
                    lb_max = int(labels[mask].max())
                    print(f"[SANITY step={global_step}] labels(min,max)={lb_min},{lb_max}")
                    assert lb_min >= 0
                    assert lb_max < V

            kwargs = dict(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            # ---- 如果支持 decoder_input_ids：自己构造，避免内部 shift -100 坑 ----
            if supports_decoder:
                dec = labels.clone()
                dec[dec == -100] = pad_id
                dec = torch.roll(dec, shifts=1, dims=1)
                dec[:, 0] = bos_id
                kwargs["decoder_input_ids"] = dec
                if supports_decoder_attn:
                    kwargs["decoder_attention_mask"] = (dec != pad_id).long()

            with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
                out = model(**kwargs)
                loss = out.loss / max(1, GRAD_ACCUM)

            if USE_FP16:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            running += loss.item() * max(1, GRAD_ACCUM)

            if step % GRAD_ACCUM == 0:
                if USE_FP16:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    optimizer.step()

                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                global_step += 1

                if global_step % LOG_EVERY == 0:
                    pbar.set_postfix(loss=f"{running / max(1, global_step):.4f}")

        # Validate
        val_metrics = evaluate(
            model, processor, val_loader, device, use_amp, amp_dtype,
            CONSTRAINED_DECODE, prefix_allowed_tokens_fn, seq2ans,
            max_gen_len=min(max_a_len, trie_max_len),
            num_beams=NUM_BEAMS,
        )
        print(f"[Epoch {epoch}] VAL acc={val_metrics['acc']:.4f}  yes/no={val_metrics['acc_yesno']:.4f}  open={val_metrics['acc_open']:.4f}  n={val_metrics['n']}")

        # Save best
        if val_metrics["acc"] > best_val:
            best_val = val_metrics["acc"]
            best_dir = os.path.join(OUTPUT_DIR, "best")
            ensure_dir(best_dir)
            model.save_pretrained(best_dir)
            processor.save_pretrained(best_dir)
            with open(os.path.join(best_dir, "val_metrics.json"), "w", encoding="utf-8") as f:
                json.dump(val_metrics, f, ensure_ascii=False, indent=2)
            print(f"  -> saved BEST to {best_dir}")

    # Test best
    print("\nLoading BEST checkpoint for TEST...")
    best_dir = os.path.join(OUTPUT_DIR, "best")
    model = BlipForQuestionAnswering.from_pretrained(best_dir).to(device)
    processor = BlipProcessor.from_pretrained(best_dir)

    test_metrics = evaluate(
        model, processor, test_loader, device, use_amp, amp_dtype,
        CONSTRAINED_DECODE, prefix_allowed_tokens_fn, seq2ans,
        max_gen_len=min(max_a_len, trie_max_len),
        num_beams=NUM_BEAMS,
    )
    print(f"[BEST] TEST acc={test_metrics['acc']:.4f}  yes/no={test_metrics['acc_yesno']:.4f}  open={test_metrics['acc_open']:.4f}  n={test_metrics['n']}")

    with open(os.path.join(OUTPUT_DIR, "test_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(test_metrics, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    main()


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Device=cuda  AMP=True  dtype=torch.float16
[Len] max_pos=512, use max_q_len=32, max_a_len=8
[Tokens] pad=0 bos=101 eos=102 vocab(len(tokenizer))=30522
Train=19654  Val=6259  Test=6719
[Preflight] Running 1 batch on CPU to catch index errors...
[Preflight] OK on CPU.


  scaler = torch.cuda.amp.GradScaler(enabled=USE_FP16)


[ConstrainedDecode] answer_set=3225 trie_max_len=55


train epoch 1:   0%|          | 0/4914 [00:00<?, ?it/s]

[SANITY step=0] q_len=16 input_ids(min,max)=(0,27179) V=30522 max_pos=512
[SANITY step=0] labels(min,max)=101,24471


  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
train epoch 1:   0%|          | 1/4914 [00:00<59:00,  1.39it/s]

[SANITY step=0] q_len=12 input_ids(min,max)=(0,16464) V=30522 max_pos=512
[SANITY step=0] labels(min,max)=101,11467


train epoch 1:   0%|          | 3/4914 [00:01<27:44,  2.95it/s]

[SANITY step=1] q_len=31 input_ids(min,max)=(0,25032) V=30522 max_pos=512
[SANITY step=1] labels(min,max)=101,28873
[SANITY step=1] q_len=22 input_ids(min,max)=(0,22520) V=30522 max_pos=512
[SANITY step=1] labels(min,max)=101,21716


train epoch 1:   0%|          | 5/4914 [00:01<18:31,  4.42it/s]

[SANITY step=2] q_len=32 input_ids(min,max)=(0,19466) V=30522 max_pos=512
[SANITY step=2] labels(min,max)=101,20360
[SANITY step=2] q_len=26 input_ids(min,max)=(0,26775) V=30522 max_pos=512
[SANITY step=2] labels(min,max)=101,29181


train epoch 1:   0%|          | 7/4914 [00:01<15:25,  5.30it/s]

[SANITY step=3] q_len=12 input_ids(min,max)=(0,21101) V=30522 max_pos=512
[SANITY step=3] labels(min,max)=101,22935
[SANITY step=3] q_len=21 input_ids(min,max)=(0,19962) V=30522 max_pos=512
[SANITY step=3] labels(min,max)=101,2748


train epoch 1:   0%|          | 9/4914 [00:02<14:13,  5.74it/s]

[SANITY step=4] q_len=10 input_ids(min,max)=(101,5648) V=30522 max_pos=512
[SANITY step=4] labels(min,max)=101,16464
[SANITY step=4] q_len=17 input_ids(min,max)=(0,27584) V=30522 max_pos=512
[SANITY step=4] labels(min,max)=101,25147


  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
                                                         

[Epoch 1] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259
  -> saved BEST to ./runs/blip_pathvqa_best/best


                                                                               

[Epoch 2] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


                                                                               

[Epoch 3] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


                                                                               

[Epoch 4] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


                                                                               

[Epoch 5] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


                                                                               

[Epoch 6] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


                                                                               

[Epoch 7] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


                                                                               

[Epoch 8] VAL acc=0.0000  yes/no=0.0000  open=0.0000  n=6259


train epoch 9:  20%|██        | 993/4914 [02:24<09:21,  6.98it/s, loss=0.0137]

In [None]:
print("DEBUG HF_DATASET_NAME =", HF_DATASET_NAME)
print("DEBUG DATA_DIR =", DATA_DIR)


In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

# ============ 环境稳态（建议放最前）============
import os
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "0")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
# 如需强制同步定位 CUDA 报错：取消注释（会明显变慢）
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

import json, math, random, re, inspect, warnings
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple
from collections import Counter

import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from PIL import Image, ImageFile

from transformers import (
    BlipProcessor,
    BlipForQuestionAnswering,
    get_linear_schedule_with_warmup,
)

# 允许加载截断图（PathVQA/某些 HF 数据里偶尔会有）
ImageFile.LOAD_TRUNCATED_IMAGES = True
warnings.filterwarnings("ignore", message="Truncated File Read")

# =========================================================
# 0) CONFIG（只改这里）
# =========================================================

# 数据：HuggingFace PathVQA
USE_HF_DATASET = True
HF_DATASET_NAME = "flaviagiammarino/path-vqa"
HF_TRAIN_SPLIT = "train"
HF_VAL_SPLIT = "validation"
HF_TEST_SPLIT = "test"
HF_IMAGE_COL = "image"
HF_QUESTION_COL = "question"
HF_ANSWER_COL = "answer"

# 若用本地 JSONL（可选）
DATA_DIR = None
IMAGE_ROOT = None
TRAIN_JSONL = "train.jsonl"
VAL_JSONL = "val.jsonl"
TEST_JSONL = "test.jsonl"

# 模型
MODEL_NAME = "Salesforce/blip-vqa-base"

# 输出
OUTPUT_DIR = "./runs/blip_pathvqa_best"
BEST_SUBDIR = "best"

# 训练超参
SEED = 42
EPOCHS = 10
BATCH_SIZE = 4
GRAD_ACCUM = 2
LR = 2e-5
WEIGHT_DECAY = 0.01
WARMUP_RATIO = 0.06
MAX_GRAD_NORM = 1.0

# 文本长度
MAX_QUESTION_LEN = 32
MAX_ANSWER_LEN = 8   # teacher forcing 标签长度（训练）
MAX_GEN_LEN = 8      # 推理生成长度（评估）；建议与 MAX_ANSWER_LEN 保持一致

USE_PROMPT = True

# DataLoader
NUM_WORKERS = 0         # 共享/Notebook 环境最稳；跑稳后可改 2/4
PIN_MEMORY = True

# 混合精度
USE_FP16 = True
USE_BF16 = False

# 评估生成参数（对齐你的独立评测脚本）
NUM_BEAMS = 3

# 是否在训练中打印一些 val 样例（对齐独立评测的打印风格）
EVAL_PRINT_SAMPLES = 10

# 选择保存 best 的指标（对齐“准确率第一”的诉求）
# 可选: "exact_acc", "token_f1", "open_exact_acc", "yesno_acc"
BEST_METRIC = "exact_acc"

# 评估是否启用 constrained decoding（默认关：完全对齐你独立评测）
EVAL_CONSTRAINED_DECODE = False
ANSWER_VOCAB_SIZE = -1   # constrained decode 用；-1=所有训练答案

# Debug/预检
CPU_PREFLIGHT_CHECK = True
DEBUG_SANITY_STEPS = 3   # 前几个 step 做 token 范围断言

# =========================================================
# 1) 评估指标（完全对齐你独立评测脚本）
# =========================================================

def normalize_answer(s: str) -> str:
    if s is None:
        return ""
    s = str(s).lower()
    s = re.sub(r"[^a-z0-9\s]", " ", s)
    s = re.sub(r"\b(the|a|an)\b", " ", s)
    s = re.sub(r"\s+", " ", s).strip()
    return s

def token_f1(pred: str, gt: str) -> float:
    p = pred.split()
    g = gt.split()
    if len(p) == 0 or len(g) == 0:
        return 0.0
    common = set(p) & set(g)
    if len(common) == 0:
        return 0.0
    prec = len(common) / len(p)
    rec = len(common) / len(g)
    return 2 * prec * rec / (prec + rec)

# =========================================================
# 2) utils
# =========================================================

def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def ensure_dir(path: str) -> None:
    os.makedirs(path, exist_ok=True)

# =========================================================
# 3) dataset
# =========================================================

class HFVQADataset(Dataset):
    """
    注意：返回 (image, question, answer_raw)
    answer 不做 normalize（训练要用原始答案 tokenization）。
    """
    def __init__(self, name: str, split: str, image_col: str, q_col: str, a_col: str):
        from datasets import load_dataset
        self.ds = load_dataset(name, split=split)
        self.image_col = image_col
        self.q_col = q_col
        self.a_col = a_col

        for c in [image_col, q_col, a_col]:
            if c not in self.ds.column_names:
                raise ValueError(f"Column '{c}' not in {self.ds.column_names}")

        # 缓存 raw answers，构 trie 时不会把图片全加载一遍
        self._answers_raw = [str(x) for x in list(self.ds[a_col])]

    def get_answers_raw(self) -> List[str]:
        return self._answers_raw

    def __len__(self) -> int:
        return len(self.ds)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        ex = self.ds[idx]
        img = ex[self.image_col]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img).convert("RGB")
        q = str(ex[self.q_col])
        a = self._answers_raw[idx]
        return img.convert("RGB"), q, a


class JsonlVQADataset(Dataset):
    def __init__(self, jsonl_path: str, image_root: Optional[str]):
        if not os.path.isfile(jsonl_path):
            raise FileNotFoundError(f"JSONL not found: {jsonl_path}")
        self.samples: List[Dict[str, Any]] = []
        self.image_root = image_root
        self._answers_raw: List[str] = []

        with open(jsonl_path, "r", encoding="utf-8") as f:
            for i, line in enumerate(f):
                line = line.strip()
                if not line:
                    continue
                ex = json.loads(line)
                for k in ["image", "question", "answer"]:
                    if k not in ex:
                        raise ValueError(f"{jsonl_path} line {i} missing key '{k}', got keys={list(ex.keys())}")
                self.samples.append(ex)
                self._answers_raw.append(str(ex["answer"]))

    def get_answers_raw(self) -> List[str]:
        return self._answers_raw

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[Image.Image, str, str]:
        ex = self.samples[idx]
        img_path = ex["image"]
        if self.image_root and not os.path.isabs(img_path):
            img_path = os.path.join(self.image_root, img_path)
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")

        img = Image.open(img_path).convert("RGB")
        q = str(ex["question"])
        a = str(ex["answer"])
        return img, q, a

# =========================================================
# 4) constrained decoding（可选）
# =========================================================

class TrieNode:
    __slots__ = ("children", "is_end")
    def __init__(self):
        self.children: Dict[int, "TrieNode"] = {}
        self.is_end: bool = False

class Trie:
    def __init__(self):
        self.root = TrieNode()

    def insert(self, token_ids: List[int]) -> None:
        node = self.root
        for t in token_ids:
            if t not in node.children:
                node.children[t] = TrieNode()
            node = node.children[t]
        node.is_end = True

    def next_tokens(self, prefix: List[int]) -> Tuple[List[int], bool]:
        node = self.root
        for t in prefix:
            if t not in node.children:
                return [], False
            node = node.children[t]
        return list(node.children.keys()), node.is_end

def build_answer_set_raw(ds: Dataset, max_answers: int) -> List[str]:
    if hasattr(ds, "get_answers_raw"):
        answers = ds.get_answers_raw()
    else:
        answers = [ds[i][2] for i in range(len(ds))]
    cnt = Counter(answers)
    if max_answers is not None and max_answers > 0:
        ans_list = [a for a, _ in cnt.most_common(max_answers)]
    else:
        ans_list = list(cnt.keys())
    # 强制包含 yes/no（raw 里可能大小写不一致，但无所谓）
    for a in ["yes", "no", "Yes", "No"]:
        if a not in ans_list:
            ans_list.append(a)
    return ans_list

def build_trie(answer_list: List[str], tokenizer) -> Trie:
    trie = Trie()
    for ans in answer_list:
        ids = tokenizer(ans, add_special_tokens=False).input_ids
        if ids:
            trie.insert(ids)
    return trie

def make_prefix_allowed_tokens_fn(trie: Trie, bos_id: Optional[int], eos_id: int):
    def _fn(batch_id: int, input_ids: torch.LongTensor) -> List[int]:
        ids = input_ids.tolist()
        if bos_id is not None and len(ids) > 0 and ids[0] == bos_id:
            prefix = ids[1:]
        else:
            prefix = ids
        nxt, is_end = trie.next_tokens(prefix)
        allowed = list(nxt)
        if is_end:
            allowed.append(eos_id)
        if not allowed:
            allowed = [eos_id]
        return allowed
    return _fn

# =========================================================
# 5) collate（图像/文本分开处理，确保 truncation 生效）
# =========================================================

@dataclass
class Batch:
    pixel_values: torch.FloatTensor
    input_ids: torch.LongTensor
    attention_mask: torch.LongTensor
    labels: torch.LongTensor
    answers_raw: List[str]
    questions_raw: List[str]

def make_collate_fn(processor: BlipProcessor, max_q_len: int, max_a_len: int):
    tok = processor.tokenizer

    def _collate(examples: List[Tuple[Image.Image, str, str]]) -> Batch:
        images, questions, answers = zip(*examples)

        questions_in = list(questions)
        if USE_PROMPT:
            questions_in = [f"Question: {q} Answer:" for q in questions_in]

        # 图像
        pixel = processor(images=list(images), return_tensors="pt")

        # 问题（强制 tokenizer 截断）
        txt = tok(
            questions_in,
            padding=True,
            truncation=True,
            max_length=max_q_len,
            return_tensors="pt",
        )

        # 答案 labels（teacher forcing）
        ans_tok = tok(
            list(answers),
            padding=True,
            truncation=True,
            max_length=max_a_len,
            return_tensors="pt",
            add_special_tokens=True,
        )
        labels = ans_tok.input_ids.clone()
        labels[labels == tok.pad_token_id] = -100

        return Batch(
            pixel_values=pixel["pixel_values"],
            input_ids=txt["input_ids"],
            attention_mask=txt["attention_mask"],
            labels=labels,
            answers_raw=list(answers),
            questions_raw=list(questions),
        )

    return _collate

# =========================================================
# 6) decode + evaluate（对齐独立评测）
# =========================================================

def decode_one(gen_ids_1d: List[int], tokenizer, bos_id, eos_id, pad_id) -> str:
    ids = list(gen_ids_1d)
    if bos_id is not None and len(ids) > 0 and ids[0] == bos_id:
        ids = ids[1:]
    if eos_id is not None and eos_id in ids:
        ids = ids[:ids.index(eos_id)]
    if pad_id is not None:
        ids = [t for t in ids if t != pad_id]
    return tokenizer.decode(ids, skip_special_tokens=True)

@torch.no_grad()
def evaluate_generation(
    model,
    processor,
    loader,
    device,
    use_amp,
    amp_dtype,
    max_gen_len: int,
    num_beams: int,
    prefix_allowed_tokens_fn=None,
    print_samples: int = 0,
):
    tok = processor.tokenizer
    pad_id = tok.pad_token_id
    bos_id = tok.cls_token_id if tok.cls_token_id is not None else tok.bos_token_id
    eos_id = tok.sep_token_id if tok.sep_token_id is not None else tok.eos_token_id

    total = 0
    exact = 0
    f1_sum = 0.0

    yesno_total = 0
    yesno_correct = 0

    open_total = 0
    open_exact = 0
    open_f1_sum = 0.0

    printed = 0

    model.eval()
    for batch in tqdm(loader, desc="eval(gen)", leave=False):
        pixel_values = batch.pixel_values.to(device, non_blocking=True)
        input_ids = batch.input_ids.to(device, non_blocking=True)
        attention_mask = batch.attention_mask.to(device, non_blocking=True)

        gen_kwargs = dict(
            max_length=1 + max_gen_len,
            num_beams=num_beams,
            do_sample=False,
        )
        if prefix_allowed_tokens_fn is not None:
            gen_kwargs["prefix_allowed_tokens_fn"] = prefix_allowed_tokens_fn

        with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
            gen_out = model.generate(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                **gen_kwargs,
            )

        # per-sample
        for seq, gt_raw, q_raw in zip(gen_out, batch.answers_raw, batch.questions_raw):
            pred_raw = decode_one(seq.tolist(), tok, bos_id, eos_id, pad_id)

            pred = normalize_answer(pred_raw)
            gt = normalize_answer(gt_raw)

            total += 1
            if pred == gt:
                exact += 1
            f1_sum += token_f1(pred, gt)

            if gt in ["yes", "no"]:
                yesno_total += 1
                if pred == gt:
                    yesno_correct += 1
            else:
                open_total += 1
                if pred == gt:
                    open_exact += 1
                open_f1_sum += token_f1(pred, gt)

            if printed < print_samples:
                print("=" * 60)
                print("Q :", q_raw)
                print("GT:", gt)
                print("PR:", pred)
                printed += 1

    metrics = {
        "exact_acc": exact / max(1, total),
        "token_f1": f1_sum / max(1, total),
        "yesno_acc": yesno_correct / max(1, yesno_total),
        "open_exact_acc": open_exact / max(1, open_total),
        "open_token_f1": open_f1_sum / max(1, open_total),
        "n": total,
        "n_yesno": yesno_total,
        "n_open": open_total,
    }
    return metrics

# =========================================================
# 7) main
# =========================================================

def main():
    if USE_FP16 and USE_BF16:
        raise ValueError("Choose only one: USE_FP16 or USE_BF16")

    set_seed(SEED)
    ensure_dir(OUTPUT_DIR)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = USE_FP16 or USE_BF16
    amp_dtype = torch.float16 if USE_FP16 else (torch.bfloat16 if USE_BF16 else torch.float32)
    print(f"Device={device}  AMP={use_amp}  dtype={amp_dtype}")

    # 先 CPU 加载，做 pad/config 修复 + preflight
    processor = BlipProcessor.from_pretrained(MODEL_NAME)
    model = BlipForQuestionAnswering.from_pretrained(MODEL_NAME)
    tok = processor.tokenizer

    # ---- 确保 pad_token 存在（避免 labels shift 的坑）
    if tok.pad_token is None:
        tok.add_special_tokens({"pad_token": "[PAD]"})
        model.resize_token_embeddings(len(tok))

    pad_id = tok.pad_token_id
    bos_id = tok.cls_token_id if tok.cls_token_id is not None else tok.bos_token_id
    eos_id = tok.sep_token_id if tok.sep_token_id is not None else tok.eos_token_id
    if bos_id is None or eos_id is None:
        raise ValueError("Cannot determine BOS/EOS token id for BLIP tokenizer.")

    # 同步 config（很关键）
    model.config.pad_token_id = pad_id
    model.config.bos_token_id = bos_id
    model.config.eos_token_id = eos_id
    model.config.decoder_start_token_id = bos_id
    if getattr(model, "generation_config", None) is not None:
        model.generation_config.pad_token_id = pad_id
        model.generation_config.bos_token_id = bos_id
        model.generation_config.eos_token_id = eos_id

    # 长度上限（避免 position embedding 越界）
    max_pos = int(model.text_encoder.config.max_position_embeddings)
    max_q_len = min(MAX_QUESTION_LEN, max_pos)
    max_a_len = min(MAX_ANSWER_LEN, max_pos)
    print(f"[Len] max_pos={max_pos}, use max_q_len={max_q_len}, max_a_len={max_a_len}, max_gen_len={MAX_GEN_LEN}")
    print(f"[Tokens] pad={pad_id} bos={bos_id} eos={eos_id} vocab(len(tok))={len(tok)}")

    # Load dataset
    if USE_HF_DATASET:
        train_ds = HFVQADataset(HF_DATASET_NAME, HF_TRAIN_SPLIT, HF_IMAGE_COL, HF_QUESTION_COL, HF_ANSWER_COL)
        val_ds = HFVQADataset(HF_DATASET_NAME, HF_VAL_SPLIT, HF_IMAGE_COL, HF_QUESTION_COL, HF_ANSWER_COL)
        test_ds = HFVQADataset(HF_DATASET_NAME, HF_TEST_SPLIT, HF_IMAGE_COL, HF_QUESTION_COL, HF_ANSWER_COL)
    else:
        if DATA_DIR is None or not os.path.isdir(DATA_DIR):
            raise ValueError("Set USE_HF_DATASET=True or provide valid DATA_DIR for local JSONL.")
        image_root = IMAGE_ROOT if IMAGE_ROOT is not None else DATA_DIR
        train_ds = JsonlVQADataset(os.path.join(DATA_DIR, TRAIN_JSONL), image_root=image_root)
        val_ds = JsonlVQADataset(os.path.join(DATA_DIR, VAL_JSONL), image_root=image_root)
        test_ds = JsonlVQADataset(os.path.join(DATA_DIR, TEST_JSONL), image_root=image_root)

    print(f"Train={len(train_ds)}  Val={len(val_ds)}  Test={len(test_ds)}")

    collate_fn = make_collate_fn(processor, max_q_len=max_q_len, max_a_len=max_a_len)

    train_loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
    )
    val_loader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_ds, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY,
        collate_fn=collate_fn,
    )

    # ---------- CPU 预检：抓 index 越界 ----------
    if CPU_PREFLIGHT_CHECK:
        print("[Preflight] Running 1 batch on CPU...")
        model.eval()
        batch = next(iter(train_loader))
        with torch.no_grad():
            V = len(tok)
            assert int(batch.input_ids.min()) >= 0
            assert int(batch.input_ids.max()) < V
            mask = batch.labels != -100
            if mask.any():
                assert int(batch.labels[mask].min()) >= 0
                assert int(batch.labels[mask].max()) < V

            fwd_params = inspect.signature(model.forward).parameters
            supports_decoder = "decoder_input_ids" in fwd_params

            kwargs = dict(
                pixel_values=batch.pixel_values,
                input_ids=batch.input_ids,
                attention_mask=batch.attention_mask,
                labels=batch.labels,
            )
            if supports_decoder:
                dec = batch.labels.clone()
                dec[dec == -100] = pad_id
                dec = torch.roll(dec, shifts=1, dims=1)
                dec[:, 0] = bos_id
                kwargs["decoder_input_ids"] = dec
                if "decoder_attention_mask" in fwd_params:
                    kwargs["decoder_attention_mask"] = (dec != pad_id).long()

            _ = model(**kwargs)

        print("[Preflight] OK.")

    # 评估：可选 constrained decoding（与独立评测默认保持一致=关）
    prefix_allowed_tokens_fn = None
    if EVAL_CONSTRAINED_DECODE:
        ans_list = build_answer_set_raw(train_ds, ANSWER_VOCAB_SIZE)
        trie = build_trie(ans_list, tok)
        prefix_allowed_tokens_fn = make_prefix_allowed_tokens_fn(trie, bos_id=bos_id, eos_id=eos_id)
        print(f"[Eval ConstrainedDecode] answer_set={len(ans_list)}")

    # Move to GPU
    model.to(device)
    model.config.use_cache = False

    # optimizer / scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
    total_steps = math.ceil(len(train_loader) / max(1, GRAD_ACCUM)) * EPOCHS
    warmup_steps = int(total_steps * WARMUP_RATIO)
    scheduler = get_linear_schedule_with_warmup(optimizer, warmup_steps, total_steps)

    scaler = torch.cuda.amp.GradScaler(enabled=USE_FP16)

    # forward 是否支持 decoder_input_ids
    fwd_params = inspect.signature(model.forward).parameters
    supports_decoder = "decoder_input_ids" in fwd_params
    supports_decoder_attn = "decoder_attention_mask" in fwd_params

    # 保存 config
    with open(os.path.join(OUTPUT_DIR, "config.json"), "w", encoding="utf-8") as f:
        cfg = {k: v for k, v in globals().items() if k.isupper()}
        json.dump(cfg, f, ensure_ascii=False, indent=2)

    best_score = -1e9
    best_metrics = None
    global_step = 0

    for epoch in range(1, EPOCHS + 1):
        model.train()
        optimizer.zero_grad(set_to_none=True)
        running_loss = 0.0

        pbar = tqdm(train_loader, desc=f"train epoch {epoch}", leave=False)
        for step, batch in enumerate(pbar, start=1):
            pixel_values = batch.pixel_values.to(device, non_blocking=True)
            input_ids = batch.input_ids.to(device, non_blocking=True)
            attention_mask = batch.attention_mask.to(device, non_blocking=True)
            labels = batch.labels.to(device, non_blocking=True)

            # sanity（只检查前几个 global step）
            if DEBUG_SANITY_STEPS > 0 and global_step < DEBUG_SANITY_STEPS:
                V = len(tok)
                assert int(input_ids.min()) >= 0 and int(input_ids.max()) < V

            kwargs = dict(
                pixel_values=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )

            # 自己构造 decoder_input_ids（规避内部 shift 的坑）
            if supports_decoder:
                dec = labels.clone()
                dec[dec == -100] = pad_id
                dec = torch.roll(dec, shifts=1, dims=1)
                dec[:, 0] = bos_id
                kwargs["decoder_input_ids"] = dec
                if supports_decoder_attn:
                    kwargs["decoder_attention_mask"] = (dec != pad_id).long()

            with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
                out = model(**kwargs)
                loss = out.loss / max(1, GRAD_ACCUM)

            if USE_FP16:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            running_loss += loss.item() * max(1, GRAD_ACCUM)

            if step % GRAD_ACCUM == 0:
                if USE_FP16:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)
                    optimizer.step()

                scheduler.step()
                optimizer.zero_grad(set_to_none=True)
                global_step += 1

                if global_step % 50 == 0:
                    pbar.set_postfix(loss=f"{running_loss / max(1, global_step):.4f}")

        # ====== VAL EVAL（对齐独立评测逻辑）======
        val_metrics = evaluate_generation(
            model, processor, val_loader, device,
            use_amp=use_amp, amp_dtype=amp_dtype,
            max_gen_len=MAX_GEN_LEN, num_beams=NUM_BEAMS,
            prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
            print_samples=EVAL_PRINT_SAMPLES if epoch == 1 else 0,
        )

        print(
            f"[Epoch {epoch}] VAL "
            f"exact={val_metrics['exact_acc']:.4f} "
            f"f1={val_metrics['token_f1']:.4f} "
            f"yesno={val_metrics['yesno_acc']:.4f} "
            f"open_exact={val_metrics['open_exact_acc']:.4f} "
            f"open_f1={val_metrics['open_token_f1']:.4f} "
            f"(n={val_metrics['n']}, yesno={val_metrics['n_yesno']}, open={val_metrics['n_open']})"
        )

        # 保存 best
        score = float(val_metrics.get(BEST_METRIC, -1e9))
        if score > best_score:
            best_score = score
            best_metrics = val_metrics

            best_dir = os.path.join(OUTPUT_DIR, BEST_SUBDIR)
            ensure_dir(best_dir)
            model.save_pretrained(best_dir)
            processor.save_pretrained(best_dir)

            with open(os.path.join(best_dir, "val_metrics.json"), "w", encoding="utf-8") as f:
                json.dump(val_metrics, f, ensure_ascii=False, indent=2)

            print(f"  -> saved BEST by {BEST_METRIC}={best_score:.4f} to {best_dir}")

    # ====== TEST BEST ======
    print("\nLoading BEST checkpoint for TEST...")
    best_dir = os.path.join(OUTPUT_DIR, BEST_SUBDIR)
    model = BlipForQuestionAnswering.from_pretrained(best_dir).to(device)
    processor = BlipProcessor.from_pretrained(best_dir)

    test_metrics = evaluate_generation(
        model, processor, test_loader, device,
        use_amp=use_amp, amp_dtype=amp_dtype,
        max_gen_len=MAX_GEN_LEN, num_beams=NUM_BEAMS,
        prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
        print_samples=EVAL_PRINT_SAMPLES,
    )

    print(
        f"[BEST] TEST "
        f"exact={test_metrics['exact_acc']:.4f} "
        f"f1={test_metrics['token_f1']:.4f} "
        f"yesno={test_metrics['yesno_acc']:.4f} "
        f"open_exact={test_metrics['open_exact_acc']:.4f} "
        f"open_f1={test_metrics['open_token_f1']:.4f} "
        f"(n={test_metrics['n']}, yesno={test_metrics['n_yesno']}, open={test_metrics['n_open']})"
    )

    with open(os.path.join(OUTPUT_DIR, "test_metrics.json"), "w", encoding="utf-8") as f:
        json.dump(test_metrics, f, ensure_ascii=False, indent=2)

if __name__ == "__main__":
    main()


Device=cuda  AMP=True  dtype=torch.float16
[Len] max_pos=512, use max_q_len=32, max_a_len=8, max_gen_len=8
[Tokens] pad=0 bos=101 eos=102 vocab(len(tok))=30522
Train=19654  Val=6259  Test=6719
[Preflight] Running 1 batch on CPU...
[Preflight] OK.


  scaler = torch.cuda.amp.GradScaler(enabled=USE_FP16)
  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
  with torch.cuda.amp.autocast(enabled=use_amp, dtype=amp_dtype):
eval(gen):   0%|          | 2/1565 [00:00<03:19,  7.82it/s]

Q : what have lost their nuclei?
GT: neutrophils
PR: cells cells of
Q : whose nuclei have been lost?
GT: neutrophils
PR: one
Q : are two small pulmonary arterioles packed with laminated swirls of fetal squamous cells?
GT: yes
PR: yes
Q : what is acute viral hepatitis characterized by?
GT: predominantly lymphocytic infiltrate
PR: cells cells of
Q : what do the cells have?
GT: wavy nuclei
PR: granmas
Q : do the cells have wavy nuclei?
GT: yes
PR: yes
Q : do individual myocardial fibres have wavy nuclei?
GT: no
PR: yes
Q : where is this area in the body?
GT: abdomen
PR: abdomen


eval(gen):   0%|          | 4/1565 [00:00<03:46,  6.90it/s]

Q : what does this image show?
GT: peritoneal carcinomatosis
PR: close up close up view of
Q : does this image show peritoneal carcinomatosis, metastatic tumor covering all of the abdominal viscera?
GT: yes
PR: yes


                                                              

[Epoch 1] VAL exact=0.4459 f1=0.4667 yesno=0.8166 open_exact=0.0763 open_f1=0.1178 (n=6259, yesno=3125, open=3134)
  -> saved BEST by exact_acc=0.4459 to ./runs/blip_pathvqa_best/best


                                                                               

[Epoch 2] VAL exact=0.4907 f1=0.5133 yesno=0.8627 open_exact=0.1197 open_f1=0.1649 (n=6259, yesno=3125, open=3134)
  -> saved BEST by exact_acc=0.4907 to ./runs/blip_pathvqa_best/best


                                                                               

[Epoch 3] VAL exact=0.4918 f1=0.5168 yesno=0.8637 open_exact=0.1209 open_f1=0.1710 (n=6259, yesno=3125, open=3134)
  -> saved BEST by exact_acc=0.4918 to ./runs/blip_pathvqa_best/best


                                                                               

[Epoch 4] VAL exact=0.5001 f1=0.5264 yesno=0.8774 open_exact=0.1238 open_f1=0.1764 (n=6259, yesno=3125, open=3134)
  -> saved BEST by exact_acc=0.5001 to ./runs/blip_pathvqa_best/best


                                                                               

[Epoch 5] VAL exact=0.4967 f1=0.5250 yesno=0.8746 open_exact=0.1200 open_f1=0.1765 (n=6259, yesno=3125, open=3134)


                                                                               

[Epoch 6] VAL exact=0.5113 f1=0.5431 yesno=0.8810 open_exact=0.1426 open_f1=0.2061 (n=6259, yesno=3125, open=3134)
  -> saved BEST by exact_acc=0.5113 to ./runs/blip_pathvqa_best/best


train epoch 7:   1%|▏         | 68/4914 [00:09<11:17,  7.15it/s, loss=0.0003]