In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2,3"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch

from transformers import TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from modelscope import snapshot_download
from modelscope import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration
import re
from transformers.models.qwen2_vl.video_processing_qwen2_vl import Qwen2VLVideoProcessor
from qwen_omni_utils import process_mm_info
import json
from torch.utils.data import Dataset
from transformers.image_utils import SizeDict

def replace_time_tokens_with_percentage(text, time_map, duration):

    if not time_map or duration is None:
        return text

    def repl(match):
        token = match.group(0)
        if token not in time_map:
            return token
        t = time_map[token]
        pct = t / duration * 100.0
        return f"{pct:.1f}%"

    return re.sub(r"<s\d+>|<e\d+>", repl, text)


class OmniVideoConversationDataset(Dataset):
    def __init__(self, json_path: str, video_root: str):
        with open(json_path, "r") as f:
            raw_data = json.load(f)

        self.video_root = video_root
        self.samples = []

        for item in raw_data:
            video_id = item["id"]
            video_path = os.path.join(video_root, f"{video_id}.mp4")
            audio_path = video_path.replace(".mp4", ".wav")

            convs = item["conversations"]
            meta = item.get("meta", {})
            duration = meta.get("duration", None)
            time_map = meta.get("token", {})
            

            # 遍历 human / gpt 成对
            for i in range(0, len(convs) - 1, 2):
                if convs[i]["from"] != "human" or convs[i + 1]["from"] != "gpt":
                    continue

                self.samples.append({
                    "video_path": video_path,
                    "audio_path": audio_path,
                    "question": convs[i]["value"],
                    "answer": convs[i + 1]["value"],
                    "duration": duration,
                    "time_map": time_map,
                })

    def __len__(self):
        return len(self.samples)


    def _build_text(self, conversations):
        messages = []
        for turn in conversations:
            if turn["from"] == "human":
                role = "user"
            elif turn["from"] == "gpt":
                role = "assistant"
            else:
                continue

            messages.append({
                "role": role,
                "content": turn["value"]
            })

        return messages

    def __getitem__(self, idx):
        s = self.samples[idx]
        conversation = [
            {
                "role": "system",
                "content": [
                    {"type": "text", "text": "You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech."}
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "video", "video": s["video_path"]},
                    {"type": "audio", "audio": s["audio_path"]},
                    {"type": "text", "text": s["question"]},
                ],
            },
            {
                "role": "assistant",
                "content": [
                    {"type": "text", "text": s["answer"]},
                ],
            },
        ]

        return {
            "conversation": conversation,
            "duration": s["duration"],
            "time_map": s["time_map"],
            }


class QwenOmniDataCollator:
    def __init__(self, processor):
        self.processor = processor
        self.tokenizer = processor.tokenizer
        self.video_cache = {}

    def __call__(self, features):
        texts = []
        videos = []
        audios = []

        for f in features:
            conversation = f["conversation"]

            for msg in conversation:
                if msg["role"] in ("user", "assistant"):
                    for ele in msg["content"]:
                        if ele.get("type") == "text":
                            ele["text"] = replace_time_tokens_with_percentage(
                                ele["text"],
                                f["time_map"],
                                f["duration"],
                            )

            # ---------- 1. 拼完整 prompt ----------
            full_text = self.processor.apply_chat_template(
                conversation,
                tokenize=False,
                add_generation_prompt=False,
            )

            # ---------- 2. 构造 labels（后缀 assistant） ----------
            texts.append(full_text)

            # ---------- 3. 多模态 ----------
            for msg in conversation:
                if msg["role"] == "user":
                    for ele in msg["content"]:
                        if ele.get("type") == "video":
                            ele["fps"] = 0.5
                            ele["max_frames"] = 50

                            video_path = ele["video"]

                            # ---------- 使用缓存 ----------
                            if video_path not in self.video_cache:
                                audios_, _, videos_ = process_mm_info(
                                    conversation, use_audio_in_video=False
                                )
                                self.video_cache[video_path] = {
                                    "video": videos_[0] if videos_ else None,
                                    "audio": audios_[0] if audios_ else None,
                                }

                            video_tensor = self.video_cache[video_path]["video"]
                            audio_tensor = self.video_cache[video_path]["audio"]


            # audios_, _, videos_ = process_mm_info(
            #     conversation, use_audio_in_video=True
            # )

            # videos.append(videos_[0] if videos_ else None)
            # audios.append(audios_[0] if audios_ else None)

            videos.append(video_tensor)
            audios.append(audio_tensor)


        # ---------- 4. 一次性 processor ----------
        batch = self.processor(
            text=texts,
            videos=videos,
            audio=audios,
            padding=True,
            return_tensors="pt",
            use_audio_in_video=False
        )


        labels = batch["input_ids"].clone()
        labels[:] = -100

        im_start_id = self.tokenizer.convert_tokens_to_ids("<|im_start|>")
        assistant_id = self.tokenizer.convert_tokens_to_ids("assistant")
        im_end_id = self.tokenizer.convert_tokens_to_ids("<|im_end|>")

        for b in range(labels.size(0)):
            input_ids = batch["input_ids"][b]

            start = None
            for i in range(len(input_ids) - 1):
                if input_ids[i] == im_start_id and input_ids[i + 1] == assistant_id:
                    start = i + 3
                    break

            if start is None:
                raise RuntimeError("No <|im_start|> assistant found")

            end = None
            for i in range(start, len(input_ids)):
                if input_ids[i] == im_end_id:
                    end = i
                    break

            if end is None:
                end = len(input_ids)

            labels[b, start:end] = input_ids[start:end]

        batch["labels"] = labels

        print(batch["labels"])

        print(batch["video_grid_thw"].shape) 

        print(batch["pixel_values_videos"].shape) 
        for k, v in batch.items(): 
            if isinstance(v, torch.Tensor): 
                print(k, v.shape, v.numel() * v.element_size() / 1024**3, "GB")

        return batch


train_dataset = OmniVideoConversationDataset(
    json_path="../../LongVALE/data/longvale-sft-bp-7k.json",
    video_root="../../LongVALE/raw_videos_train/video_train_7240/"
)


model_path = snapshot_download(
    'Qwen/Qwen2.5-Omni-3B',
    cache_dir="../../Qwen/cache/modelscope"
)

model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
    model_path,
    dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    use_safetensors=True
)

class FixedResQwen2VLVideoProcessor(Qwen2VLVideoProcessor):
    def _preprocess(
        self, videos, do_resize=True, size=None, interpolation=None, **kwargs
    ):
        # 固定分辨率
        fixed_size = SizeDict(height=224, width=224)
        for i, video in enumerate(videos):
            videos[i] = self.resize(video, size=fixed_size, interpolation=interpolation)
        return super()._preprocess(videos, do_resize=False, size=fixed_size, interpolation=interpolation, **kwargs)
    
video_processor = FixedResQwen2VLVideoProcessor.from_pretrained(model_path)

processor = Qwen2_5OmniProcessor.from_pretrained(
    model_path,
    video_processor=video_processor,
)


# 配置LoRA
config = LoraConfig(
    r=16,
    lora_alpha=16,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    # task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)


for name, param in model.named_parameters():
    if (
        "audio_tower" in name
        or "visual" in name
    ):
        param.requires_grad = False

model.gradient_checkpointing_enable()
model.enable_input_require_grads()
model.config.use_cache = False

model.print_trainable_parameters()

# 检查模型是否在训练模式
model.train()
print(f"Model is in training mode: {model.training}")

batch_size = 1

args = TrainingArguments(
    output_dir="./r_models",
    remove_unused_columns=False,
    eval_strategy="no",
    save_strategy="epoch",
    learning_rate=1e-4,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=2,
    bf16=True,
    fp16=False,
    num_train_epochs=2,
    logging_steps=5,
    load_best_model_at_end=False,
)

data_collator = QwenOmniDataCollator(processor)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    data_collator=data_collator,
)



2025-12-22 21:55:50,481 - modelscope - INFO - Target directory already exists, skipping creation.
Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


Downloading Model from https://www.modelscope.cn to directory: ../../Qwen/cache/modelscope/Qwen/Qwen2.5-Omni-3B


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Detected kernel version 4.15.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


trainable params: 22,413,312 || all params: 4,734,622,720 || trainable%: 0.4734
Model is in training mode: True


In [10]:
trainer.train()

tensor([[ -100,  -100,  -100,  ..., 14360,  -100,  -100]])
torch.Size([1, 3])
torch.Size([3584, 1176])
input_ids torch.Size([1, 2466]) 1.837313175201416e-05 GB
attention_mask torch.Size([1, 2466]) 1.837313175201416e-05 GB
pixel_values_videos torch.Size([3584, 1176]) 0.0157012939453125 GB
video_grid_thw torch.Size([1, 3]) 2.2351741790771484e-08 GB
video_second_per_grid torch.Size([1]) 3.725290298461914e-09 GB
feature_attention_mask torch.Size([1, 30000]) 0.00011175870895385742 GB
input_features torch.Size([1, 128, 30000]) 0.01430511474609375 GB
labels torch.Size([1, 2466]) 1.837313175201416e-05 GB
tensor([[ -100,  -100,  -100,  ..., 14360,  -100,  -100]])
torch.Size([1, 3])
torch.Size([6400, 1176])
input_ids torch.Size([1, 4866]) 3.625452518463135e-05 GB
attention_mask torch.Size([1, 4866]) 3.625452518463135e-05 GB
pixel_values_videos torch.Size([6400, 1176]) 0.02803802490234375 GB
video_grid_thw torch.Size([1, 3]) 2.2351741790771484e-08 GB
video_second_per_grid torch.Size([1]) 3.725290

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss


[h264 @ 0xd39c1e40] mmco: unref short failure
[h264 @ 0xd39c1e40] mmco: unref short failure
[h264 @ 0xd39c1e40] mmco: unref short failure
[h264 @ 0xd39c1e40] mmco: unref short failure
[h264 @ 0xd39c1e40] mmco: unref short failure
[h264 @ 0xd39c1e40] mmco: unref short failure


tensor([[ -100,  -100,  -100,  ..., 14360,  -100,  -100]])
torch.Size([1, 3])
torch.Size([6400, 1176])
input_ids torch.Size([1, 7112]) 5.2988529205322266e-05 GB
attention_mask torch.Size([1, 7112]) 5.2988529205322266e-05 GB
pixel_values_videos torch.Size([6400, 1176]) 0.02803802490234375 GB
video_grid_thw torch.Size([1, 3]) 2.2351741790771484e-08 GB
video_second_per_grid torch.Size([1]) 3.725290298461914e-09 GB
feature_attention_mask torch.Size([1, 30000]) 0.00011175870895385742 GB
input_features torch.Size([1, 128, 30000]) 0.01430511474609375 GB
labels torch.Size([1, 7112]) 5.2988529205322266e-05 GB
tensor([[-100, -100, -100,  ...,   13, -100, -100]])
torch.Size([1, 3])
torch.Size([6400, 1176])
input_ids torch.Size([1, 7259]) 5.408376455307007e-05 GB
attention_mask torch.Size([1, 7259]) 5.408376455307007e-05 GB
pixel_values_videos torch.Size([6400, 1176]) 0.02803802490234375 GB
video_grid_thw torch.Size([1, 3]) 2.2351741790771484e-08 GB
video_second_per_grid torch.Size([1]) 3.72529029

OutOfMemoryError: CUDA out of memory. Tried to allocate 5.18 GiB. GPU 2 has a total capacity of 23.69 GiB of which 4.10 GiB is free. Process 112003 has 8.25 GiB memory in use. Including non-PyTorch memory, this process has 11.31 GiB memory in use. Of the allocated memory 10.96 GiB is allocated by PyTorch, and 41.44 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
from inspect import signature
print(signature(Qwen2_5OmniForConditionalGeneration.forward))
print(signature(Qwen2_5OmniThinkerForConditionalGeneration.forward))

In [None]:
def verify_lora_gradients(model):
    """验证LoRA梯度流"""
    model.train()
    
    # 创建测试数据
    test_inputs = {
        'input_ids': torch.randint(0, 1000, (1, 10)).cuda(),
        'attention_mask': torch.ones(1, 10).cuda(),
        'labels': torch.randint(0, 1000, (1, 10)).cuda(),
    }
    
    # 前向传播
    outputs = model(**test_inputs)
    loss = outputs.loss
    
    print(f"Loss: {loss.item():.4f}")
    print(f"Loss requires_grad: {loss.requires_grad}")
    
    if loss.requires_grad:
        # 反向传播
        loss.backward()
        
        # 检查梯度
        gradients_found = False
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                gradients_found = True
                grad_norm = param.grad.norm().item()
                if 'lora' in name:
                    print(f"  ✓ LoRA梯度: {name} | norm={grad_norm:.6f}")
        
        if not gradients_found:
            print("  ⚠ 没有找到梯度")
    else:
        print("  ✗ Loss没有requires_grad")

verify_lora_gradients(model)

In [4]:
# 取训练集第一个样本
sample = train_dataset[0]

# 通过 data_collator 生成 batch
batch = data_collator([sample])

labels = batch["labels"]       # [batch_size, seq_len]
input_ids = batch["input_ids"] # [batch_size, seq_len]
tokenizer = processor.tokenizer

for i in range(labels.shape[0]):
    effective_ids = []
    ignored_ids = []
    
    for j in range(labels.shape[1]):
        token_id = input_ids[i, j].item()
        label = labels[i, j].item()
        if label != -100:
            effective_ids.append(token_id)
        else:
            ignored_ids.append(token_id)
    
    effective_text = tokenizer.decode(effective_ids)
    ignored_text = tokenizer.decode(ignored_ids)
    
    print(f"=== Sample {i} ===")
    print("有效 token 拼接文本:")
    print(effective_text)
    print("无效 token 拼接文本:")
    print(ignored_text)



StopIteration: 

=== Sample 0 ===
有效 token 拼接文本:
From <s4> to <e4>.

这里有个问题，<s4> 转化成什么时间
- 用真实时间？
    因为 Qwen 实际上能学习到真实的时间，但是评估的时候要算成相对时间吗？还是直接和数据集中的计算？但是这样和 LongVale 算是不是不太公平
- 和 LongVale 一样用相对时间百分比
    感觉这个是对应的固定 100 帧，如果 Qwen 也固定一百帧的话（不太确定这样能不能对应成绝对时间？），而且固定 100帧 目前显存也不够


In [None]:
print("\n=== 进行前向传播测试 ===")
with torch.set_grad_enabled(True):

    sample = next(iter(train_dataset))
    batch = data_collator([sample])
    

    device = model.device
    batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
    
    outputs = model(**batch)
    loss = outputs.loss
    
    print(f"Loss: {loss}")
    print(f"Loss requires_grad: {loss.requires_grad}")
    
    if loss.requires_grad:
        loss.backward()
        print("反向传播成功")
        

        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                print(f"参数 '{name}' 有梯度")
                break
        else:
            print("没有发现任何参数的梯度")
    else:
        print("loss没有requires_grad属性")

In [3]:
print(len(train_dataset))

75075


In [4]:
trainer.get_train_dataloader()

<torch.utils.data.dataloader.DataLoader at 0x7f0a304428f0>

In [5]:
sample = train_dataset[0]
print(sample.keys())
print(sample["conversation"])

dict_keys(['conversation', 'duration', 'time_map'])
[{'role': 'system', 'content': [{'type': 'text', 'text': 'You are Qwen, a virtual human developed by the Qwen Team, Alibaba Group, capable of perceiving auditory and visual inputs, as well as generating text and speech.'}]}, {'role': 'user', 'content': [{'type': 'video', 'video': '../../LongVALE/raw_videos_train/video_train_7240/cNj_TMPKa10.mp4'}, {'type': 'audio', 'audio': '../../LongVALE/raw_videos_train/video_train_7240/cNj_TMPKa10.wav'}, {'type': 'text', 'text': '<video>\nDuring which frames in the video can we observe someone lifts a freshly prepared taco, adorned with cilantro and a dollop of white sauce, from a wooden platter, a juicy lime wedge resting beside it, as a man excitedly remarks on an "East connection" amidst the lively chatter happening?'}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': 'From <s4> to <e4>.'}]}]


In [6]:
collator = QwenOmniDataCollator(processor)
batch = collator([sample])
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)

StopIteration: 

In [8]:
train_loader = trainer.get_train_dataloader()
for b in train_loader:
    print(b.keys())
    break