In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
torch.cuda.empty_cache()
!export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True


In [2]:
# 清理掉 ~/.cache/huggingface 里所有 datasets、hub、modules 的旧缓存
!rm -rf ~/.cache/huggingface/datasets
!rm -rf ~/.cache/huggingface/hub
!rm -rf ~/.cache/huggingface/modules


In [3]:
# 基础库
!pip install --upgrade transformers datasets peft evaluate qwen-vl-utils



# 1) 卸载可能已经装过但元数据不全的版本
!pip uninstall -y bitsandbytes

# 2) 从 PyPI 装最新稳定版
!pip install --upgrade bitsandbytes



# 评估与可视化
!pip install scikit-learn matplotlib

# 图像处理
!pip install pillow
!pip install rouge_score
!pip install huggingface_hub[hf_xet]

Found existing installation: bitsandbytes 0.46.0
Uninstalling bitsandbytes-0.46.0:
  Successfully uninstalled bitsandbytes-0.46.0
Collecting bitsandbytes
  Using cached bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Using cached bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)
Installing collected packages: bitsandbytes
Successfully installed bitsandbytes-0.46.0


In [4]:
!pip install grad-cam




In [9]:
# ───────────────────────────────────────────────────────────────────
# 第一部分：把 HF 缓存目录重定向到 Google Drive（或其他大盘），
#          以彻底避免 “No space left on device”
# ───────────────────────────────────────────────────────────────────
import os
from google.colab import drive

# 1) 挂载 Google Drive（在 Colab cell 里先执行一次）
drive.mount('/content/drive')

# 2) 指定一个位于 Drive 上的文件夹，专门用来存 HF 的缓存
hf_cache = "/content/drive/MyDrive/hf_cache"
os.makedirs(hf_cache, exist_ok=True)

# 3) 把 HF 相关的缓存环境变量都指向这个目录
os.environ["HF_HOME"]            = hf_cache
os.environ["HF_DATASETS_CACHE"]  = os.path.join(hf_cache, "datasets")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(hf_cache, "transformers")
# 到此，HF 无论是下载数据集、model weights 还是 tokenizer，都不会写到根盘


# ───────────────────────────────────────────────────────────────────
# 第二部分：导入各种库
# ───────────────────────────────────────────────────────────────────
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from PIL import Image
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

from datasets import load_dataset, Dataset
from peft import LoraConfig, get_peft_model
from transformers import (
    AutoProcessor,
    Qwen2VLForConditionalGeneration,
    BitsAndBytesConfig,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
    TrainerCallback
)
import evaluate
from nltk.translate.bleu_score import SmoothingFunction
from pycocoevalcap.cider.cider import Cider
from bert_score import BERTScorer

# Grad-CAM 相关（可选，如果不需要可视化就注释掉下面两行）
# pip install grad-cam
from pytorch_grad_cam import GradCAM as _GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# ───────────────────────────────────────────────────────────────────
# 第三部分：用 streaming=True 从 e-SNLI-VE 抽小子集，
#         完全不往本地写大文件
# ───────────────────────────────────────────────────────────────────
def convert_to_binary_label(gold_label):
    return 1.0 if gold_label == 0 else 0.0

def preprocess_stream_example(ex):
    return {
        "image": ex["image"],            # 这里保留原始值（PIL URL/dict/path）
        "caption": ex["hypothesis"],
        "cls_label": convert_to_binary_label(ex["gold_label"])
    }

# 流式读取 train/dev split（这样不会下载整份 Parquet 到磁盘）
train_iter = load_dataset("J1mb0o/e-snli-ve", split="train", streaming=True)
eval_iter  = load_dataset("J1mb0o/e-snli-ve", split="dev", streaming=True)

# 只取前 80 条训练，20 条评估（为了演示和测试）
import itertools
train_list = []
for ex in itertools.islice(train_iter, 80):
    train_list.append(preprocess_stream_example(ex))
eval_list = []
for ex in itertools.islice(eval_iter, 20):
    eval_list.append(preprocess_stream_example(ex))

print(f"✔️ 从流式数据集中抽取：Train {len(train_list)} | Eval {len(eval_list)}")

# 把 Python list 转成 HuggingFace Dataset（不再往磁盘写大文件）
train_ds = Dataset.from_list(train_list)
eval_ds  = Dataset.from_list(eval_list)


# ───────────────────────────────────────────────────────────────────
# 第四部分：加载 Qwen2-VL-2B-Instruct 所需的 Processor 和 Model
#           加上 trust_remote_code=True，避免 404 且让 HF 在内存里直接使用
# ───────────────────────────────────────────────────────────────────
repo_id = "Qwen/Qwen2-VL-2B-Instruct"

# 1) 先加载 Processor：
processor = AutoProcessor.from_pretrained(
    repo_id,
    trust_remote_code=True    # 这一步让 HF 使用仓库里真正存在的 preprocessor_config.json
)

# 2) 再加载 4-bit NF4 量化后的 Qwen2-VL-2B，并注入 LoRA
bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
base_model = Qwen2VLForConditionalGeneration.from_pretrained(
    repo_id,
    trust_remote_code=True,
    quantization_config=bnb_cfg,
    device_map="auto"
)

peft_cfg = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)
base_model = get_peft_model(base_model, peft_cfg)

print("✔️ Processor 与 Base Model（4-bit NF4 + LoRA）已加载完毕")
class Qwen2VLWithClassifier(torch.nn.Module):
    def __init__(self, pretrained_qwen, num_labels=1, cls_loss_weight=1.0):
        super().__init__()
        self.qwen = pretrained_qwen
        self.cls_loss_weight = cls_loss_weight

        hidden_size = self.qwen.config.hidden_size
        self.classifier = torch.nn.Linear(hidden_size, num_labels)
        self.cls_loss_fn = torch.nn.BCEWithLogitsLoss()

    def forward(
        self,
        pixel_values=None,
        input_ids=None,
        attention_mask=None,
        labels=None,
        cls_labels=None,
        output_attentions=True,
        return_dict=True,
        **kwargs
    ):
        # ——1) pop 掉 Trainer 或其他环节可能传入的无关 kwargs
        kwargs.pop("num_items_in_batch", None)
        kwargs.pop("pixel_mask", None)
        kwargs.pop("use_cache", None)
        # ——2) **关键**：把 image_grid_thw pop 掉，避免传 None 给模型
        kwargs.pop("image_grid_thw", None)

        # ——3) 将 pixel_values, input_ids, attention_mask, labels 传给底层 Qwen2-VL
        gen_outputs = self.qwen(
            pixel_values=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            output_attentions=output_attentions,
            return_dict=return_dict,
            **kwargs
        )
        # gen_outputs 包含：
        #   - 如果 labels 不为 None，则 gen_outputs.loss 存放语言生成 loss
        #   - gen_outputs.logits （语言生成 logits）
        #   - gen_outputs.last_hidden_state, cross_attentions 等

        # ——3) 提取视觉特征：在底层 self.qwen 里找到第一个带 get_image_features() 的子模块
        vision_feats = None
        for module in self.qwen.modules():
            if hasattr(module, "get_image_features"):
                vision_feats = module.get_image_features(pixel_values.to(self.qwen.device))
                break
        if vision_feats is None:
            raise AttributeError(
                "未能在 Qwen2-VL 模型内部找到任何子模块包含 get_image_features()，请检查模型版本是否正确。"
            )
        # vision_feats 的 shape 应为 [batch_size, hidden_size]

        # ——4) 分类 head
        cls_logits = self.classifier(vision_feats)  # [batch_size, num_labels]

        # ——5) 计算联合 loss：gen_loss + λ * cls_loss
        gen_loss = gen_outputs.loss if (labels is not None) else None
        cls_loss = None
        if cls_labels is not None:
            # cls_labels: [batch_size] 或 [batch_size,1]
            cls_loss = self.cls_loss_fn(cls_logits.view(-1), cls_labels.view(-1))

        total_loss = None
        if gen_loss is not None and cls_loss is not None:
            total_loss = gen_loss + self.cls_loss_weight * cls_loss
        elif gen_loss is not None:
            total_loss = gen_loss
        elif cls_loss is not None:
            total_loss = cls_loss

        outputs = {
            "loss":       total_loss,
            "gen_loss":   gen_loss,
            "cls_loss":   cls_loss,
            "gen_logits": gen_outputs.logits,
            "cls_logits": cls_logits,
        }
        if output_attentions:
            outputs["cross_attentions"] = gen_outputs.get("cross_attentions", None)
        return outputs

    def generate(self, *args, **kwargs):
        """
        生成阶段直接调用底层 Qwen2-VL 的 generate，不要传 image_grid_thw。
        同样 pop 掉不需要的 kwargs。
        """
        kwargs.pop("num_items_in_batch", None)
        kwargs.pop("pixel_mask", None)
        kwargs.pop("use_cache", None)
        return self.qwen.generate(*args, **kwargs)


def collate_fn_train(examples):
    """
    训练时的 batch 构造：只保留原始 PIL Image，让 processor 负责生成 pixel_values（和 image_grid_thw）。
    脚本中后续的 forward() 只会用 pixel_values，不再传 image_grid_thw。
    """
    texts, images, cls_labels = [], [], []
    for ex in examples:
        prompt = random.choice(prompts)
        img = ex["image"]
        # ex["image"] 可能是 dict{'path': ...}、也可能是 PIL.Image.Image
        if isinstance(img, dict):
            img = Image.open(img["path"]).convert("RGB")
        elif not isinstance(img, Image.Image):
            # 例如一开始就可能是路径字符串，这里做一次兜底转换
            img = Image.open(img).convert("RGB")

        # 1) 用 apply_chat_template 构建对话式 prompt + answer
        msgs = [
            {"role": "user", "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt}
            ]},
            {"role": "assistant", "content": ex["caption"]}
        ]
        txt = processor.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=False
        )
        texts.append(txt)

        # 2) 直接把 PIL Image 传给 processor：
        images.append(img)
        cls_labels.append(ex["cls_label"])

    # ③ 让 processor(texts, images) 自己做 resize/normalize → pixel_values, image_grid_thw
    batch = processor(text=texts, images=images, padding=True, return_tensors="pt")

    # ④ 构造 labels：把所有视觉 token（<|vision_start|> / <|vision_end|> / <|image_pad|>）和 pad_token 都置为 -100
    labels = batch.input_ids.clone()
    vs = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
    ve = processor.tokenizer.convert_tokens_to_ids("<|vision_end|>")
    vp = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
    pad = processor.tokenizer.pad_token_id
    mask = (labels == vs) | (labels == ve) | (labels == vp) | (labels == pad)
    labels[mask] = -100

    batch["labels"]     = labels
    batch["cls_labels"] = torch.tensor(cls_labels, dtype=torch.float)
    return batch


def collate_fn_eval(examples):
    """
    验证/测试时的 batch 构造，同理只让 processor 处理 PIL Image → pixel_values。
    """
    texts, images, captions, cls_labels = [], [], [], []
    for ex in examples:
        prompt = random.choice(prompts)
        img = ex["image"]
        if isinstance(img, dict):
            img = Image.open(img["path"]).convert("RGB")
        elif not isinstance(img, Image.Image):
            img = Image.open(img).convert("RGB")

        msgs = [
            {"role": "user", "content": [
                {"type": "image", "image": img},
                {"type": "text", "text": prompt}
            ]}
        ]
        txt = processor.apply_chat_template(
            msgs, tokenize=False, add_generation_prompt=True
        )
        texts.append(txt)

        images.append(img)
        captions.append(ex["caption"])
        cls_labels.append(ex["cls_label"])

    batch = processor(text=texts, images=images, padding=True, return_tensors="pt")

    batch["captions"]   = captions
    batch["cls_labels"] = cls_labels
    return batch



# 自定义 Eval Callback，输出分类 & 文本指标
_smooth = SmoothingFunction().method1
rouge_metric  = evaluate.load("rouge")
bleu_metric   = evaluate.load("bleu")
meteor_metric = evaluate.load("meteor")
cider_scorer  = Cider()
bert_scorer   = BERTScorer(lang="en", rescale_with_baseline=True)

class LowMemEvalCallback(TrainerCallback):
    def __init__(
        self,
        eval_dataset,
        collate_fn_eval,
        processor,
        model,
        batch_size: int = 4,
        max_eval_samples: int = None
    ):
        self.processor       = processor
        self.model           = model
        self.collate_fn_eval = collate_fn_eval

        if max_eval_samples is not None:
            eval_dataset = eval_dataset.select(range(max_eval_samples))

        self.loader = torch.utils.data.DataLoader(
            eval_dataset,
            batch_size=batch_size,
            collate_fn=collate_fn_eval,
            shuffle=False
        )

        self.rouge  = evaluate.load("rouge")
        self.bleu   = evaluate.load("bleu")
        self.meteor = evaluate.load("meteor")
        self.cider  = Cider()
        self.bert   = BERTScorer(lang="en", rescale_with_baseline=True)

    def on_epoch_end(self, args, state, control, **kwargs):
        device = self.model.qwen.device
        all_preds, all_refs = [], []
        total_cls_correct = 0
        total_cls_samples = 0

        self.model.eval()
        with torch.no_grad():
            for batch in self.loader:
                # —— 文本生成阶段：必须传入 pixel_values，如果有 image_grid_thw 就一并传给 generate()
                inputs = {
                    "pixel_values": batch["pixel_values"].to(device),
                    "input_ids":    batch["input_ids"].to(device),
                    "attention_mask": batch["attention_mask"].to(device),
                }
                # 某些版本底层支持 image_grid_thw，故尝试加入：
                if "image_grid_thw" in batch:
                    inputs["image_grid_thw"] = batch["image_grid_thw"].to(device)

                # 用我们的 generate()（内部自带 try/except，兼容各种版本）
                outs = self.model.generate(
                    **inputs,
                    max_new_tokens=100,
                    num_beams=4,
                    early_stopping=True
                )
                prefix_len = inputs["input_ids"].shape[-1]
                texts = self.processor.batch_decode(
                    outs[:, prefix_len:], skip_special_tokens=True
                )
                all_preds.extend(texts)
                all_refs.extend(batch["captions"])

                # —— 分类推理阶段：forward() 也要带上 image_grid_thw（如果有）
                fwd_kwargs = {
                    "pixel_values": batch["pixel_values"].to(device),
                    "input_ids":    inputs["input_ids"],
                    "attention_mask": inputs["attention_mask"],
                    "labels":       None,
                    "cls_labels":   None,
                    "output_attentions": False,
                    "return_dict":  True
                }
                if "image_grid_thw" in batch:
                    fwd_kwargs["image_grid_thw"] = batch["image_grid_thw"].to(device)

                model_outputs = self.model(**fwd_kwargs)
                cls_logits = model_outputs["cls_logits"]
                cls_probs = torch.sigmoid(cls_logits.view(-1))
                preds_binary = (cls_probs > 0.5).long().cpu().tolist()
                golds_binary = batch["cls_labels"]
                for p, g in zip(preds_binary, golds_binary):
                    if p == int(g):
                        total_cls_correct += 1
                    total_cls_samples += 1

                del inputs, outs, texts, model_outputs, cls_logits, cls_probs, preds_binary
                torch.cuda.empty_cache()

        # 文本指标
        refs_for_bleu = [[r] for r in all_refs]
        rL  = self.rouge.compute(predictions=all_preds, references=all_refs, use_stemmer=True)["rougeL"]
        b1  = self.bleu.compute(predictions=all_preds, references=refs_for_bleu, max_order=1)["bleu"]
        b2  = self.bleu.compute(predictions=all_preds, references=refs_for_bleu, max_order=2)["bleu"]
        m   = self.meteor.compute(predictions=all_preds, references=all_refs)["meteor"]
        c,_ = self.cider.compute_score(
            {i:[r] for i,r in enumerate(all_refs)},
            {i:[p] for i,p in enumerate(all_preds)}
        )
        bs = self.bert.compute(predictions=all_preds, references=all_refs, lang="en", rescale_with_baseline=True)
        f1 = float(sum(bs["f1"]) / len(bs["f1"]))

        # 分类准确率
        cls_acc = total_cls_correct / total_cls_samples if total_cls_samples > 0 else 0.0

        print(f"\n**** Eval Results (Epoch {int(state.epoch)}) ****")
        print(f"Text Metrics ➜ ROUGE-L: {rL:.4f}, BLEU-1: {b1:.4f}, BLEU-2: {b2:.4f}, METEOR: {m:.4f}, CIDEr: {c:.4f}, BERT-F1: {f1:.4f}")
        print(f"Classification ➜ Accuracy: {cls_acc:.4f} ({total_cls_correct}/{total_cls_samples})\n")
        return control


# ──────────────────────────────────────────────────────────────
# 第七部分：Prompt 池（加入更多变体以增强多样性）
# ──────────────────────────────────────────────────────────────
prompts = [
    # 原有示例
    "Q: Does this image support the given statement? Explain your reasoning.",
    "Question: Is the sentence true for the picture shown? Give your reasons.",
    "Is the description accurate for this image? Explain why or why not.",
    "Please determine whether the image entails the sentence, and provide a detailed explanation.",
    "Analyze the image–text pair and explain if they match or conflict.",
    "Assess whether the statement is consistent with the image, and justify your answer.",
    "I want you to judge if this sentence correctly describes the image and explain your judgment.",
    "Your task: verify the sentence against the image and articulate the reasons behind your decision.",

    # 新增变体，覆盖不同问法
    "请判断下面这张图片是否经过篡改，并说明你的理由。",
    "请指出图像中哪些部分可能被篡改，并给出解释。",
    "图中是否存在异常编辑？请说明具体区域和原因。",
    "判断图像是否真实，若不真实，请标出篡改位置并解释。",
    "你认为这张图片描述是否准确？如果不准确，请指出不符之处并说明原因。",
    "请验证该陈述是否与所示图片匹配，并详细说明原因。",
    "请分析图片与文字描述是否一致，如不一致，请指出并解释。",
    "这张图和给定句子是否矛盾？请说明你的判断依据。"
]


# ──────────────────────────────────────────────────────────────
# 第八部分：实例化模型、构造 Trainer、开始训练
# ──────────────────────────────────────────────────────────────
wrapped_model = Qwen2VLWithClassifier(
    pretrained_qwen=base_model,
    num_labels=1,
    cls_loss_weight=1.0
)

training_args = Seq2SeqTrainingArguments(
    output_dir="output",
    num_train_epochs=5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=5e-5,
    warmup_steps=20,
    weight_decay=0.01,

    logging_strategy="epoch",
    eval_strategy="epoch",
    save_strategy="epoch",

    # ← 把下面这一行改成 False，让 Trainer 不再自动走一次 generate()
    predict_with_generate=False,

    remove_unused_columns=False,
)

trainer = Seq2SeqTrainer(
    model=wrapped_model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    data_collator=collate_fn_train,
    callbacks=[
        LowMemEvalCallback(
            eval_dataset=eval_ds,
            collate_fn_eval=collate_fn_eval,
            processor=processor,
            model=wrapped_model,
            batch_size=4,
            max_eval_samples=None
        )
    ],
)

# 这样 trainer.train() 只会跑我们定义的数据 collator + forward + callback 中自己做生成的逻辑
trainer.train()

# 完成训练后保存 LoRA & 分类头
wrapped_model.qwen.save_pretrained("./qwen2vl_instruct_loRA_checkpoint")
torch.save(wrapped_model.classifier.state_dict(), "./classifier_head.pt")


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


Resolving data files:   0%|          | 0/116 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/116 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/116 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/116 [00:00<?, ?it/s]

✔️ 从流式数据集中抽取：Train 80 | Eval 20


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✔️ Processor 与 Base Model（4-bit NF4 + LoRA）已加载完毕


[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
Some weights of RobertaMod

TypeError: 'NoneType' object is not iterable