In [1]:
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from transformers.modeling_outputs import BaseModelOutputWithPast, ModelOutput

In [3]:
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

@dataclass
class Qwen2_5_VLCausalLMOutputWithPast(ModelOutput):
    """
    Base class for Qwen2_5_VL causal language model (or autoregressive) outputs.

    Args:
        loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
            Language modeling loss (for next-token prediction).
        logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
            Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
        past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
            Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
            `(batch_size, num_heads, sequence_length, embed_size_per_head)`)

            Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
            `past_key_values` input) to speed up sequential decoding.
        hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
            Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
            one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
        attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
            Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
            sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.
        rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
            The rope index difference between sequence length and multimodal rope.
    """

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    past_key_values: Optional[List[torch.FloatTensor]] = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None
    rope_deltas: Optional[torch.LongTensor] = None

In [4]:

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
base_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-VL-3B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    cache_dir=HF_HOME_DIR,
)


You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  4.38it/s]


In [5]:
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct")
# The default range for the number of visual tokens per image in the model is 4-16384.
# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

# messages = [
#     {
#         "role": "user",
#         "content": [
#             {
#                 "type": "image",
#                 "image": "latex.png",
#                 "min_pixels": 4*28*28,
#                 "max_pixels": 4*28*28,
#             },
#             {"type": "text", "text": "Describe this image."},
#         ],
#     }
# ]

# # Preparation for inference
# text = processor.apply_chat_template(
#     messages, tokenize=False, add_generation_prompt=True
# )
# image_inputs, video_inputs = process_vision_info(messages)
# inputs = processor(
#     text=[text],
#     images=image_inputs,
#     videos=video_inputs,
#     padding=True,
#     return_tensors="pt",
# )
# inputs = inputs.to("cuda")

# # # Inference: Generation of the output
# generated_ids = model.generate(**inputs, max_new_tokens=512)
# generated_ids_trimmed = [
#     out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
# ]
# output_text = processor.batch_decode(
#     generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
# )
# print(output_text)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.50, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [6]:
N_IMG_TOKENS = 512

messages = [
    {
        "role": "user",
        "content": [
            {"type": "text", "text": "Can you describe these images?"},
            {
                "type": "image",
                "image": "latex.png",
                "min_pixels": N_IMG_TOKENS*28*28,
                "max_pixels": N_IMG_TOKENS*28*28,
            },
            {
                "type": "image",
                "image": "latex_2.png",
                "min_pixels": N_IMG_TOKENS*28*28,
                "max_pixels": N_IMG_TOKENS*28*28,
            },
            {"type": "text", "text": "These are from my notes."},
        ],
    },
]

In [7]:
text = processor.apply_chat_template(
    messages, add_generation_prompt=True, add_vision_id=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)

inputs = inputs.to('cuda')

In [8]:
sequence = inputs['input_ids'].tolist()[0]

decoded_text = processor.tokenizer.decode(sequence)
for i in range(len(sequence)):
    print('token:', sequence[i], 'decoded:', processor.tokenizer.decode([sequence[i]]))


token: 151644 decoded: <|im_start|>
token: 8948 decoded: system
token: 198 decoded: 

token: 2610 decoded: You
token: 525 decoded:  are
token: 264 decoded:  a
token: 10950 decoded:  helpful
token: 17847 decoded:  assistant
token: 13 decoded: .
token: 151645 decoded: <|im_end|>
token: 198 decoded: 

token: 151644 decoded: <|im_start|>
token: 872 decoded: user
token: 198 decoded: 

token: 6713 decoded: Can
token: 498 decoded:  you
token: 7512 decoded:  describe
token: 1493 decoded:  these
token: 5335 decoded:  images
token: 30 decoded: ?
token: 151652 decoded: <|vision_start|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 151655 decoded: <|image_pad|>
token: 1516

In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Optional, Tuple, Union
from transformers import Qwen2_5_VLForConditionalGeneration
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss

class AnyToAnyQwen2_5_VL(Qwen2_5_VLForConditionalGeneration):
    def __init__(self, config):
        super().__init__(config)
        # Number of vision tokens to generate after <|vision_start|>
        self.n_img_tokens = 10  # Assuming 10 image embeddings as mentioned in requirements
        # Special token IDs
        self.vision_start_token_id = 151652  # <|vision_start|>
        self.vision_end_token_id = 151653    # <|vision_end|>
        # Setup vision head
        self.setup_vision_head()
        # Initialize weights
        self.post_init()

    def setup_vision_head(self):
        """Initialize the vision head that predicts vision embeddings."""
        self.vision_head = nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=False)

    def generate_labels(self, input_ids, vision_positions=None):
        """
        Generate labels for language modeling (same as input_ids for next-token prediction).
        Masks out vision positions with -100 so they don't contribute to text loss.
        """
        labels = input_ids.clone()
        
        # Mask out vision positions
        if vision_positions is not None:
            labels = labels.masked_fill(vision_positions, -100)
        
        print(f"labels: {labels}, shape: {labels.shape}")
        print(f"vision_positions: {vision_positions}, shape: {vision_positions.shape}")
        
        return labels
        
    def identify_vision_positions(self, input_ids):
        """
        Create a boolean mask identifying positions where vision tokens should be predicted.
        These positions include:
        1. Image token positions (marked by self.config.image_token_id)
        2. Positions between vision_start_token_id and vision_end_token_id
        
        Args:
            input_ids: Tensor of shape (batch_size, seq_length)
            
        Returns:
            is_vision_position: Boolean mask of shape (batch_size, seq_length)
        """
        # First identify positions with image_token_id
        is_vision_position = (input_ids == self.config.image_token_id)
        
        # Also identify positions with video_token_id if applicable
        if hasattr(self.config, 'video_token_id'):
            is_vision_position = is_vision_position | (input_ids == self.config.video_token_id)
        
        # Next, identify positions between vision_start_token_id and vision_end_token_id
        batch_size, seq_length = input_ids.shape
        
        # Find positions where we need to predict vision tokens
        for b in range(batch_size):
            # Find start tokens in the sequence
            starts = (input_ids[b] == self.vision_start_token_id).nonzero(as_tuple=True)[0]
            ends = (input_ids[b] == self.vision_end_token_id).nonzero(as_tuple=True)[0]
            
            # For each start token, mark the positions up to the corresponding end token
            for i in range(len(starts)):
                start_pos = starts[i]
                # Find the corresponding end position or use sequence length if not found
                end_pos = seq_length
                for end_idx in ends:
                    if end_idx > start_pos:
                        end_pos = end_idx
                        break
                
                # Mark the positions between start and end as vision positions
                # (not including the start token itself, but including positions up to end token)
                if start_pos + 1 < seq_length:
                    is_vision_position[b, start_pos+1:min(end_pos+1, seq_length)] = True
        
        return is_vision_position

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        pixel_values: Optional[torch.Tensor] = None,
        pixel_values_videos: Optional[torch.FloatTensor] = None,
        image_grid_thw: Optional[torch.LongTensor] = None,
        video_grid_thw: Optional[torch.LongTensor] = None,
        rope_deltas: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        second_per_grid_ts: Optional[torch.Tensor] = None,
    ) -> Union[Tuple, CausalLMOutputWithPast]:
        """
        Forward pass for the any-to-any model that can generate both text and vision tokens.
        """
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Store original input_embeds for vision loss calculation
        original_inputs_embeds = None
        
        # Process inputs and prepare embeddings
        if inputs_embeds is None:
            inputs_embeds = self.model.embed_tokens(input_ids)
            
            # Process image inputs if available
            if pixel_values is not None:
                pixel_values = pixel_values.type(self.visual.dtype)
                image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
                mask = (input_ids == self.config.image_token_id).unsqueeze(-1).expand_as(inputs_embeds)
                inputs_embeds = inputs_embeds.masked_scatter(
                    mask, image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                )
            
            # Process video inputs if available
            if pixel_values_videos is not None:
                pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
                video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
                mask = (input_ids == self.config.video_token_id).unsqueeze(-1).expand_as(inputs_embeds)
                inputs_embeds = inputs_embeds.masked_scatter(
                    mask, video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
                )
                
            if attention_mask is not None:
                attention_mask = attention_mask.to(inputs_embeds.device)
            
            # Save the original inputs_embeds for use in loss calculation
            # Using detach() to prevent gradient flow through the target values
            original_inputs_embeds = inputs_embeds.clone().detach()

        # Handle position_ids and rope_deltas as in the original model
        if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
            if (
                (cache_position is not None and cache_position[0] == 0)
                or self.rope_deltas is None
                or (past_key_values is None or past_key_values.get_seq_length() == 0)
            ):
                position_ids, rope_deltas = self.get_rope_index(
                    input_ids,
                    image_grid_thw,
                    video_grid_thw,
                    second_per_grid_ts,
                    attention_mask,
                )
                self.rope_deltas = rope_deltas
            else:
                batch_size, seq_length, _ = inputs_embeds.shape
                delta = (
                    (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
                    if cache_position is not None
                    else 0
                )
                position_ids = torch.arange(seq_length, device=inputs_embeds.device)
                position_ids = position_ids.view(1, -1).expand(batch_size, -1)
                if cache_position is not None:
                    delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
                position_ids = position_ids.add(delta)
                position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

        # Forward pass through the LLM
        outputs = self.model(
            input_ids=None,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            cache_position=cache_position,
        )

        hidden_states = outputs[0]
        
        # Identify positions where we need to predict vision tokens
        vision_positions = None
        if input_ids is not None:
            vision_positions = self.identify_vision_positions(input_ids)
        print(f"vision_positions: {vision_positions}")
        
        # Get text logits for the entire sequence
        text_logits = self.lm_head(hidden_states)
        print(f"text_logits.shape: {text_logits.shape}")
        
        # Get vision predictions for the entire sequence
        vision_predictions = self.vision_head(hidden_states)
        print(f"vision_predictions.shape: {vision_predictions.shape}")
        
        # Calculate loss (if in training mode)
        loss = None
        
        # If labels are not provided but we have input_ids, generate them
        if labels is None and input_ids is not None:
            labels = self.generate_labels(input_ids, vision_positions)
            
        if labels is not None:
            # Initialize loss
            loss = 0.0
            
            # 1. Calculate text loss (cross-entropy)
            text_logits_float = text_logits.float()
            
            # Shift logits and labels for next-token prediction
            shift_logits = text_logits_float[..., :-1, :].contiguous()
            print(f"shift_logits: {shift_logits}, shape: {shift_logits.shape}")
            shift_labels = labels[..., 1:].contiguous()
            print(f"shift_labels: {shift_labels}, shape: {shift_labels.shape}")
            
            # Mask out vision positions in labels to avoid computing text loss for them
            if vision_positions is not None:
                # Shift vision positions to align with shifted labels
                shift_vision_positions = vision_positions[..., 1:].contiguous()
                shift_labels = shift_labels.masked_fill(shift_vision_positions, -100)
            
            # Calculate cross-entropy loss for text positions
            loss_fct = CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)
            shift_labels = shift_labels.to(shift_logits.device)
            text_loss = loss_fct(shift_logits, shift_labels)
            loss += text_loss
            print(f"text_loss: {text_loss}, shape: {text_loss.shape}")
            
            # 2. Calculate vision loss (MSE or cosine similarity)
            if vision_positions is not None and vision_positions.any() and original_inputs_embeds is not None:
                # Get the shifted vision positions for next-token prediction
                shift_vision_positions = vision_positions[..., :-1].contiguous()
                
                if shift_vision_positions.any():
                    # Get predictions and targets for vision positions
                    # Predictions: what our vision head predicts for the next position
                    # Targets: the actual embeddings from the input at the next position
                    
                    # Get predicted vision embeddings (from current positions)
                    pred_vision_embeds = vision_predictions[..., :-1, :][shift_vision_positions]
                    
                    # Get target embeddings (from next positions in original_inputs_embeds)
                    target_vision_embeds = original_inputs_embeds[..., 1:, :][shift_vision_positions]
                    
                    # Calculate MSE loss between predicted and target vision embeddings
                    vision_loss = F.mse_loss(
                        pred_vision_embeds, 
                        target_vision_embeds.to(pred_vision_embeds.dtype)
                    )
                    
                    # Alternative: Cosine similarity loss
                    # vision_loss = 1 - F.cosine_similarity(
                    #     pred_vision_embeds, 
                    #     target_vision_embeds.to(pred_vision_embeds.dtype), 
                    #     dim=-1
                    # ).mean()
                    
                    loss += vision_loss
                    print(f"vision_loss: {vision_loss}, shape: {vision_loss.shape}")

        # Return the appropriate output format
        if not return_dict:
            output = (text_logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output
        
        return Qwen2_5_VLCausalLMOutputWithPast(
            loss=loss,
            logits=text_logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            rope_deltas=self.rope_deltas,
        )

In [10]:
# Create an instance of your custom AnyToAnyQwen2_5_VL class
model = AnyToAnyQwen2_5_VL(base_model.config)

# Transfer the weights from the original model to your custom model
model.load_state_dict(base_model.state_dict(), strict=False)

model.setup_vision_head()

model = model.to('cuda')
model = model.to(torch.bfloat16)

In [11]:
# torch.set_printoptions(profile="default")

In [12]:
out = model(input_ids = inputs['input_ids'], pixel_values = inputs['pixel_values'], attention_mask =
inputs['attention_mask'], image_grid_thw = inputs['image_grid_thw'])


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


vision_positions: tensor([[False, False, False,  ..., False, False, False]], device='cuda:0')
text_logits.shape: torch.Size([1, 1076, 151936])
vision_predictions.shape: torch.Size([1, 1076, 2048])
labels: tensor([[151644,   8948,    198,  ..., 151644,  77091,    198]],
       device='cuda:0'), shape: torch.Size([1, 1076])
vision_positions: tensor([[False, False, False,  ..., False, False, False]], device='cuda:0'), shape: torch.Size([1, 1076])
shift_logits: tensor([[[10.5625, 18.8750, 18.0000,  ...,  8.5000,  8.5000,  8.5000],
         [ 9.8750, 14.8125,  9.0000,  ...,  6.7812,  6.7812,  6.7812],
         [15.6875, 18.5000, 21.1250,  ...,  7.1875,  7.1875,  7.1875],
         ...,
         [ 3.8906, -2.0469, -6.6250,  ...,  1.9844,  1.9844,  1.9844],
         [10.8125, 12.8125,  9.8750,  ...,  6.0625,  6.0625,  6.0625],
         [18.3750, 17.7500, 15.5625,  ...,  6.9062,  6.9062,  6.9375]]],
       device='cuda:0', grad_fn=<SliceBackward0>), shape: torch.Size([1, 1075, 151936])
shift_la

In [13]:
print("Output items:")
for key, value in out.items():
    if hasattr(value, 'shape'):
        print(f"{key}: shape={value.shape}, value={value}")
    else:
        print(f"{key}: {value}")

Output items:
loss: shape=torch.Size([]), value=13.872798919677734
logits: shape=torch.Size([1, 1076, 151936]), value=tensor([[[10.5625, 18.8750, 18.0000,  ...,  8.5000,  8.5000,  8.5000],
         [ 9.8750, 14.8125,  9.0000,  ...,  6.7812,  6.7812,  6.7812],
         [15.6875, 18.5000, 21.1250,  ...,  7.1875,  7.1875,  7.1875],
         ...,
         [10.8125, 12.8125,  9.8750,  ...,  6.0625,  6.0625,  6.0625],
         [18.3750, 17.7500, 15.5625,  ...,  6.9062,  6.9062,  6.9375],
         [15.8125, 18.3750, 18.0000,  ...,  6.5938,  6.5938,  6.5938]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>)
past_key_values: <transformers.cache_utils.DynamicCache object at 0x7f1eca928b60>
rope_deltas: shape=torch.Size([1, 1]), value=tensor([[-966]], device='cuda:0')


In [14]:
generated_ids = model.generate(**inputs, max_new_tokens=512)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)

Setting `pad_token_id` to `eos_token_id`:151645 for open-end generation.


vision_positions: tensor([[False, False, False,  ..., False, False, False]], device='cuda:0')
text_logits.shape: torch.Size([1, 1076, 151936])
vision_predictions.shape: torch.Size([1, 1076, 2048])
labels: tensor([[151644,   8948,    198,  ..., 151644,  77091,    198]],
       device='cuda:0'), shape: torch.Size([1, 1076])
vision_positions: tensor([[False, False, False,  ..., False, False, False]], device='cuda:0'), shape: torch.Size([1, 1076])
shift_logits: tensor([[[10.5625, 18.8750, 18.0000,  ...,  8.5000,  8.5000,  8.5000],
         [ 9.8750, 14.8125,  9.0000,  ...,  6.7812,  6.7812,  6.7812],
         [15.6875, 18.5000, 21.1250,  ...,  7.1875,  7.1875,  7.1875],
         ...,
         [ 3.8906, -2.0469, -6.6250,  ...,  1.9844,  1.9844,  1.9844],
         [10.8125, 12.8125,  9.8750,  ...,  6.0625,  6.0625,  6.0625],
         [18.3750, 17.7500, 15.5625,  ...,  6.9062,  6.9062,  6.9375]]],
       device='cuda:0'), shape: torch.Size([1, 1075, 151936])
shift_labels: tensor([[  8948,    