In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # 或你要用的 GPU 编号
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"  # 关闭 TF 日志
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"  # 禁用 TF 优化，避免影响

In [2]:
import os
import requests
from PIL import Image
from tqdm import tqdm
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from datasets import load_dataset
import transformers
from peft import get_peft_model, PrefixTuningConfig, TaskType
from transformers import Trainer, TrainingArguments
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2.5-VL-7B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
# )
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

# # 配置 Prefix Tuning
# peft_config = PrefixTuningConfig(
#     task_type=TaskType.CAUSAL_LM,
#     inference_mode=False,
#     num_virtual_tokens=20,
#     prefix_projection=True
# )

# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

In [4]:
from PIL import Image
from tqdm import tqdm
from transformers import AutoTokenizer
import numpy as np
from torchvision import transforms

def build_filtered_dataset(dataset_name='derek-thomas/ScienceQA',
                           split='train',
                           keep_grades='1-6'):
    """
    构建按年级和图像存在性过滤的数据集。

    参数:
        dataset_name (str): 数据集名称，例如 'derek-thomas/ScienceQA'。
        split (str): 数据分割，例如 'train', 'test', 'validation'。
        keep_grades (str or None): 筛选的年级段："1-6"、"7-12" 或 None 表示不过滤。

    返回:
        List[Dict]: 筛选后的样本列表。
    """

    def is_grade_allowed(grade_str):
        if keep_grades is None:
            return True
        try:
            grade_num = int(grade_str.replace("grade", ""))
            if keep_grades == "1-6":
                return 1 <= grade_num <= 6
            elif keep_grades == "7-12":
                return 7 <= grade_num <= 12
        except:
            return False
        return False



    data = load_dataset(dataset_name, split=split)
    dataset = []

    for i, sample in enumerate(data):
        try:
            if sample.get('question') is None:
                continue
            
            if sample.get("image", None) is None:
                continue

            if not is_grade_allowed(sample.get("grade", "")):
                continue

            solution = sample.get("solution", "")
            lecture = sample.get("lecture", "")
            solution_lecture = f"{solution}\n\n{lecture}".strip()
            
            image = sample["image"].convert("RGB")
            

            # image = np.array(image)
            # image = torch.tensor(image).permute(2, 0, 1)  # shape: (C, H, W)
            dataset.append({
                "image": image, 
                "question": sample["question"],
                "choices": sample["choices"],
                "hint": sample["hint"],
                "answer": sample["answer"],
                "solution_lecture": solution_lecture,
                'grade':sample["grade"],
            })
            
        except Exception as e:
            print(f"跳过第 {i} 个样本，错误：{e}")
            continue
    return dataset

dataset_train = build_filtered_dataset(split='train', keep_grades='1-6')
print(f"\n✅ 筛选后的样本数量: {len(dataset_train)}")


✅ 筛选后的样本数量: 4349


In [5]:
dataset_val = build_filtered_dataset(split='validation', keep_grades='1-6')
print(f"\n✅ 筛选后的样本数量: {len(dataset_val)}")


✅ 筛选后的样本数量: 1481


In [6]:
from random import choice

sample_1 = choice(dataset_train)
print(f"Question: {sample_1['question']}")
print(f"Choices: {sample_1['choices']}")
print(f"Hint: {sample_1['hint']}")
print(f"Grade: {sample_1['grade']}")
print(f"Answer: {sample_1['answer']}")
print(f"Explanation: {sample_1['solution_lecture']}")
print(f"Image type: {type(sample_1['image'])}")


Question: What is the name of the colony shown?
Choices: ['Connecticut', 'New Hampshire', 'Massachusetts', 'Wisconsin']
Hint: 
Grade: grade5
Answer: 0
Explanation: The colony is Connecticut.
Image type: <class 'PIL.Image.Image'>


In [7]:
# from PIL import Image
# import torch

# def build_training_sample(sample, processor, max_input_length=512, max_label_length=256, debug=False):
#     """
#     构建 prefix tuning 所需训练样本，适用于 Qwen2.5-VL。
    
#     参数:
#         sample (dict): 包含 image（PIL.Image 或路径）、question、choices、hint、answer、solution_lecture
#         processor: Qwen2.5-VL 对应的 AutoProcessor
#         debug (bool): 是否打印调试信息

#     返回:
#         dict: 包含 input_ids, attention_mask, labels, pixel_values
#     """
#     # 处理答案
#     if isinstance(sample["answer"], int):
#         answer_index = sample["answer"]
#     else:
#         answer_index = sample["choices"].index(sample["answer"])
#     answer_letter = chr(65 + answer_index)

#     # 构建问题文本
#     question_text = f"Question: {sample['question']}\nChoices:\n"
#     for idx, choice in enumerate(sample["choices"]):
#         question_text += f"{chr(65 + idx)}. {choice}\n"
#     if sample.get("hint"):
#         question_text += f"\nHint: {sample['hint']}\n"
#     question_text += (
#         "\nPlease select the correct answer. Then, explain your reasoning in detail. "
#         "Make sure your explanation is at least three sentences long, "
#         "refers to specific data from the image, and shows your step-by-step logic."
#     )

#     # 图像处理
#     image = sample.get("image", None)
#     if image is None:
#         raise ValueError("样本缺失图像字段 'image'")
#     if isinstance(image, str):
#         image = Image.open(image).convert("RGB")
#     elif isinstance(image, Image.Image):
#         image = image.convert("RGB")
#     else:
#         raise ValueError(f"image 类型错误，收到 {type(image)}")

#     if debug:
#         print(f"[DEBUG] SIZE: {image.size}, MODE: {image.mode}")

#     # 构建多模态消息
#     messages = [
#         {
#             "role": "user",
#             "content": [
#                 {"type": "image", "image": image},
#                 {"type": "text", "text": question_text}
#             ]
#         }
#     ]
#     prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

#     # 编码输入和标签
#     tokenizer = processor.tokenizer
#     prompt_ids = tokenizer(prompt, return_tensors="pt", max_length=max_input_length, padding="max_length", truncation=True)
#     label_text = f"Answer: {answer_letter}\nExplanation: {sample['solution_lecture']}"
#     label_ids = tokenizer(label_text, return_tensors="pt", max_length=max_label_length, padding="max_length", truncation=True)["input_ids"]
#     label_ids[label_ids == tokenizer.pad_token_id] = -100

#     # 图像转换为 tensor
#     processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
#     pixel_values = processor.image_processor(image, return_tensors="pt")["pixel_values"].squeeze(0)

#     if debug:
#         print(f"[DEBUG] pixel_values.shape: {pixel_values.shape}")

#     return {
#         "input_ids": prompt_ids["input_ids"].squeeze(0),
#         "attention_mask": prompt_ids["attention_mask"].squeeze(0),
#         "labels": label_ids.squeeze(0),
#         "pixel_values": pixel_values
# }



In [8]:
# from transformers import AutoProcessor
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
# train_sample = build_training_sample(sample, processor)

In [9]:
# train_sample = build_training_sample(sample, processor)

# print("keys:", train_sample.keys())
# print("input_ids shape:", train_sample["input_ids"].shape)
# print("attention_mask shape:", train_sample["attention_mask"].shape)
# print("labels shape:", train_sample["labels"].shape)
# print("pixel_values shape:", train_sample["pixel_values"].shape)

# 如需查看具体内容可取消注释
# print("input_ids:", train_sample["input_ids"])
# print("labels:", train_sample["labels"])

In [10]:
# from torch.utils.data import Dataset
# from PIL import Image
# import torch

# class QwenVLPrefixDataset(Dataset):
#     def __init__(self, data_list, processor, max_input_length=512, max_label_length=256, debug=False):
#         """
#         Qwen2.5-VL prefix tuning 数据集封装类。

#         参数:
#             data_list (List[Dict]): 每个样本是一个 dict，字段包括 image, question, choices, hint, answer, solution_lecture
#             processor: Qwen2.5-VL 对应的 AutoProcessor 实例
#             max_input_length (int): 输入文本最大长度
#             max_label_length (int): 输出标签最大长度
#             debug (bool): 是否启用调试信息
#         """
#         self.data_list = data_list
#         self.processor = processor
#         self.max_input_length = max_input_length
#         self.max_label_length = max_label_length
#         self.debug = debug

#     def __len__(self):
#         return len(self.data_list)

#     def __getitem__(self, idx):
#         sample = self.data_list[idx]
#         return self.build_training_sample(sample)

#     def build_training_sample(self, sample):
#         # 处理答案
#         if isinstance(sample["answer"], int):
#             answer_index = sample["answer"]
#         else:
#             answer_index = sample["choices"].index(sample["answer"])
#         answer_letter = chr(65 + answer_index)

#         # 构建问题文本
#         question_text = f"Question: {sample['question']}\nChoices:\n"
#         for idx, choice in enumerate(sample["choices"]):
#             question_text += f"{chr(65 + idx)}. {choice}\n"

#         if sample.get("hint"):
#             question_text += f"\nHint: {sample['hint']}\n"

#         question_text += (
#             "\nHere is a image: <img>\n"
#             "Please select the correct answer. Then, explain your reasoning in detail. "
#             "Make sure your explanation is at least three sentences long, "
#             "refers to specific data from the image, and shows your step-by-step logic."
#         )
        
#         chat = [
#             {"role": "user", "content": [
#                 {"type": "text", "text": question_text},
#                 {"type": "image","image": sample['image']}  # 这会被 processor 替换成 <img>
#                 ]}
#             ]
#         question_prompt = self.processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

#         # 图像对象
#         image = sample["image"]
#         image = image.resize((448, 448))
#         if not isinstance(image, Image.Image):
#             raise ValueError("image must be a PIL.Image.Image")

#         # 构建标签
#         label_text = f"Answer: {answer_letter}\nExplanation: {sample['solution_lecture']}"

#         # 使用 processor 统一处理图文（推荐方式）
#         inputs = self.processor(
#             text=question_prompt,
#             images=image,
#             return_tensors="pt",
#             padding="longest",  
#             truncation=False    
#             )
            
#         # 编码标签
#         tokenizer = self.processor.tokenizer
        
#         label_ids = tokenizer(
#             label_text,
#             return_tensors="pt",
#             # max_length=self.max_label_length,
#             padding=True,
#         )["input_ids"]
        
#         label_ids[label_ids == tokenizer.pad_token_id] = -100

#         if self.debug:
#             print(f"[DEBUG] Question:\n{question_text}")
#             print(f"[DEBUG] Label:\n{label_text}")
#             print(f"[DEBUG] pixel_values.shape: {inputs['pixel_values'].shape}")

#         return {
#             "input_ids": inputs["input_ids"].squeeze(0),
#             "attention_mask": inputs["attention_mask"].squeeze(0),
#             "labels": label_ids.squeeze(0),
#             "pixel_values": inputs["pixel_values"].squeeze(0),
#             "image_grid_thw": inputs["image_grid_thw"].squeeze(0)
#         }



In [11]:
from torch.utils.data import Dataset
from PIL import Image
import torch

class QwenVLPrefixDataset(Dataset):
    def __init__(self, data_list, processor, max_label_length=256, debug=False):
        self.data_list = data_list
        self.processor = processor
        self.max_label_length = max_label_length
        self.debug = debug

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        sample = self.data_list[idx]
        return self.build_training_sample(sample)

    def build_training_sample(self, sample):
        # 1. 处理答案
        if isinstance(sample["answer"], int):
            answer_index = sample["answer"]
        else:
            answer_index = sample["choices"].index(sample["answer"])
        answer_letter = chr(65 + answer_index)

        # 2. 构建问题内容
        question_text = f"Question: {sample['question']}\nChoices:\n"
        for idx, choice in enumerate(sample["choices"]):
            question_text += f"{chr(65 + idx)}. {choice}\n"
        if sample.get("hint"):
            question_text += f"\nHint: {sample['hint']}\n"
        question_text += (
            "\nHere is a image:\n"
            "Please select the correct answer. Then, explain your reasoning in detail. "
            "Make sure your explanation is at least three sentences long, "
            "refers to specific data from the image, and shows your step-by-step logic."
        )

        # 3. 构造 chat + 图像
        image = sample["image"]
        if not isinstance(image, Image.Image):
            raise ValueError("image must be a PIL.Image.Image")
        image = image.convert("RGB").resize((224, 224))

        chat = [
            {"role": "user", "content": [
                {"type": "text", "text": question_text},
                {"type": "image","image": image}
            ]}
        ]
        prompt = self.processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

        # 4. 编码图文输入（注意不要设置 truncation 或 max_length）
        inputs = self.processor(
            text=prompt,
            images=image,
            return_tensors="pt",
            padding="longest",   # ✅ 自动对齐
            truncation=False     # ✅ 不截断任何 token
        )

        # 5. 编码标签
        label_text = f"Answer: {answer_letter}\nExplanation: {sample['solution_lecture']}"
        tokenizer = self.processor.tokenizer
        
        input_len = inputs["input_ids"].shape[1]
        
        label_ids = tokenizer(
            label_text,
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=input_len
        )["input_ids"]
        label_ids[label_ids == tokenizer.pad_token_id] = -100

        # 6. 返回项（确保 input_ids 和 attention_mask 等长）
        input_ids = inputs["input_ids"].squeeze(0)
        attention_mask = inputs["attention_mask"].squeeze(0)
        if input_ids.shape != attention_mask.shape:
            raise ValueError(f"input_ids shape {input_ids.shape} ≠ attention_mask shape {attention_mask.shape}")

        result = {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_ids.squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
        }

        if "image_grid_thw" in inputs:
            result["image_grid_thw"] = inputs["image_grid_thw"].squeeze(0)

        if self.debug:
            print(f"[DEBUG] input_ids: {input_ids.shape}")
            print(f"[DEBUG] attention_mask: {attention_mask.shape}")

        return result


In [12]:
from transformers import AutoProcessor
from torch.utils.data import DataLoader

# 假设你已经有了 data_list，每个元素包含 image（路径或PIL对象）、question、choices、hint、answer、solution_lecture
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
train_dataset = QwenVLPrefixDataset(dataset_train, processor, debug=False)
dataloader = DataLoader(dataset_train, batch_size=6, shuffle=True)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [13]:
import torch
from torch.nn.utils.rnn import pad_sequence

def qwen_vl_collate_fn(batch):
    """
    可自动 pad 的 collate 函数，适配不同长度 input_ids / labels。
    """
    def pad_tensor_list(tensor_list, pad_value=0):
        return pad_sequence(tensor_list, batch_first=True, padding_value=pad_value)

    input_ids = [item["input_ids"] for item in batch]
    attention_mask = [item["attention_mask"] for item in batch]
    labels = [item["labels"] for item in batch]
    pixel_values = [item["pixel_values"] for item in batch]  # 通常 shape 一致，可直接 stack
    image_grid_thw = [item["image_grid_thw"] for item in batch]  # 通常 shape 一致，可直接 stack

    input_ids = pad_tensor_list(input_ids, pad_value=0)
    attention_mask = pad_tensor_list(attention_mask, pad_value=0)
    labels = pad_tensor_list(labels, pad_value=-100)  # 对 labels padding 用 -100 避免影响 loss
    pixel_values = torch.stack(pixel_values)
    image_grid_thw = torch.stack(image_grid_thw)
    return {
        "input_ids": input_ids,
        # "attention_mask": attention_mask,
        "labels": labels,
        "pixel_values": pixel_values,
        'image_grid_thw': image_grid_thw
        
    }

In [14]:
from torch.utils.data import DataLoader

# 正确使用 collate_fn
dataloader = DataLoader(train_dataset, batch_size=6, shuffle=False, collate_fn=qwen_vl_collate_fn)

# 获取第一个 batch
first_batch = next(iter(dataloader))

# 查看第一个样本的结构
print("==== 第一个 batch 的第一个样本 ====")
print(f"input_ids shape: {first_batch['input_ids'][0].shape}")
# print(f"attention_mask shape: {first_batch['attention_mask'][0].shape}")
print(f"labels shape: {first_batch['labels'][0].shape}")
print(f"pixel_values shape: {first_batch['pixel_values'][0].shape}")
print(f"image_grid_thw shape: {first_batch['image_grid_thw'][0].shape}")

# 可选：查看文本内容（需要 tokenizer）
tokenizer = processor.tokenizer
decoded_input = tokenizer.decode(first_batch['input_ids'][0], skip_special_tokens=True)
decoded_label = tokenizer.decode(
    [id for id in first_batch['labels'][0].tolist() if id != -100],
    skip_special_tokens=True
)

print("\n--- 解码后的 input_ids ---")
print(decoded_input)

print("\n--- 解码后的 labels ---")
print(decoded_label)


==== 第一个 batch 的第一个样本 ====
input_ids shape: torch.Size([227])
labels shape: torch.Size([227])
pixel_values shape: torch.Size([256, 1176])
image_grid_thw shape: torch.Size([3])

--- 解码后的 input_ids ---
system
You are a helpful assistant.
user
Question: Which of these states is farthest north?
Choices:
A. West Virginia
B. Louisiana
C. Arizona
D. Oklahoma

Here is a image:
Please select the correct answer. Then, explain your reasoning in detail. Make sure your explanation is at least three sentences long, refers to specific data from the image, and shows your step-by-step logic.
assistant
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

--- 解码后的 labels ---
Answer: A
Explanation: To find the answer, look at the compass rose. Look at which way the north arrow is pointing. West Virginia is farthest north.

Maps have four cardinal directions, or main directions. Those directions are north, south, east, and west.
A compass rose is a set of arrows that point to the cardinal di

In [15]:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct", device_map="auto", torch_dtype=torch.bfloat16
)
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")

# # 配置 Prefix Tuning
# peft_config = PrefixTuningConfig(
#     task_type=TaskType.CAUSAL_LM,
#     inference_mode=False,
#     num_virtual_tokens=20,
#     prefix_projection=True
# )

# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.34s/it]


In [16]:
peft_config = PrefixTuningConfig(
    task_type='CAUSAL_LM',
    inference_mode=False,
    num_virtual_tokens=12,
    prefix_projection=True
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 4,805,888 || all params: 3,759,428,864 || trainable%: 0.1278


In [17]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):  # 屏蔽 num_items_in_batch
        if "num_items_in_batch" in kwargs:
            kwargs.pop("num_items_in_batch")
        outputs = model(**inputs)
        loss = outputs.loss
        return (loss, outputs) if return_outputs else loss

In [18]:
import re
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge import Rouge

def parse_output(output: str):
    """
    提取选择题答案和解释文本，容错支持 Answer: A, A., A: 等。
    返回 answer: int (0~3 对应 A-D), explanation: str
    """
    output = output.strip()
    answer_match = re.search(r"(?i)\banswer\s*[:\-]?\s*([A-D])\b", output)
    if not answer_match:
        answer_match = re.search(r"\b([A-D])[\.\:\-]", output)

    if answer_match:
        choice_char = answer_match.group(1).upper()
        answer = ord(choice_char) - ord("A")
    else:
        answer = -1

    explanation = ""
    if answer_match:
        idx = output.find(answer_match.group(0))
        if idx != -1:
            explanation = output[idx + len(answer_match.group(0)):].strip()

    return answer, explanation

def keyword_overlap(pred, ref):
    pred_keywords = set(pred.lower().split())
    ref_keywords = set(ref.lower().split())
    if not ref_keywords:
        return 0.0
    return len(pred_keywords & ref_keywords) / len(ref_keywords)

def compute_metrics(eval_preds, tokenizer):
    predictions = eval_preds.predictions
    label_ids = eval_preds.label_ids

    # 如果 predictions 是 tuple（例如包含 logits），取第一个作为 token ids
    if isinstance(predictions, tuple):
        predictions = predictions[0]

    # 如果是 logits（形如 [batch, seq_len, vocab_size]），则 argmax 得到 token ids
    if predictions.ndim == 3:
        predictions = predictions.argmax(-1)

    decoded_preds = []
    decoded_labels = []

    for pred_ids, label in zip(predictions, label_ids):
        label = [id for id in label if id != -100]
        decoded_pred = tokenizer.decode(pred_ids, skip_special_tokens=True)
        decoded_label = tokenizer.decode(label, skip_special_tokens=True)

        decoded_preds.append(decoded_pred.strip())
        decoded_labels.append(decoded_label.strip())

    # 打印第一个样本的预测和标签
    print("\n==== 示例输出 ====")
    print("预测：", decoded_preds[0])
    print("标签：", decoded_labels[0])

    smoothie = SmoothingFunction().method4
    bleu1_scores = []
    bleu4_scores = []
    rouge_l_scores = []
    keyword_overlaps = []
    choice_correct = []

    rouge = Rouge()

    for pred, label in zip(decoded_preds, decoded_labels):
        reference = label.split()
        candidate = pred.split()

        # BLEU-1 和 BLEU-4
        bleu1 = sentence_bleu([reference], candidate, weights=(1, 0, 0, 0), smoothing_function=smoothie)
        bleu4 = sentence_bleu([reference], candidate, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)
        bleu1_scores.append(bleu1)
        bleu4_scores.append(bleu4)

        # ROUGE-L
        try:
            rouge_l = rouge.get_scores(pred, label)[0]['rouge-l']['f']
        except ValueError:
            rouge_l = 0.0
        rouge_l_scores.append(rouge_l)

        # Keyword overlap
        keyword_acc = keyword_overlap(pred, label)
        keyword_overlaps.append(keyword_acc)

        # Choice accuracy
        pred_choice, _ = parse_output(pred)
        label_choice, _ = parse_output(label)
        is_correct = (pred_choice == label_choice) and (pred_choice != -1)
        choice_correct.append(int(is_correct))

    return {
        "BLEU-1": np.mean(bleu1_scores),
        "BLEU-4": np.mean(bleu4_scores),
        "ROUGE-L": np.mean(rouge_l_scores),
        "KeywordOverlap": np.mean(keyword_overlaps),
        "ChoiceAccuracy": np.mean(choice_correct),
    }



In [19]:
# from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction

# def compute_metrics(eval_preds, tokenizer):
#     predictions = eval_preds.predictions
#     label_ids = eval_preds.label_ids

#     # 如果 predictions 是 tuple（例如包含 logits），取第一个作为 token ids
#     if isinstance(predictions, tuple):
#         predictions = predictions[0]

#     # 如果是 logits（形如 [batch, seq_len, vocab_size]），则 argmax 得到 token ids
#     if predictions.ndim == 3:
#         predictions = predictions.argmax(-1)

#     decoded_preds = []
#     decoded_labels = []

#     for pred_ids, label in zip(predictions, label_ids):
#         label = [id for id in label if id != -100]
#         decoded_pred = tokenizer.decode(pred_ids, skip_special_tokens=True)
#         decoded_label = tokenizer.decode(label, skip_special_tokens=True)

#         decoded_preds.append(decoded_pred.strip())
#         decoded_labels.append(decoded_label.strip())

#     # 打印第一个样本的预测和标签
#     print("\n==== 示例输出 ====")
#     print("预测：", decoded_preds[0])
#     print("标签：", decoded_labels[0])

#     smoothie = SmoothingFunction().method4
#     bleu_scores = []

#     for pred, label in zip(decoded_preds, decoded_labels):
#         reference = label.split()
#         candidate = pred.split()
#         bleu = sentence_bleu([reference], candidate, smoothing_function=smoothie)
#         bleu_scores.append(bleu)

#     average_bleu = sum(bleu_scores) / len(bleu_scores)
#     return {"bleu": average_bleu}




In [20]:
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
small_train_dataset=dataset_train[:100] 
small_val_dataset=dataset_val[:10]
train_dataset = QwenVLPrefixDataset(small_train_dataset, processor, debug=False)
val_dataset = QwenVLPrefixDataset(small_val_dataset, processor, debug=False)    
# dataloader = DataLoader(train_dataset, batch_size=6, shuffle=True)

In [21]:
from transformers import Trainer, TrainingArguments
import torch



training_args = TrainingArguments(
    output_dir="./qwen2.5vl-prefix",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=1, 
    dataloader_num_workers=0,
    eval_accumulation_steps=1,
    learning_rate=5e-4,
    num_train_epochs=100,
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    bf16=True,  # 如果使用的是支持 bfloat16 的 GPU，可改为 bf16=True
    gradient_accumulation_steps=4,
    remove_unused_columns=False
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    data_collator=qwen_vl_collate_fn,
    compute_metrics=lambda p: compute_metrics(p, tokenizer=processor.tokenizer)
)


Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
[codecarbon ERROR @ 20:11:16] Error: Another instance of codecarbon is probably running as we find `/tmp/.codecarbon.lock`. Turn off the other instance to be able to run this one or use `allow_multiple_runs` or delete the file. Exiting.


In [24]:
trainer.train()



Step,Training Loss


KeyboardInterrupt: 

In [23]:
# metrics = trainer.evaluate()
# print(metrics)