In [None]:
# %% [markdown]
# # 合并 Qwen2-VL 多任务消融实验与 Grad-CAM 可视化 Notebook
# 本 Notebook 包含：
# 1. 环境依赖安装
# 2. Google Drive HF 缓存配置
# 3. 库导入与模型/数据处理函数定义
# 4. 数据加载与处理
# 5. 评估回调与指标定义
# 6. 训练 & 消融实验
# 7. 训练后评估
# 8. Grad-CAM 可视化

# %% [markdown]
# ## 1. 环境依赖安装
!pip install --upgrade transformers datasets peft evaluate qwen-vl-utils bitsandbytes captum matplotlib pillow rouge_score huggingface_hub[hf_xet] bert_score
!pip install git+https://github.com/salaniz/pycocoevalcap.git
!pip install pytorch_grad_cam


# %% [markdown]
# ## 3. 库导入与模型/数据处理函数定义
import random
import torch
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from datasets import load_dataset, Dataset
from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, BitsAndBytesConfig, Seq2SeqTrainer, Seq2SeqTrainingArguments
from peft import LoraConfig, get_peft_model
from qwen_vl_utils import process_vision_info
import evaluate
from nltk.translate.bleu_score import SmoothingFunction
from pycocoevalcap.cider.cider import Cider
from bert_score import BERTScorer
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# LoRA + 多任务分类 Head
class Qwen2VLWithClassifier(torch.nn.Module):
    def __init__(self, pretrained_qwen, num_labels=1, cls_loss_weight=1.0):
        super().__init__()
        self.qwen = pretrained_qwen
        self.cls_loss_weight = cls_loss_weight
        hidden_size = self.qwen.config.hidden_size
        self.classifier = torch.nn.Linear(hidden_size, num_labels)
        self.cls_loss_fn = torch.nn.BCEWithLogitsLoss()

    def forward(self, pixel_values=None, input_ids=None, attention_mask=None, labels=None, cls_labels=None, output_attentions=True, return_dict=True, **kwargs):
        # pop 多余参数
        for k in ["num_items_in_batch","pixel_mask","use_cache","image_grid_thw"]:
            kwargs.pop(k, None)
        gen_out = self.qwen(pixel_values=pixel_values, input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_attentions=output_attentions, return_dict=return_dict)
        # 分类 head
        vision_feats = None
        for m in self.qwen.modules():
            if hasattr(m, "get_image_features"): vision_feats = m.get_image_features(pixel_values.to(self.qwen.device)); break
        cls_logits = self.classifier(vision_feats)
        # 计算联合 loss
        gen_loss = gen_out.loss if labels is not None else None
        cls_loss = self.cls_loss_fn(cls_logits.view(-1), cls_labels.view(-1)) if cls_labels is not None else None
        total_loss = None
        if gen_loss is not None and cls_loss is not None:
            total_loss = gen_loss + self.cls_loss_weight * cls_loss
        elif gen_loss is not None:
            total_loss = gen_loss
        elif cls_loss is not None:
            total_loss = cls_loss
        return {"loss": total_loss, "gen_logits": gen_out.logits, "cls_logits": cls_logits}

# Collate 函数
def collate_fn_train(examples):
    processor = collate_fn_train.processor
    texts, imgs, cls_lbs = [], [], []
    for ex in examples:
        prompt = random.choice(collate_fn_train.prompts)
        img = ex["image"]
        if isinstance(img, dict): img = Image.open(img["path"]).convert("RGB")
        elif not isinstance(img, Image.Image): img = Image.open(img).convert("RGB")
        msgs = [{"role":"user","content":[{"type":"image","image":img},{"type":"text","text":prompt}]}, {"role":"assistant","content":ex["caption"]}]
        txt = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=False)
        texts.append(txt); imgs.append(img); cls_lbs.append(ex["cls_label"])
    batch = processor(text=texts, images=imgs, padding=True, return_tensors="pt")
    labels = batch.input_ids.clone()
    tok = processor.tokenizer
    vs, ve, vp, pad = tok.convert_tokens_to_ids(["<|vision_start|>","<|vision_end|>","<|image_pad|>\n"] ) + [tok.pad_token_id]
    mask = (labels==vs)|(labels==ve)|(labels==vp)|(labels==pad)
    labels[mask] = -100
    batch["labels"] = labels
    batch["cls_labels"] = torch.tensor(cls_lbs, dtype=torch.float)
    return batch

def collate_fn_eval(examples):
    processor = collate_fn_train.processor
    texts, imgs, caps, cls_lbs = [], [], [], []
    for ex in examples:
        prompt = random.choice(collate_fn_train.prompts)
        img = ex["image"]
        if isinstance(img, dict): img = Image.open(img["path"]).convert("RGB")
        elif not isinstance(img, Image.Image): img = Image.open(img).convert("RGB")
        msgs = [{"role":"user","content":[{"type":"image","image":img},{"type":"text","text":prompt}]}]
        txt = processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
        texts.append(txt); imgs.append(img); caps.append(ex["caption"]); cls_lbs.append(ex["cls_label"])
    batch = processor(text=texts, images=imgs, padding=True, return_tensors="pt")
    batch["captions"] = caps; batch["cls_labels"] = cls_lbs
    return batch

# %% [markdown]
# ## 4. 加载模型、Processor 与 Prompt 池
repo_id = "Qwen/Qwen2-VL-2B-Instruct"
processor = AutoProcessor.from_pretrained(repo_id, trust_remote_code=True)
bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
base_model = Qwen2VLForConditionalGeneration.from_pretrained(repo_id, trust_remote_code=True, quantization_config=bnb_cfg, device_map="auto")
peft_cfg = LoraConfig(r=32, lora_alpha=32, target_modules=["q_proj","k_proj","v_proj","o_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
base_model = get_peft_model(base_model, peft_cfg)
wrapped_model = Qwen2VLWithClassifier(pretrained_qwen=base_model)

# Prompt 池
collate_fn_train.processor = processor
collate_fn_train.prompts = [
    "Q: Does this image support the statement? Explain.",
    "Is the description accurate? Why or why not?"
    # ... 可添加更多变体
]
collate_fn_train = collate_fn_train; collate_fn_eval = collate_fn_eval

# %% [markdown]
# ## 5. 加载数据集 & 转为 Dataset
from datasets import load_dataset
import itertools
raw_train = load_dataset("J1mb0o/e-snli-ve", split="train", streaming=True)
raw_dev   = load_dataset("J1mb0o/e-snli-ve", split="dev", streaming=True)

def preprocess(ex): return {"image":ex["image"], "caption":ex["hypothesis"], "cls_label": (1.0 if ex["gold_label"]==0 else 0.0)}
train_list = [preprocess(ex) for ex in itertools.islice(raw_train, 80)]
eval_list  = [preprocess(ex) for ex in itertools.islice(raw_dev, 20)]
train_ds = Dataset.from_list(train_list)
eval_ds  = Dataset.from_list(eval_list)

# %% [markdown]
# ## 6. 定义评估回调 & 指标
_smooth = SmoothingFunction().method1
rouge  = evaluate.load("rouge")
bleu   = evaluate.load("bleu")
meteor = evaluate.load("meteor")
cider  = Cider()
berts  = BERTScorer(lang="en", rescale_with_baseline=True)

from transformers import TrainerCallback
class LowMemEvalCallback(TrainerCallback):
    def on_epoch_end(self, args, state, control, **kwargs):
        device = wrapped_model.qwen.device
        all_pred, all_ref = [], []
        cls_corr, cls_tot = 0,0
        loader = torch.utils.data.DataLoader(eval_ds, batch_size=4, collate_fn=collate_fn_eval)
        wrapped_model.eval()
        with torch.no_grad():
            for batch in loader:
                inputs = {k:v.to(device) for k,v in batch.items() if k in ["pixel_values","input_ids","attention_mask"]}
                if "image_grid_thw" in batch: inputs["image_grid_thw"]=batch["image_grid_thw"].to(device)
                outs = wrapped_model.qwen.generate(**inputs, max_new_tokens=100, num_beams=4)
                txts = processor.batch_decode(outs[:,inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
                all_pred.extend(txts); all_ref.extend(batch["captions"])
                # 分类
                fwd = wrapped_model(pixel_values=batch["pixel_values"].to(device), input_ids=batch["input_ids"].to(device), attention_mask=batch["attention_mask"].to(device), cls_labels=None)
                probs = torch.sigmoid(fwd["cls_logits"].view(-1))>0.5
                golds = batch["cls_labels"]
                for p,g in zip(probs.cpu(), golds): cls_corr += (p.int()==int(g)); cls_tot+=1
        # 打印指标
        print(f"Epoch {int(state.epoch)}: ROUGE-L {rouge.compute(predictions=all_pred,references=all_ref,use_stemmer=True)['rougeL']:.4f}, BLEU {bleu.compute(predictions=all_pred,references=[[r] for r in all_ref])['bleu']:.4f}, ACC {cls_corr/cls_tot:.4f}")
        return control

# %% [markdown]
# ## 7. 训练 & 消融实验
training_args = Seq2SeqTrainingArguments(
    output_dir="./output", num_train_epochs=5, per_device_train_batch_size=4, per_device_eval_batch_size=4,
    learning_rate=5e-5, warmup_steps=20, weight_decay=0.01,
    logging_strategy="epoch", eval_strategy="epoch", save_strategy="epoch",
    remove_unused_columns=False, predict_with_generate=False
)
trainer = Seq2SeqTrainer(
    model=wrapped_model, args=training_args,
    train_dataset=train_ds, eval_dataset=eval_ds,
    data_collator=collate_fn_train, callbacks=[LowMemEvalCallback()]
)
trainer.train()

# %% [markdown]
# ## 8. Grad-CAM 可视化
# 取一张样例图
ex = eval_ds[0]
img = ex['image']
if isinstance(img, dict): img = Image.open(img['path']).convert('RGB')
elif not isinstance(img, Image.Image): img = Image.open(img).convert('RGB')
original_size = img.size

prompt = random.choice(collate_fn_train.prompts)
msgs = [{"role":"user","content":[{"type":"image","image":img},{"type":"text","text":prompt}]}]
inputs = processor(text=processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True), images=[img], padding=True, return_tensors='pt')

# 前向获得特征图
wrapped = wrapped_model.qwen
vision = wrapped.model.visual
cam = GradCAM(model=wrapped, target_layers=[vision.patch_embed.proj], use_cuda=torch.cuda.is_available())

grayscale_cam = cam(input_tensor=inputs['pixel_values'].to(wrapped.device))[0]
heatmap = np.uint8(255 * grayscale_cam)
heatmap = Image.fromarray(heatmap).resize(original_size)

# 可视化
plt.figure(figsize=(12,6))
plt.subplot(1,2,1); plt.imshow(img); plt.axis('off'); plt.title('Original')
plt.subplot(1,2,2); plt.imshow(img); plt.imshow(np.array(heatmap), cmap='jet', alpha=0.5); plt.axis('off'); plt.title('Grad-CAM')
plt.show()


Collecting transformers
  Downloading transformers-4.53.2-py3-none-any.whl.metadata (40 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.9/40.9 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting datasets
  Downloading datasets-4.0.0-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Collecting qwen-vl-utils
  Downloading qwen_vl_utils-0.0.11-py3-none-any.whl.metadata (6.3 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.46.1-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting captum
  Downloading captum-0.8.0-py3-none-any.whl.metadata (26 kB)
Collecting matplotlib
  Downloading matplotlib-3.10.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting pillow
  Downloading pillow-11.3.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (9.0 kB)
Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing m

Collecting git+https://github.com/salaniz/pycocoevalcap.git
  Cloning https://github.com/salaniz/pycocoevalcap.git to /tmp/pip-req-build-5hzso47h
  Running command git clone --filter=blob:none --quiet https://github.com/salaniz/pycocoevalcap.git /tmp/pip-req-build-5hzso47h
^C
