In [1]:
import os
import base64
from io import BytesIO
from PIL import Image
import re
import numpy as np
from collections import defaultdict, Counter
import json
import pandas as pd
import torch
from torch.utils.data import Dataset
from transformers import AutoProcessor, TrainingArguments, Trainer, Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


In [12]:
###### 基础工具
def decode_base64_to_image(base64_str: str) -> Image.Image:
    img_data = base64.urlsafe_b64decode(base64_str)
    return Image.open(BytesIO(img_data)).convert("RGB")

def chinese_char_tokenize(text: str) -> list:
    """
    中文按字切分，连续字母/数字视为一个 token
    示例: "V领包臀裙80后" → ["V", "领", "包", "臀", "裙", "80", "后"]
    """
    tokens = []
    i = 0
    while i < len(text):
        if re.match(r'[a-zA-Z0-9]', text[i]):
            j = i
            while j < len(text) and re.match(r'[a-zA-Z0-9]', text[j]):
                j += 1
            tokens.append(text[i:j])
            i = j
        else:
            tokens.append(text[i])
            i += 1
    return tokens

In [13]:
class ECommerceCaptionDataset(Dataset):
    def __init__(self, tsv_path, jsonl_path, processor, max_length=512, nrows=2):
        self.processor = processor
        self.max_length = max_length
        
        # 加载图像 base64
        df_img = pd.read_csv(tsv_path, sep='\t', header=None, names=['img_id', 'img_b64'], nrows = nrows)
        self.img_dict = dict(zip(df_img['img_id'], df_img['img_b64']))
        
        # 加载 caption
        self.samples = []
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                data = json.loads(line.strip())
                img_id = data['image_id']
                if img_id in self.img_dict:
                    # 随机选一条作为 target（SFT 阶段）
                    target = np.random.choice(data['text'])
                    self.samples.append((img_id, target, data['text']))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_id, target_caption, all_captions = self.samples[idx]
        image = decode_base64_to_image(self.img_dict[img_id])
        
        # 构造对话格式（正确的多模态内容格式）
        messages = [
            {"role": "user", "content": [
                {"type": "image"},
                {"type": "text", "text": "请为这件商品生成一段吸引人的描述。"}
            ]},
            {"role": "assistant", "content": target_caption}
        ]
        prompt = self.processor.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False
        )
        
        inputs = self.processor(
            images=image,
            text=prompt,
            return_tensors="pt",
            padding=True,
            truncation=True
        )
        # 生成 labels，将输入文本作为标签，忽略 pad_token
        labels = inputs["input_ids"].clone()
        labels[labels == self.processor.tokenizer.pad_token_id] = -100  # 修正：使用 self.processor
        inputs["labels"] = labels
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs, all_captions  # all_captions 用于 SCST

In [3]:
# 加载模型和 processor
model_id = "/mnt/d/HuggingFaceModels/models--Qwen--Qwen2.5-VL-3B-Instruct/snapshots/66285546d2b821cf421d4f5eb2576359d3770cd3"
# 降低图像分辨率以减少内存使用
processor = Qwen2_5_VLProcessor.from_pretrained(model_id, image_size=224)

# 针对16GB GPU优化的模型加载配置
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_id,
    device_map="cuda",  # 强制使用CUDA
    dtype=torch.float16,  # 使用float16而不是bfloat16来减少内存使用
    low_cpu_mem_usage=True,
    # attn_implementation="flash_attention_2"  # 若支持
)

# 添加gradient checkpointing以减少内存使用
model.gradient_checkpointing_enable()

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████| 2/2 [00:12<00:00,  6.02s/it]


In [4]:
# 注入 LoRA
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

trainable params: 1,843,200 || all params: 3,756,466,176 || trainable%: 0.0491


In [5]:
# 数据集
base_dir = "/mnt/d/forCoding_data/Tianchi_MUGE/originalData/ECommerce-IC/"

train_dataset = ECommerceCaptionDataset(
    base_dir + "IC_train.tsv", 
    base_dir + "IC_train.jsonl", 
    processor
)

In [6]:
# 数据整理函数，确保正确合并batch - 特别处理Qwen2.5-VL需要的图像特征
def data_collator(batch):
    # 从batch中提取所有inputs
    inputs_list = [item[0] for item in batch]
    
    # 合并inputs字典
    merged_inputs = {}
    for key in inputs_list[0].keys():
        if key == "pixel_values":
            # 图像像素值需要特殊处理，使用torch.stack
            merged_inputs[key] = torch.stack([item[key] for item in inputs_list])
        elif key == "input_ids" or key == "attention_mask" or key == "labels":
            # 对于序列数据，使用pad_sequence来确保长度一致
            merged_inputs[key] = torch.nn.utils.rnn.pad_sequence(
                [item[key] for item in inputs_list],
                batch_first=True,
                padding_value=processor.tokenizer.pad_token_id
            )
        else:
            # 其他类型的数据默认处理
            merged_inputs[key] = torch.stack([item[key] for item in inputs_list])
    
    return merged_inputs

In [18]:
# 训练参数 - 针对16GB GPU优化batch大小
training_args = TrainingArguments(
    output_dir="sft_qwen25vl_lora",
    per_device_train_batch_size=1,  # 减小batch size以适应16GB GPU
    gradient_accumulation_steps=8,  # 增加gradient accumulation来弥补小batch
    learning_rate=2e-4,
    num_train_epochs=3,
    logging_steps=10,
    save_strategy="epoch",
    fp16=True,  # 使用混合精度训练以减少内存使用
    remove_unused_columns=False,
    dataloader_pin_memory=True,
)

In [19]:
# training_args = TrainingArguments(
#     output_dir="./sft_qwen25vl_lora",
#     per_device_train_batch_size=1,
#     gradient_accumulation_steps=8,
#     learning_rate=1e-4,          # LoRA 学习率可稍高
#     num_train_epochs=3,
#     logging_steps=20,
#     save_strategy="epoch",
#     eval_strategy="epoch",
#     fp16=False,                  # bfloat16 已由 bnb 指定
#     bf16=True,
#     optim="paged_adamw_8bit",    # 8-bit 优化器，省显存
#     lr_scheduler_type="cosine",
#     warmup_ratio=0.03,
#     remove_unused_columns=False,
#     dataloader_pin_memory=True,
#     report_to="none",
#     ddp_find_unused_parameters=False,
# )

In [None]:
# 自定义 Trainer 类来实现自定义 loss
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs["labels"]
        outputs = model(**inputs)
        logits = outputs.get("logits")
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
        return (loss, outputs) if return_outputs else loss

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=data_collator,  # 使用自定义的data_collator
)

# 检查是否已经存在训练好的模型
model_dir = "sft_qwen25vl_lora"
final_model_dir = os.path.join(model_dir, "final")
checkpoint_exists = False

# 检查是否存在final模型目录
if os.path.exists(final_model_dir):
    print(f"检测到已存在final模型目录: {final_model_dir}，跳过SFT训练")
    checkpoint_exists = True
else:
    # 检查是否存在checkpoint目录
    import glob
    checkpoint_dirs = glob.glob(os.path.join(model_dir, "checkpoint-*"))
    if checkpoint_dirs:
        print(f"检测到已存在checkpoint目录: {checkpoint_dirs[0]}，跳过SFT训练")
        checkpoint_exists = True

# 如果不存在模型，才进行训练
if not checkpoint_exists:
    trainer.train()
else:
    print("模型已存在，直接加载模型进行后续操作")

In [20]:
## 上面，不知为何，跑得很慢。
## 可能是因为量化得不太到位？