In [2]:
from transformers import AutoProcessor

model_name = "Qwen/Qwen3-VL-4B-Instruct"

processor = AutoProcessor.from_pretrained(model_name)

In [None]:
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "video",
                "video": "test_video.mp4",
                "min_pixels": 4 * 32 * 32,
                "max_pixels": 256 * 32 * 32,
                "total_pixels": 20480 * 32 * 32,
            },
            {"type": "text", "text": "Describe ball movements in great details."},
        ],
    },
]

from qwen_vl_utils import process_vision_info

text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
images, videos, video_kwargs = process_vision_info(messages, image_patch_size=16, return_video_kwargs=True, return_video_metadata=True)

In [3]:
from transformers import Qwen3VLForConditionalGeneration

model = Qwen3VLForConditionalGeneration.from_pretrained(
    model_name,
    dtype="bfloat16",
    device_map="auto",
    attn_implementation="flash_attention_2"
)
print(model)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Qwen3VLForConditionalGeneration(
  (model): Qwen3VLModel(
    (visual): Qwen3VLVisionModel(
      (patch_embed): Qwen3VLVisionPatchEmbed(
        (proj): Conv3d(3, 1024, kernel_size=(2, 16, 16), stride=(2, 16, 16))
      )
      (pos_embed): Embedding(2304, 1024)
      (rotary_pos_emb): Qwen3VLVisionRotaryEmbedding()
      (blocks): ModuleList(
        (0-23): 24 x Qwen3VLVisionBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): Qwen3VLVisionAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (mlp): Qwen3VLVisionMLP(
            (linear_fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (linear_fc2): Linear(in_features=4096, out_features=1024, bias=True)
            (act_fn): GELUTanh()
          )
        )
      )
 

In [None]:
# split the videos and according metadatas
if videos is not None:
    videos, video_metadatas = zip(*videos)
    videos, video_metadatas = list(videos), list(video_metadatas)
else:
    video_metadatas = None

# since qwen-vl-utils has resize the images/videos, \
# we should pass do_resize=False to avoid duplicate operation in processor!
inputs = processor(text=text, images=images, videos=videos, video_metadata=video_metadatas, return_tensors="pt", do_resize=False, **video_kwargs)
inputs = inputs.to(model.device)

In [None]:
# Inference: Generation of the output
generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text[0])

In [None]:
from pprint import pprint
pprint(output_text[0])

In [None]:
import torch

def index_to_letter(index):
    assert 0 <= index < 26, f"Index must be between 0 and 25, got {index}"
    return chr(index + ord("A"))

def letter_to_index(letter):
    assert len(letter) == 1 and letter.isalpha(), f"Letter must be a single alphabetic character, got {letter}"
    return ord(letter.upper()) - ord("A")

def get_assistant_mask(input_ids):
    # Vectorized search for sequence "<|im_start|>assistant\n" -> [151644, 77091, 198]
    pattern = torch.tensor([151644, 77091, 198], device=input_ids.device, dtype=input_ids.dtype)
    k = pattern.numel()
    batch_size, seq_len = input_ids.shape

    if seq_len < k:
        return torch.zeros_like(input_ids, dtype=torch.bool)

    # Create all sliding windows of length k: shape (B, T-k+1, k)
    windows = input_ids.unfold(1, k, 1)
    # Match windows against the pattern
    matches = (windows == pattern).all(dim=-1)  # (B, T-k+1)

    any_match = matches.any(dim=1)
    # First occurrence index (undefined if no match, so guard with any_match)
    first_pos = torch.where(
        any_match,
        matches.int().argmax(dim=1),
        torch.full((batch_size,), seq_len, device=input_ids.device, dtype=torch.long),
    )

    start_after = first_pos + k
    mask = torch.arange(seq_len, device=input_ids.device).unsqueeze(0) >= start_after.unsqueeze(1)
    return mask

def video_assistant_data_collator(samples):
    conversations = []
    for sample in samples:
        question = sample["question"]
        choices = sample["choices"]
        
        question_prompt = f"{question}\n\n"
        for i, choice in enumerate(choices):
            question_prompt += f"{index_to_letter(i)}. {choice}\n"
        
        ground_truth_indices = sample["ground_truth"]  # List[int] indicies
        ground_truth_answer = "".join(index_to_letter(i) for i in ground_truth_indices)
        
        conversation = [
            {
                "role": "user",
                "content": [
                    {
                        "type": "video",
                        "video": sample["video"],
                        "min_pixels": 4 * 32 * 32,
                        "max_pixels": 256 * 32 * 32,
                        "total_pixels": 20480 * 32 * 32,
                    },
                    {"type": "text", "text": question_prompt},
                ],
            },
            {
                "role": "assistant",
                "content": ground_truth_answer,
            }
        ]
        conversations.append(conversation)
    
    text = processor.apply_chat_template(conversations, tokenize=False, add_generation_prompt=False)
    images, videos, video_kwargs = process_vision_info(conversations, image_patch_size=16, return_video_kwargs=True, return_video_metadata=True)

    # split the videos and according metadatas
    if videos is not None:
        videos, video_metadatas = zip(*videos)
        videos, video_metadatas = list(videos), list(video_metadatas)
    else:
        video_metadatas = None

    # since qwen-vl-utils has resize the images/videos,
    # we should pass do_resize=False to avoid duplicate operation in processor!
    inputs = processor(text=text, images=images, videos=videos, video_metadata=video_metadatas, return_tensors="pt", do_resize=False, **video_kwargs)

    input_ids = inputs.input_ids
    assistant_mask = get_assistant_mask(input_ids)
    # TODO: there's a bug here, the assistant mask doesn't mask the turn end token
    
    labels = inputs.input_ids.clone()
    labels[~assistant_mask] = -100
    
    return {
        **inputs,
        "labels": labels,
    }
    