# Setup

In [1]:
from vlm_injector_train import HeatmapInjectionExperiment

config = "/workspace/config/qwen2.5_heat.yaml"
experiment = HeatmapInjectionExperiment(config)

In [2]:
from src.qwen2_5.fa_model import Qwen2_5_VLForConditionalGenerationWithHeatmap

import torch
from transformers import AutoProcessor, AutoConfig, PreTrainedModel, ProcessorMixin, TrainingArguments, Trainer
from datasets import Dataset
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
)

hf_config = AutoConfig.from_pretrained(experiment.cfg.model.name, trust_remote_code=True)
hf_config.vision_config.latent_dim = 512

model = Qwen2_5_VLForConditionalGenerationWithHeatmap.from_pretrained(
    experiment.cfg.model.name,
    config=hf_config,
    # ignore_mismatched_sizes=True,
    **experiment.cfg.model.kwargs
)
processor = AutoProcessor.from_pretrained(experiment.cfg.model.name)
processor.tokenizer.padding_side = 'left'

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of Qwen2_5_VLForConditionalGenerationWithHeatmap were not initialized from the model checkpoint at ./Qwen2.5-VL-3B-Instruct and are newly initialized: ['heat_embedding.linear1.bias', 'heat_embedding.linear1.weight', 'heat_embedding.linear2.bias', 'heat_embedding.linear2.weight', 'visual.blocks.32.attn.kv_proj.bias', 'visual.blocks.32.attn.kv_proj.weight', 'visual.blocks.32.attn.proj.bias', 'visual.blocks.32.attn.proj.weight', 'visual.blocks.32.attn.q_proj.bias', 'visual.blocks.32.attn.q_proj.weight', 'visual.blocks.32.mlp.down_proj.bias', 'visual.blocks.32.mlp.down_proj.weight', 'visual.blocks.32.mlp.gate_proj.bias', 'visual.blocks.32.mlp.gate_proj.weight', 'visual.blocks.32.mlp.up_proj.bias', 'visual.blocks.32.mlp.up_proj.weight', 'visual.blocks.32.norm0.weight', 'visual.blocks.32.norm1.weight', 'visual.blocks.32.norm2.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using a slow image processor as `us

In [3]:
# import torch
# from transformers import AutoProcessor, AutoConfig, PreTrainedModel, ProcessorMixin, TrainingArguments, Trainer
# from datasets import Dataset
# from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
#     Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLCausalLMOutputWithPast
# )

# hf_config = AutoConfig.from_pretrained(experiment.cfg.model.name, trust_remote_code=True)
# hf_config.vision_config.latent_dim = 512

# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     experiment.cfg.model.name,
#     config=hf_config,
#     # ignore_mismatched_sizes=True,
#     **experiment.cfg.model.kwargs
# )
# processor = AutoProcessor.from_pretrained(experiment.cfg.model.name)
# processor.tokenizer.padding_side = 'left'

In [4]:
from src.common.dataset import DataCollator

experiment.prepare_dataset()
data_collator = DataCollator(processor).data_collator

In [5]:
batch = experiment.train_dataset.select(range(5))
# batch = data_collator(batch)

In [6]:
from typing import List, Dict, Any

import torch
from qwen_vl_utils import process_vision_info
from transformers import PreTrainedTokenizer

from src.common.templates import messages_template, answer_template
from src.common.transforms import get_heatmap_transformation


def find_substring(input_ids: torch.Tensor, ref_ids: List[int]):
    start_index = -1
    for i in range(len(input_ids) - len(ref_ids) + 1):
        if input_ids[i: i + len(ref_ids)].tolist() == ref_ids:
            start_index = i
            break
    if start_index == -1:
        raise ValueError("Target sequence not found.")
    end_index = start_index + len(ref_ids)
    return start_index, end_index


def create_labels(input_ids: torch.Tensor, answers: List[str], tokenizer: PreTrainedTokenizer) -> torch.Tensor:
    """
    Create labels for SFT training. It masks all tokens after the start token with excluding_probability
    and after end token for the rest.
    Args:
        input_ids: ids from tokenizer output
        ....

    Returns: tensor with masks  for each input_ids
    """

    labels = torch.full_like(input_ids, fill_value=-100)

    for i, row in enumerate(input_ids):
        start_index, end_index = find_substring(
            row, tokenizer(answers[i], add_special_tokens=False)["input_ids"]
        )
        labels[i, start_index:end_index] = row[start_index:end_index]

    return labels


def process_injection(image_grid_thw, features):
    heatmap_flat = []
    for thw, feature in zip(image_grid_thw, features):
        _, h, w = thw
        transformation = get_heatmap_transformation(h, w)
        heatmap_flat.append(transformation(feature["heatmap"]).unsqueeze(1))

    return torch.stack(heatmap_flat)


class DataCollator:
    def __init__(self, processor):
        self.processor = processor

    def data_collator(self, features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        if not features:
            return {}

        messages = []
        answers = []
        for feature in features:
            messages.append(messages_template(feature["image"], feature["transcribation"]))
            answers.append(answer_template.format(ans_text=feature["transcribation"]))

        texts = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
        image_inputs, _ = process_vision_info(messages)
        batch = self.processor(
            text=texts, images=image_inputs, padding=True, return_tensors="pt"
        )  # ['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw']

        batch["labels"] = create_labels(batch["input_ids"], answers, self.processor.tokenizer)
        batch["heatmap_flat"] = process_injection(batch["image_grid_thw"], features)

        return batch
data_collator = DataCollator(processor).data_collator

In [7]:
batch = data_collator(batch)

In [8]:
model(
    input_ids=batch["input_ids"].cuda(), 
    pixel_values=batch["pixel_values"].to(dtype=model.dtype).cuda(),
    heatmap_flat=batch['heatmap_flat'].to(dtype=model.dtype).cuda(),
    **{key: batch[key].cuda() for key in ['attention_mask', 'image_grid_thw', 'labels']})

CrossAttn: attn_output shape after flash_attn: torch.Size([16000, 16, 80])
CrossAttn: final attn_output shape after projection: torch.Size([16000, 1280])


Qwen2_5_VLCausalLMOutputWithPast(loss=tensor(3.3421, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[[11.0000, 12.1875, 10.8750,  ...,  1.4297,  1.4297,  1.4297],
         [11.0000, 17.6250, 17.5000,  ...,  8.5625,  8.5625,  8.5625],
         [10.3750, 15.3125,  9.5000,  ...,  6.9062,  6.9062,  6.9062],
         ...,
         [10.9375,  7.6250,  9.4375,  ...,  3.5000,  3.5000,  3.5000],
         [10.4375,  3.6562,  3.3125,  ...,  1.0859,  1.0859,  1.0859],
         [ 6.4688,  4.0312, -0.9492,  ..., -0.1016, -0.1016, -0.1016]],

        [[11.0000, 17.6250, 17.5000,  ...,  8.5625,  8.5625,  8.5625],
         [10.3750, 15.3125,  9.5000,  ...,  6.9062,  6.9062,  6.9062],
         [15.5000, 18.3750, 20.8750,  ...,  7.0000,  7.0000,  7.0000],
         ...,
         [10.5000,  7.1875,  8.3750,  ...,  3.2656,  3.2656,  3.2656],
         [10.2500,  3.1875,  3.2500,  ...,  1.0469,  1.0469,  1.0469],
         [ 5.6562,  2.8438, -2.3594,  ..., -0.5625, -0.5625, -0.5625]],

        [[

In [41]:
import torch
import torch.nn as nn

class HeatmapEmbeddingLayer(nn.Module):
    def __init__(self, hidden_state: int):
        super().__init__()
        self.linear1 = nn.Linear(1, hidden_state)
        self.activation = nn.SiLU()
        self.linear2 = nn.Linear(hidden_state, hidden_state)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        h1 = self.linear1(x)
        output = self.linear2(self.activation(h1))
        
        return output

layer = TransformHiddenStateLayer(1028)
out = layer(batch["heatmap_flat"].view(-1, 1))
out.shape

torch.Size([32000, 1028])

In [5]:
self = model
output_attentions = None
output_hidden_states = None
return_dict = None
inputs_embeds = None

input_ids = batch["input_ids"].cuda()
pixel_values = batch["pixel_values"].cuda()
grid_thw=batch["image_grid_thw"]


output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if inputs_embeds is None:
    inputs_embeds = self.model.embed_tokens(input_ids)
    if pixel_values is not None:
        pixel_values = pixel_values.type(self.visual.dtype)
        # heatmap_flat = heatmap_flat.reshape(-1, 1).type(self.visual.dtype)

In [6]:
model.visual

Qwen2_5_VisionTransformerPretrainedModel(
  (patch_embed): Qwen2_5_VisionPatchEmbed(
    (proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
  )
  (rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding()
  (blocks): ModuleList(
    (0-31): 32 x Qwen2_5_VLVisionBlock(
      (norm1): Qwen2RMSNorm((1280,), eps=1e-06)
      (norm2): Qwen2RMSNorm((1280,), eps=1e-06)
      (attn): Qwen2_5_VLVisionFlashAttention2(
        (qkv): Linear(in_features=1280, out_features=3840, bias=True)
        (proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (mlp): Qwen2_5_VLMLP(
        (gate_proj): Linear(in_features=1280, out_features=3420, bias=True)
        (up_proj): Linear(in_features=1280, out_features=3420, bias=True)
        (down_proj): Linear(in_features=3420, out_features=1280, bias=True)
        (act_fn): SiLU()
      )
    )
  )
  (merger): Qwen2_5_VLPatchMerger(
    (ln_q): Qwen2RMSNorm((1280,), eps=1e-06)
    (mlp): Sequential(
      (0): Linear(

In [7]:
self = model.visual
hidden_states = pixel_values

In [8]:
import math
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.nn import CrossEntropyLoss

# from ...activations import ACT2FN
# from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache
# from ...generation import GenerationMixin
# from ...modeling_attn_mask_utils import AttentionMaskConverter
# from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
# from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
# from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
# from ...modeling_utils import PreTrainedModel
# from ...utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
# from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig


hidden_states = self.patch_embed(hidden_states)
rotary_pos_emb = self.rot_pos_emb(grid_thw)
window_index, cu_window_seqlens = self.get_window_index(grid_thw)
cu_window_seqlens = torch.tensor(
    cu_window_seqlens,
    device=hidden_states.device,
    dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

seq_len, _ = hidden_states.size()
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
hidden_states = hidden_states[window_index, :, :]
hidden_states = hidden_states.reshape(seq_len, -1)
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
position_embeddings = (emb.cos(), emb.sin())

cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
    dim=0,
    # Select dtype based on the following factors:
    #  - FA2 requires that cu_seqlens_q must have dtype int32
    #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
    # See https://github.com/huggingface/transformers/pull/34852 for more information
    dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

In [9]:
cu_seqlens.shape

torch.Size([11])

In [10]:
for layer_num, blk in enumerate(self.blocks):
    break

In [11]:
cu_seqlens.shape, cu_window_seqlens.shape

(torch.Size([11]), torch.Size([501]))

In [12]:
hidden_states = blk(hidden_states, cu_seqlens=cu_window_seqlens, position_embeddings=position_embeddings)
hidden_states.shape

torch.Size([32000, 1280])

In [13]:
# for layer_num, blk in enumerate(self.blocks):
#     if layer_num in self.fullatt_block_indexes:
#         cu_seqlens_now = cu_seqlens
#     else:
#         cu_seqlens_now = cu_window_seqlens
#     if self.gradient_checkpointing and self.training:
#         hidden_states = self._gradient_checkpointing_func(
#             blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings
#         )
#     else:
#         hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)

# hidden_states = self.merger(hidden_states)
# reverse_indices = torch.argsort(window_index)
# hidden_states = hidden_states[reverse_indices, :]

In [14]:
hidden_states.shape

torch.Size([32000, 1280])

In [15]:
self.blocks[0]

Qwen2_5_VLVisionBlock(
  (norm1): Qwen2RMSNorm((1280,), eps=1e-06)
  (norm2): Qwen2RMSNorm((1280,), eps=1e-06)
  (attn): Qwen2_5_VLVisionFlashAttention2(
    (qkv): Linear(in_features=1280, out_features=3840, bias=True)
    (proj): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (mlp): Qwen2_5_VLMLP(
    (gate_proj): Linear(in_features=1280, out_features=3420, bias=True)
    (up_proj): Linear(in_features=1280, out_features=3420, bias=True)
    (down_proj): Linear(in_features=3420, out_features=1280, bias=True)
    (act_fn): SiLU()
  )
)

# HeatFA

In [16]:
context_features = torch.rand(hidden_states.shape, dtype=hidden_states.dtype)

In [17]:
context_features, context_features.shape, hidden_states.shape

(tensor([[0.4375, 0.9492, 0.4141,  ..., 0.8008, 0.9961, 0.5117],
         [0.3008, 0.6992, 0.1992,  ..., 0.3906, 0.5195, 0.5820],
         [0.9961, 0.3008, 0.1914,  ..., 0.8398, 0.0664, 0.1250],
         ...,
         [0.6133, 0.0625, 0.8594,  ..., 0.8633, 0.8828, 0.7188],
         [0.8125, 0.3711, 0.5625,  ..., 0.5078, 0.7422, 0.2734],
         [0.5000, 0.5156, 0.9766,  ..., 0.4570, 0.4258, 0.4297]],
        dtype=torch.bfloat16),
 torch.Size([32000, 1280]),
 torch.Size([32000, 1280]))

In [18]:
import torch
import torch.nn as nn
from typing import Optional, Tuple
import logging

from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import flash_attn_varlen_func, apply_rotary_pos_emb_flashatt


logger = logging.getLogger(__name__)


class Qwen2_5_VLCrossAttentionFlashAttention2(nn.Module):
    """
    Модифицированный Attention модуль для Cross-Attention с использованием FlashAttention v2.
    Query (Q) берется из `context_features` (например, признаки тепловых карт).
    Key (K) и Value (V) берутся из `hidden_states` (например, визуальные признаки).
    Ротационные эмбеддинги (`position_embeddings`), предназначенные для `hidden_states`,
    применяются к Q и K для сохранения структуры (требует совпадения длин!).
    """
    def __init__(self, dim: int, dim_context: Optional[int] = None, num_heads: int = 16, bias: bool = True) -> None:
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        if dim % num_heads != 0:
            raise ValueError(f"`dim` ({dim}) должен быть кратен `num_heads` ({num_heads})")

        if dim_context is None:
            dim_context = dim # По умолчанию контекст имеет ту же размерность, что и hidden_states

        # Линейный слой для проекции context_features -> Q
        # Вход: dim_context, Выход: dim (чтобы соответствовать K/V по размерности головы)
        self.q_proj = nn.Linear(dim_context, dim, bias=bias)
        # Линейный слой для проекции hidden_states -> K, V
        # Вход: dim, Выход: dim * 2
        self.kv_proj = nn.Linear(dim, dim * 2, bias=bias)
        # Выходной проекционный слой
        self.proj = nn.Linear(dim, dim) # Выходная размерность соответствует Q

    def forward(
        self,
        hidden_states: torch.Tensor,                # Фичи для K, V, shape: (total_seq_len_kv, dim)
        context_features: torch.Tensor,             # Фичи для Q, shape: (total_seq_len_q, dim_context)
        cu_seqlens: torch.Tensor,                   # Кумулятивные длины для K/V (hidden_states), shape: (batch_size + 1,)
        # cu_seqlens_context: torch.Tensor,           # Кумулятивные длины для Q (context), shape: (batch_size + 1,)
        rotary_pos_emb: Optional[torch.Tensor] = None, # Устаревший способ передачи RoPE (для K/V)
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (cos, sin) для RoPE (для K/V)
                                                                                # shape: (total_seq_len_kv, rotary_dim) или совместимые
    ) -> torch.Tensor:
        """
        Args:
            hidden_states: Тензор для вычисления Key и Value. Форма (total_seq_len_kv, dim).
            context_features: Тензор для вычисления Query. Форма (total_seq_len_q, dim_context).
            cu_seqlens: Кумулятивные длины последовательностей для `hidden_states`. Форма (batch_size + 1,).
            cu_seqlens_context: Кумулятивные длины последовательностей для `context_features`. Форма (batch_size + 1,).
            rotary_pos_emb: Theta значения RoPE (устарело). Используется, если position_embeddings is None.
            position_embeddings: Кортеж (cos, sin) для RoPE. Ожидается, что они рассчитаны для `hidden_states`.

        Returns:
            Тензор выхода attention. Форма (total_seq_len_q, dim).
        """
        seq_length_kv = hidden_states.shape[0]     # Длина последовательности для K, V
        seq_length_q = context_features.shape[0]   # Длина последовательности для Q

        # --- ВАЖНАЯ ПРОВЕРКА ---
        if seq_length_q != seq_length_kv:
            logger.error(
                f"Длина последовательности для Q ({seq_length_q} из context_features) "
                f"не совпадает с длиной для K/V ({seq_length_kv} из hidden_states). "
                f"Применение одинаковых position_embeddings к Q и K в этом случае вызовет ошибку "
                f"или будет некорректным! Убедитесь, что длины совпадают, или измените логику RoPE."
            )
            # Можно либо падать с ошибкой, либо продолжить с предупреждением,
            # рискуя получить ошибку размерности в apply_rotary_pos_emb_flashatt
            # raise ValueError("Sequence lengths for Q and K/V must match for this RoPE application strategy.")
            logger.warning("Продолжение работы с несовпадающими длинами Q и K/V, возможна ошибка в RoPE!")


        # 1. Проецируем K, V из hidden_states
        # (total_kv, dim) -> (total_kv, 2 * dim) -> (total_kv, 2, num_heads, head_dim)
        kv = self.kv_proj(hidden_states).reshape(seq_length_kv, 2, self.num_heads, self.head_dim)
        # k, v shapes: (total_kv, num_heads, head_dim)
        k, v = kv.unbind(1)
        print(f"CrossAttn: k shape: {k.shape}, v shape: {v.shape} (from hidden_states)")

        # 2. Проецируем Q из context_features
        # (total_q, dim_context) -> (total_q, dim) -> (total_q, num_heads, head_dim)
        q = self.q_proj(context_features).reshape(seq_length_q, self.num_heads, self.head_dim)
        print(f"CrossAttn: q shape: {q.shape} (from context_features)")

        # 3. Получаем и применяем RoPE (к Q и K, используя эмбеддинги от K/V)
        if position_embeddings is None:
            if rotary_pos_emb is None:
                 raise ValueError("Необходимо предоставить либо position_embeddings, либо rotary_pos_emb")
            logger.warning_once(
                "Используется устаревший `rotary_pos_emb` для вычисления RoPE в CrossAttention."
            )
            # Убедимся, что rotary_pos_emb имеет правильную длину seq_length_kv
            if rotary_pos_emb.shape[0] != seq_length_kv:
                 raise ValueError(f"rotary_pos_emb имеет длину {rotary_pos_emb.shape[0]}, ожидалось {seq_length_kv}")
            emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
            cos = emb.cos()
            sin = emb.sin()
        else:
            cos, sin = position_embeddings
            # Убедимся, что cos/sin имеют правильную длину seq_length_kv
            if cos.shape[0] != seq_length_kv or sin.shape[0] != seq_length_kv:
                 raise ValueError(f"position_embeddings имеют длину {cos.shape[0]}/{sin.shape[0]}, ожидалось {seq_length_kv}")

        print(f"CrossAttn: Применяем RoPE (cos: {cos.shape}, sin: {sin.shape}) к Q ({q.shape}) и K ({k.shape})")
        # Эта функция должна применить cos/sin (длины seq_length_kv) к q (длины seq_length_q) и k (длины seq_length_kv)
        # Это будет работать без ошибок размерности только если seq_length_q == seq_length_kv
        q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
        print(f"CrossAttn: q shape after RoPE: {q.shape}, k shape after RoPE: {k.shape}")


        # 4. Вычисляем максимальные длины последовательностей
        # max_seqlen_q = (cu_seqlens_context[1:] - cu_seqlens_context[:-1]).max().item()
        max_seqlen_kv = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
        print(f"CrossAttn: max_seqlen_q: {max_seqlen_kv}, max_seqlen_kv: {max_seqlen_kv}")

        # 5. Вызываем FlashAttention для cross-attention
        # q: (total_q, num_heads, head_dim)
        # k: (total_kv, num_heads, head_dim)
        # v: (total_kv, num_heads, head_dim)

        q = q.squeeze(0)
        k = k.squeeze(0)
        print("q:", q.shape, "(total_q, num_heads, head_dim)")
        print("k:", k.shape, "(total_q, num_heads, head_dim)")
        print("v:", v.shape, "(total_q, num_heads, head_dim)")

        
        attn_output = flash_attn_varlen_func(
            q, k, v, cu_seqlens, cu_seqlens, max_seqlen_kv, max_seqlen_kv,
            causal=False, # Cross-attention не каузальное
        )
        # attn_output shape: (total_q, num_heads, head_dim)
        print(f"CrossAttn: attn_output shape after flash_attn: {attn_output.shape}")

        # 6. Решейпим и проецируем выход
        # (total_q, num_heads, head_dim) -> (total_q, dim)
        attn_output = attn_output.reshape(seq_length_q, -1)
        attn_output = self.proj(attn_output)
        print(f"CrossAttn: final attn_output shape after projection: {attn_output.shape}")

        return attn_output

In [19]:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2RMSNorm, Qwen2_5_VLMLP



class Qwen2_5_VLVisionBlockHeat(nn.Module):
    def __init__(self, config, attn_implementation: str = "sdpa") -> None:
        super().__init__()
        self.norm0 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.attn = Qwen2_5_VLCrossAttentionFlashAttention2(
            config.hidden_size, num_heads=config.num_heads
        )
        self.mlp = Qwen2_5_VLMLP(config, bias=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        context_features: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
    ) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            self.norm0(context_features),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            position_embeddings=position_embeddings,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states


fa_heat = Qwen2_5_VLVisionBlockHeat(self.config)
fa_heat

Qwen2_5_VLVisionBlockHeat(
  (norm0): Qwen2RMSNorm((1280,), eps=1e-06)
  (norm1): Qwen2RMSNorm((1280,), eps=1e-06)
  (norm2): Qwen2RMSNorm((1280,), eps=1e-06)
  (attn): Qwen2_5_VLCrossAttentionFlashAttention2(
    (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
    (kv_proj): Linear(in_features=1280, out_features=2560, bias=True)
    (proj): Linear(in_features=1280, out_features=1280, bias=True)
  )
  (mlp): Qwen2_5_VLMLP(
    (gate_proj): Linear(in_features=1280, out_features=3420, bias=True)
    (up_proj): Linear(in_features=1280, out_features=3420, bias=True)
    (down_proj): Linear(in_features=3420, out_features=1280, bias=True)
    (act_fn): SiLU()
  )
)

In [20]:
fa_heat = fa_heat.to(dtype=hidden_states.dtype).cuda()

In [21]:
fa_heat(hidden_states, context_features.cuda(), cu_seqlens=cu_window_seqlens, position_embeddings=position_embeddings)

CrossAttn: k shape: torch.Size([32000, 16, 80]), v shape: torch.Size([32000, 16, 80]) (from hidden_states)
CrossAttn: q shape: torch.Size([32000, 16, 80]) (from context_features)
CrossAttn: Применяем RoPE (cos: torch.Size([32000, 80]), sin: torch.Size([32000, 80])) к Q (torch.Size([32000, 16, 80])) и K (torch.Size([32000, 16, 80]))
CrossAttn: q shape after RoPE: torch.Size([1, 32000, 16, 80]), k shape after RoPE: torch.Size([1, 32000, 16, 80])
CrossAttn: max_seqlen_q: 64, max_seqlen_kv: 64
q: torch.Size([32000, 16, 80]) (total_q, num_heads, head_dim)
k: torch.Size([32000, 16, 80]) (total_q, num_heads, head_dim)
v: torch.Size([32000, 16, 80]) (total_q, num_heads, head_dim)
CrossAttn: attn_output shape after flash_attn: torch.Size([32000, 16, 80])
CrossAttn: final attn_output shape after projection: torch.Size([32000, 1280])


tensor([[ 0.4297, -0.2520, -0.0967,  ...,  0.1201,  0.2891,  0.2500],
        [ 0.4297, -0.2520, -0.0967,  ...,  0.1201,  0.2891,  0.2500],
        [ 0.4297, -0.2520, -0.0967,  ...,  0.1201,  0.2891,  0.2500],
        ...,
        [ 0.4297, -0.2520, -0.0967,  ...,  0.1201,  0.2891,  0.2500],
        [ 0.4297, -0.2520, -0.0967,  ...,  0.1201,  0.2891,  0.2500],
        [ 0.4297, -0.2520, -0.0967,  ...,  0.1201,  0.2891,  0.2500]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<AddBackward0>)

In [22]:
context_features.shape

torch.Size([32000, 1280])

In [None]:
data

In [16]:
import torch
import torch.nn.functional as F
# Предполагаем, что hidden_states и grid_thw уже существуют
# print(f"Начальный hidden_states shape: {hidden_states.shape}") # Можно добавить, если нужно видеть вход
# print(f"Начальный grid_thw shape: {grid_thw.shape}")      # Можно добавить, если нужно видеть вход

hidden_states = self.patch_embed(hidden_states)
print(f"После patch_embed, hidden_states shape: {hidden_states.shape}")

rotary_pos_emb = self.rot_pos_emb(grid_thw)
print(f"После rot_pos_emb, rotary_pos_emb shape: {rotary_pos_emb.shape}")
# Убедимся что rotary_pos_emb на том же девайсе
rotary_pos_emb = rotary_pos_emb.to(hidden_states.device, hidden_states.dtype)

window_index, cu_window_seqlens_list = self.get_window_index(grid_thw)
print(f"После get_window_index, window_index shape: {window_index.shape}, len(cu_window_seqlens_list): {len(cu_window_seqlens_list)}")
# Перенесем window_index на нужный девайс, если он еще не там
window_index = window_index.to(hidden_states.device)


cu_window_seqlens = torch.tensor(
    cu_window_seqlens_list, # Используем полученный список
    device=hidden_states.device,
    dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
)
print(f"После torch.tensor, cu_window_seqlens shape: {cu_window_seqlens.shape}, dtype: {cu_window_seqlens.dtype}")

cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)
print(f"После unique_consecutive, cu_window_seqlens shape: {cu_window_seqlens.shape}")


seq_len, hidden_dim = hidden_states.size() # Получаем размерность после patch_embed
print(f"Определены seq_len: {seq_len}, hidden_dim: {hidden_dim}")

# --- Решейпинг и индексация hidden_states ---
print(f"hidden_states перед решейпом в окна: {hidden_states.shape}")
# Проверка на делимость перед решейпом
if seq_len % self.spatial_merge_unit != 0:
    raise ValueError(f"seq_len ({seq_len}) не делится на spatial_merge_unit ({self.spatial_merge_unit})")
hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
print(f"hidden_states после решейпа в окна: {hidden_states.shape}")

print(f"hidden_states перед индексацией окнами (window_index): {hidden_states.shape}")
hidden_states = hidden_states[window_index, :, :]
print(f"hidden_states после индексации окнами (перегруппировка): {hidden_states.shape}")

print(f"hidden_states перед финальным решейпом: {hidden_states.shape}")
hidden_states = hidden_states.reshape(seq_len, -1) # Используем seq_len и выводим hidden_dim
print(f"hidden_states после финального решейпа (готово для блоков): {hidden_states.shape}")

# --- Решейпинг и индексация rotary_pos_emb ---
seq_len_pos, pos_dim = rotary_pos_emb.size()
print(f"rotary_pos_emb перед решейпом в окна: {rotary_pos_emb.shape} (seq_len_pos={seq_len_pos}, pos_dim={pos_dim})")
# Проверка совпадения seq_len (важно!)
if seq_len_pos != seq_len:
     # Можно добавить обработку ошибки или предупреждение
     print(f"[WARN/ERROR] seq_len у hidden_states ({seq_len}) не совпадает с seq_len у rotary_pos_emb ({seq_len_pos})!")
     # Попытка исправить размер rotary_pos_emb (может быть неверной логикой)
     if rotary_pos_emb.shape[0] > seq_len:
         rotary_pos_emb = rotary_pos_emb[:seq_len, :]
     else:
         padding = torch.zeros((seq_len - rotary_pos_emb.shape[0], pos_dim), dtype=rotary_pos_emb.dtype, device=rotary_pos_emb.device)
         rotary_pos_emb = torch.cat([rotary_pos_emb, padding], dim=0)
     print(f"[FIX (POTENTIALLY WRONG)] Изменен размер rotary_pos_emb на: {rotary_pos_emb.shape}")


# Проверка на делимость перед решейпом
if seq_len % self.spatial_merge_unit != 0:
     # Эта ошибка должна была возникнуть раньше для hidden_states, но дублируем для ясности
    raise ValueError(f"seq_len ({seq_len}) не делится на spatial_merge_unit ({self.spatial_merge_unit}) для rotary_pos_emb")
rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
print(f"rotary_pos_emb после решейпа в окна: {rotary_pos_emb.shape}")

print(f"rotary_pos_emb перед индексацией окнами (window_index): {rotary_pos_emb.shape}")
rotary_pos_emb = rotary_pos_emb[window_index, :, :]
print(f"rotary_pos_emb после индексации окнами (перегруппировка): {rotary_pos_emb.shape}")

print(f"rotary_pos_emb перед финальным решейпом: {rotary_pos_emb.shape}")
rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) # Используем seq_len и выводим pos_dim
print(f"rotary_pos_emb после финального решейпа: {rotary_pos_emb.shape}")


# --- Создание position_embeddings ---
print(f"rotary_pos_emb shape перед cat: {rotary_pos_emb.shape}")
emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
print(f"emb shape после cat: {emb.shape}")

position_embeddings = (emb.cos(), emb.sin())
print(f"position_embeddings[0] (cos) shape: {position_embeddings[0].shape}")
print(f"position_embeddings[1] (sin) shape: {position_embeddings[1].shape}")

# --- Вычисление cu_seqlens для full attention ---
# Эта часть вычисляет cu_seqlens специфичным образом
print(f"grid_thw для full attention cu_seqlens: {grid_thw.shape}, dtype: {grid_thw.dtype}")
if grid_thw.dim() > 1 and grid_thw.shape[0] > 0 and grid_thw.shape[1] == 3:
    # Вычисление на основе repeat_interleave
    repeated_hw = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0].long()) # .long() для repeat_interleave
    print(f"Full attention: repeated_hw shape: {repeated_hw.shape}")
    cu_seqlens_full = repeated_hw.cumsum(
        dim=0,
        dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
    )
    print(f"Full attention: cu_seqlens (до pad) shape: {cu_seqlens_full.shape}, dtype: {cu_seqlens_full.dtype}")

    cu_seqlens_full = F.pad(cu_seqlens_full, (1, 0), value=0)
    print(f"Full attention: cu_seqlens (после pad) shape: {cu_seqlens_full.shape}")
else:
    # Заглушка или обработка ошибки, если grid_thw не подходит
    print(f"[WARN] Не удалось вычислить full attention cu_seqlens из grid_thw формы {grid_thw.shape}. Создается заглушка.")
    # Стандартная заглушка для батча из 1 элемента
    cu_seqlens_full = torch.tensor([0, seq_len], dtype=torch.int32, device=hidden_states.device)
    print(f"Full attention: cu_seqlens (заглушка) shape: {cu_seqlens_full.shape}")


# --- Цикл по блокам ---
for layer_num, blk in enumerate(self.blocks):
    print(f"\n--- Блок {layer_num} ---")
    print(f"Вход в блок {layer_num}, hidden_states shape: {hidden_states.shape}")

    if layer_num in self.fullatt_block_indexes:
        cu_seqlens_now = cu_seqlens_full # Используем рассчитанные выше
        print(f"Используем Full Attention cu_seqlens (форма {cu_seqlens_now.shape})")
    else:
        cu_seqlens_now = cu_window_seqlens # Используем оконные
        print(f"Используем Windowed Attention cu_window_seqlens (форма {cu_seqlens_now.shape})")

    # Приведение типа для FA2, если не трассировка
    if not torch.jit.is_tracing():
        cu_seqlens_now = cu_seqlens_now.to(torch.int32)
        # print(f"Привели cu_seqlens_now к dtype: {cu_seqlens_now.dtype}") # Можно раскомментировать для отладки

    if self.gradient_checkpointing and self.training:
        print(f"Блок {layer_num}: Используем Gradient Checkpointing")
        hidden_states = self._gradient_checkpointing_func(
            blk.__call__, hidden_states, cu_seqlens_now, None, position_embeddings
        )
    else:
        # Передаем аргументы в блок
        hidden_states = blk(hidden_states, cu_seqlens=cu_seqlens_now, position_embeddings=position_embeddings)

    print(f"Выход из блока {layer_num}, hidden_states shape: {hidden_states.shape}")

print("\n--- После всех блоков ---")
print(f"hidden_states перед merger: {hidden_states.shape}")
hidden_states = self.merger(hidden_states)
print(f"hidden_states после merger: {hidden_states.shape}")

# --- Восстановление порядка ---
print(f"hidden_states перед обратной индексацией: {hidden_states.shape}")
reverse_indices = torch.argsort(window_index)
print(f"reverse_indices shape: {reverse_indices.shape}")
hidden_states = hidden_states[reverse_indices, :]
print(f"hidden_states после обратной индексации (финальный выход сегмента): {hidden_states.shape}")

# return hidden_states # Возвращаем результат

После patch_embed, hidden_states shape: torch.Size([32000, 1280])
После rot_pos_emb, rotary_pos_emb shape: torch.Size([32000, 40])
После get_window_index, window_index shape: torch.Size([8000]), len(cu_window_seqlens_list): 661
После torch.tensor, cu_window_seqlens shape: torch.Size([661]), dtype: torch.int32
После unique_consecutive, cu_window_seqlens shape: torch.Size([501])
Определены seq_len: 32000, hidden_dim: 1280
hidden_states перед решейпом в окна: torch.Size([32000, 1280])
hidden_states после решейпа в окна: torch.Size([8000, 4, 1280])
hidden_states перед индексацией окнами (window_index): torch.Size([8000, 4, 1280])
hidden_states после индексации окнами (перегруппировка): torch.Size([8000, 4, 1280])
hidden_states перед финальным решейпом: torch.Size([8000, 4, 1280])
hidden_states после финального решейпа (готово для блоков): torch.Size([32000, 1280])
rotary_pos_emb перед решейпом в окна: torch.Size([32000, 40]) (seq_len_pos=32000, pos_dim=40)
rotary_pos_emb после решейпа в ок

RuntimeError: cu_seqlens_q must be on CUDA

In [24]:
cu_seqlens_full

tensor([    0,  3200,  6400,  9600, 12800, 16000, 19200, 22400, 25600, 28800,
        32000], dtype=torch.int32)