In [1]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"


In [2]:
import torch
from transformers import AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info

path = "Archistrax/Qwen2_5_VL"

processor = AutoProcessor.from_pretrained(path)
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [10]:
from transformers import AutoConfig

from src.qwen2_5.model import Qwen2_5_VLForConditionalGenerationWithHeatmap


config = AutoConfig.from_pretrained(path, trust_remote_code=True)
config.vision_config.latent_dim = 256

model = Qwen2_5_VLForConditionalGenerationWithHeatmap.from_pretrained(
    path,
    torch_dtype=torch.bfloat16,
    config=config,
    attn_implementation="flash_attention_2",
    device_map="auto",
    trust_remote_code=True
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
hidden_states = model.visual(inputs["pixel_values"].cuda(), inputs["image_grid_thw"].cuda())
patches, _ = hidden_states.shape
heatmap = torch.zeros((patches, 1))  # dummy heatmap [0..1]

heat_inputs = inputs.copy()
heat_inputs["heatmap_flat"] = heatmap.to(torch.bfloat16)  # 
heat_inputs = heat_inputs.to("cuda")
inputs = inputs.to("cuda")

In [5]:
inputs["pixel_values"].shape

torch.Size([14308, 1176])

In [6]:
hidden_states.shape

torch.Size([3577, 2048])

In [154]:
def gen_inputs(inputs):
    generated_ids = model.generate(**inputs, max_new_tokens=128)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]
    output_text = processor.batch_decode(
        generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
    )
    print(output_text)
    return generated_ids, generated_ids_trimmed

# generated_ids_heat, _ = gen_inputs(heat_inputs)

In [7]:
from datasets import Dataset

set_dataset_off = Dataset.load_from_disk("set_eye_dataset_off")

In [78]:
patches, image_grid_thw = self._preprocess(
    image,
    do_resize=self.do_resize,
    size=size,
    resample=resample,
    do_rescale=do_rescale,
    rescale_factor=rescale_factor,
    do_normalize=do_normalize,
    image_mean=[0.48145466],
    image_std=[0.26862954],
    patch_size=patch_size,
    temporal_patch_size=temporal_patch_size,
    merge_size=merge_size,
    do_convert_rgb=do_convert_rgb,
)

In [3]:
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms

class ResizeTensor:
    """
    Простой класс-трансформ, который повторяет поведение F.interpolate(..., align_corners=False)
    для тензора формата (C, H, W).
    """
    def __init__(self, size=(14, 14), mode='bilinear', align_corners=False):
        self.size = size
        self.mode = mode
        self.align_corners = align_corners

    def __call__(self, tensor):
        tensor = tensor.unsqueeze(0)  # -> (1, C, H, W)
        tensor = F.interpolate(
            tensor, size=self.size,
            mode=self.mode,
            align_corners=self.align_corners
        )  # -> (1, C, newH, newW)
        return tensor.squeeze(0)      # -> (C, newH, newW)

def get_heatmap_transformation(h, w):
    return transforms.Compose([
        transforms.Lambda(lambda img: img.convert("L")),
        transforms.ToTensor(),  # теперь тензор формата (1, H, W)
        transforms.Lambda(lambda t: (t - t.mean()) / (t.std() + 1e-8)),
        ResizeTensor(size=(int(h), int(w)), mode='bilinear', align_corners=False),
        transforms.Lambda(lambda t: t.flatten())
    ])
# transformation = get_heatmap_transformation(h/2, w/2)
# transformation(set_dataset_off[0]["heatmap"]).shape

In [4]:
from typing import List, Dict, Any


image_folder = "./images/" 
messages_template = lambda image, transcription: [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": image_folder + image,
                "resized_height": 14*40,
                "resized_width": 14*80,
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }, 
    {
        "role": "assistant",
        "content": [
            {"type": "text", "text": transcription},
        ],
    }
]



def formatting_prompt(examples):
    # Допустим, в examples есть списки: ["image"], ["transcribation"].
    images = examples["image"]
    transcriptions = examples["transcribation"]
    heats = examples["heatmap"]

    input_ids_list = []
    attention_mask_list = []
    labels_list = []
    pixel_values_list = []  
    image_grid_thw = []
    heatmaps = []

    for img_path, ans_text, heat in zip(images, transcriptions, heats):
        messages = messages_template(img_path, ans_text)
        text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

        image_inputs, _ = process_vision_info(messages)
        inputs = processor(
            text=[text_prompt],
            images=image_inputs,
            padding="max_length",
            max_length=1050,
            return_tensors="pt",
        )

        input_ids = inputs["input_ids"][0].detach()
        attention_mask = inputs["attention_mask"][0].detach()            

        labels = input_ids.clone()
        labels[:] = -100

        answer_str = f"<|im_start|>assistant\n{ans_text}<|im_end|>\n"
        answer_ids = processor.tokenizer(
            answer_str, add_special_tokens=False
        )["input_ids"]

        start_index = -1
        for i in range(len(input_ids) - len(answer_ids) + 1):
            if all(input_ids[i + j].item() == answer_ids[j]
                   for j in range(len(answer_ids))):
                start_index = i
                break

        if start_index != -1:
            labels[start_index : start_index + len(answer_ids)] = \
                input_ids[start_index : start_index + len(answer_ids)]

        input_ids_list.append(input_ids)
        attention_mask_list.append(attention_mask)
        labels_list.append(labels)
        if "pixel_values" in inputs:
            pixel_values = inputs["pixel_values"].detach()  # shape: [3, H, W]
            thw = inputs["image_grid_thw"][0].detach()
            
            _, h, w = thw
            transformation = get_heatmap_transformation(h/2, w/2)
            pixel_values_list.append(pixel_values)
            image_grid_thw.append(thw)
            heatmap = transformation(heat).to(torch.bfloat16).unsqueeze(1)
            heatmaps.append(heatmap)

    return {
        "input_ids": input_ids_list,
        "attention_mask": attention_mask_list,
        "labels": labels_list,
        "pixel_values": pixel_values_list,
        "image_grid_thw": image_grid_thw,
        "heatmap_flat": heatmaps,
    }

In [82]:
f"<|im_start|>assistant\n{ans_text}<|im_end|>\n", text_prompt

('<|im_start|>assistant\nThe painting depicts some kind of flowers in an abstract form against a red background, with the flowers appearing multicolored.<|im_end|>\n',
 '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>\n<|im_start|>assistant\nThe painting depicts some kind of flowers in an abstract form against a red background, with the flowers appearing multicolored.<|im_end|>\n')

In [96]:
examples = set_dataset_off[:2]

In [56]:
images = examples["image"]
transcriptions = examples["transcribation"]
heats = examples["heatmap"]

img_path, ans_text = images[0], transcriptions[0]
messages = messages_template(img_path, ans_text)
text_prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)

image_inputs, _ = process_vision_info(messages)
inputs = processor(
    text=[text_prompt],
    images=image_inputs,
    padding="max_length",
    max_length=1050,
    return_tensors="pt",
)

input_ids = inputs["input_ids"][0].detach()
attention_mask = inputs["attention_mask"][0].detach()            

labels = input_ids.clone()
labels[:] = -100

In [161]:
from typing import List


def find_substring(input_ids:torch.Tensor, ref_ids: List[int]):
    start_index = -1
    for i in range(len(input_ids) - len(ref_ids) + 1):
        if input_ids[i : i + len(ref_ids)].tolist() == ref_ids:
            start_index = i
            break
    if start_index == -1:
        raise ValueError("Target sequence not found.")
    end_index = start_index + len(ref_ids)
    return start_index, end_index


def create_labels(input_ids: torch.Tensor, answers: List[str]) -> torch.Tensor:
    """
    Create labels for SFT training. It masks all tokens after the start token with excluding_probability
    and after end token for the rest.
    Args:
        input_ids: ids from tokenizer output
        start_token: token that indicates start of selected sequence (for example, simple talk)
        end_token: token that indicates end of selected sequence
        excluding_probability: probability of excluding simple talk from attention

    Returns: tensor with masks  for each input_ids
    """

    labels = torch.full_like(input_ids, fill_value=-100)
    
    for i, row in enumerate(input_ids):
        start_index, end_index = find_substring(row, processor.tokenizer(answers[i], add_special_tokens=False)["input_ids"])
        labels[i, start_index:end_index] = row[start_index:end_index]

    return labels

In [162]:
answer_template = "<|im_start|>assistant\n{ans_text}<|im_end|>\n"

def process_injection(image_grid_thw, features):
    heatmap_flat = []
    for thw, feature in zip(image_grid_thw, features):
        _, h, w = thw
        transformation = get_heatmap_transformation(h/2, w/2)
        heatmap_flat.append(transformation(feature["heatmap"]).unsqueeze(1))

    return torch.stack(heatmap_flat)


def data_collator(features: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
    if not features:
        return {}

    messages = []
    answers = []
    for feature in features:
        messages.append(messages_template(feature["image"], feature["transcribation"]))
        answers.append(answer_template.format(ans_text=feature["transcribation"]))

    texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    image_inputs, _ = process_vision_info(messages)
    batch = processor(text=texts, images=image_inputs, padding=True, return_tensors="pt")  # ['input_ids', 'attention_mask', 'pixel_values', 'image_grid_thw']

    batch["labels"] = create_labels(batch["input_ids"], answers)
    batch["heatmap_flat"] = process_injection(batch["image_grid_thw"], features)

    return batch
    

In [163]:
train_data = set_dataset_off.select(range(10))

In [266]:
out_data.save_to_disk("exploded_dataset")

Saving the dataset (0/2 shards):   0%|          | 0/16380 [00:00<?, ? examples/s]

In [11]:
for param in model.parameters():
    param.requires_grad = False
for param in model.visual.post_merger_injector.heatmap_proj.parameters():
    param.requires_grad = True

In [5]:
from datasets import Dataset

out_data = Dataset.load_from_disk("exploded_dataset")

In [160]:
import torch
from trl import SFTTrainer
from transformers import TrainingArguments, TrainerCallback
# from clearml import Task
from dotenv import load_dotenv

load_dotenv()

training_args = TrainingArguments(
    output_dir="./qwen-checkpoints",
    num_train_epochs=40,
    per_device_train_batch_size=2, #8?
    gradient_accumulation_steps=2,
    learning_rate=3e-4,
    weight_decay=0.01,
    lr_scheduler_type="cosine", #linear?
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    fp16=False,
    bf16=True,
    gradient_checkpointing=True,
    optim="adamw_8bit",
    push_to_hub=True,
    report_to=[],
    hub_model_id="Archistrax/Qwen2_5_VL-checkpoints",
)


# class ClearMLCallback(TrainerCallback):
#     def __init__(self, task):
#         self.task = task
#         self.logger = task.get_logger()

#     def on_log(self, args, state, control, logs=None, **kwargs):
#         if logs:
#             for key, value in logs.items():
#                 self.logger.report_scalar(title="Training", series=key, value=value, iteration=state.global_step)


# experiment_name = "qwen2.5-full"
# task = Task.init(
#     project_name="qwen2.5",
#     task_name=experiment_name,
#     output_uri=False
# )


tokenizer = processor.tokenizer
# clearml_callback = ClearMLCallback(task)
# logging.getLogger("clearml").setLevel(logging.CRITICAL)


trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=train_data,
    data_collator=data_collator,
    args=training_args,
    # callbacks=[clearml_callback]
)

  trainer = SFTTrainer(


Converting train dataset to ChatML:   0%|          | 0/10 [00:00<?, ? examples/s]

Applying chat template to train dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/10 [00:00<?, ? examples/s]

KeyError: 'text'

In [13]:

trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


ValueError: Expected input batch_size (2046) to match target batch_size (2098).

In [14]:
train_loader = trainer.get_train_dataloader()

# 2. Берём из него ровно один батч
batch = next(iter(train_loader))

# 3. Можем посмотреть, что в батче
for k, v in batch.items():
    print(k, type(v), v.shape if hasattr(v, "shape") else v)

# 4. Прогоняем через модель (или делаем что угодно)
with torch.no_grad():
    outputs = trainer.model(**batch)
    print(outputs)

input_ids <class 'torch.Tensor'> torch.Size([2, 1024])
labels <class 'torch.Tensor'> torch.Size([2, 1050])
attention_mask <class 'torch.Tensor'> torch.Size([2, 1024])
pixel_values <class 'torch.Tensor'> torch.Size([2, 3200, 1176])
image_grid_thw <class 'torch.Tensor'> torch.Size([2, 3])
heatmap_flat <class 'torch.Tensor'> torch.Size([2, 800, 1])


ValueError: Expected input batch_size (2046) to match target batch_size (2098).

In [16]:
self = trainer.model
input_ids = batch["input_ids"]
inputs_embeds = self.model.embed_tokens(input_ids)

mask = input_ids == self.config.image_token_id
mask_unsqueezed = mask.unsqueeze(-1)
mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
image_mask = mask_expanded.to(inputs_embeds.device)

In [53]:
from typing import Optional, Tuple, List, Union

import torch

input_ids: torch.LongTensor = batch["input_ids"]
attention_mask: Optional[torch.Tensor] = batch["attention_mask"]
position_ids: Optional[torch.LongTensor] = None
past_key_values: Optional[List[torch.FloatTensor]] = None
inputs_embeds: Optional[torch.FloatTensor] = None
labels: Optional[torch.LongTensor] = batch["labels"]
use_cache: Optional[bool] = None
output_attentions: Optional[bool] = None
output_hidden_states: Optional[bool] = None
return_dict: Optional[bool] = None
pixel_values: Optional[torch.Tensor] = batch["pixel_values"]
pixel_values_videos: Optional[torch.FloatTensor] = None
image_grid_thw: Optional[torch.LongTensor] = batch["image_grid_thw"]
video_grid_thw: Optional[torch.LongTensor] = None
rope_deltas: Optional[torch.LongTensor] = None
cache_position: Optional[torch.LongTensor] = None
second_per_grid_ts: Optional[torch.Tensor] = None
heatmap_flat=batch["heatmap_flat"]

In [54]:
self = trainer.model


In [55]:
batch["labels"].shape, batch["input_ids"].shape

(torch.Size([2, 1050]), torch.Size([2, 1024]))

In [23]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
    output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if inputs_embeds is None:
    inputs_embeds = self.model.embed_tokens(input_ids)
    if pixel_values is not None:
        pixel_values = pixel_values.type(self.visual.dtype)
        heatmap_flat = heatmap_flat.reshape(-1, 1).type(self.visual.dtype)  # (b*h*w, 1) Apply flattening for native injection as human attention
        image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw, heatmap_flat=heatmap_flat)
        n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
        n_image_features = image_embeds.shape[0]
        if n_image_tokens != n_image_features:
            raise ValueError(
                f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
            )

        mask = input_ids == self.config.image_token_id
        mask_unsqueezed = mask.unsqueeze(-1)
        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
        image_mask = mask_expanded.to(inputs_embeds.device)

        image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)

    if pixel_values_videos is not None:
        pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
        video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
        n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
        n_video_features = video_embeds.shape[0]
        if n_video_tokens != n_video_features:
            raise ValueError(
                f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
            )

        mask = input_ids == self.config.video_token_id
        mask_unsqueezed = mask.unsqueeze(-1)
        mask_expanded = mask_unsqueezed.expand_as(inputs_embeds)
        video_mask = mask_expanded.to(inputs_embeds.device)

        video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
        inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)

    if attention_mask is not None:
        attention_mask = attention_mask.to(inputs_embeds.device)

if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
    # calculate RoPE index once per generation in the pre-fill stage only
    if (
            (cache_position is not None and cache_position[0] == 0)
            or self.rope_deltas is None
            or (past_key_values is None or past_key_values.get_seq_length() == 0)
    ):
        position_ids, rope_deltas = self.get_rope_index(
            input_ids,
            image_grid_thw,
            video_grid_thw,
            second_per_grid_ts,
            attention_mask,
        )
        self.rope_deltas = rope_deltas
    # then use the prev pre-calculated rope-deltas to get the correct position ids
    else:
        batch_size, seq_length, _ = inputs_embeds.shape
        delta = (
            (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
            if cache_position is not None
            else 0
        )
        position_ids = torch.arange(seq_length, device=inputs_embeds.device)
        position_ids = position_ids.view(1, -1).expand(batch_size, -1)
        if cache_position is not None:  # otherwise `deltas` is an int `0`
            delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
        position_ids = position_ids.add(delta)
        position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)

outputs = self.model(
    input_ids=None,
    position_ids=position_ids,
    attention_mask=attention_mask,
    past_key_values=past_key_values,
    inputs_embeds=inputs_embeds,
    use_cache=use_cache,
    output_attentions=output_attentions,
    output_hidden_states=output_hidden_states,
    return_dict=return_dict,
    cache_position=cache_position,
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)

In [48]:
logits.shape, labels.shape, hidden_states.shape

(torch.Size([2, 1024, 151936]),
 torch.Size([1050]),
 torch.Size([2, 1024, 2048]))

In [51]:
mask_expanded.shape, image_mask.shape, mask_unsqueezed.shape

(torch.Size([2, 1024, 2048]),
 torch.Size([2, 1024, 2048]),
 torch.Size([2, 1024, 1]))

In [52]:
batch["heatmap_flat"].reshape(-1, 1).to(dtype=model.dtype)

tensor([[-0.3008],
        [-0.3008],
        [-0.3008],
        ...,
        [-0.3008],
        [-0.3008],
        [-0.3008]], device='cuda:0', dtype=torch.bfloat16)

In [39]:
pixel_values, image_grid_thw = batch["pixel_values"], batch["image_grid_thw"]

image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw, heatmap_flat=batch["heatmap_flat"].reshape(-1, 1).to(dtype=model.dtype))

In [41]:
image_embeds.shape

torch.Size([6400, 2048])

In [25]:
image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
inputs_embeds.masked_scatter(image_mask, image_embeds).shape

torch.Size([8, 1024, 2048])

In [24]:
image_embeds.shape

torch.Size([6400, 2048])

In [47]:
import transformers
from typing import Sequence
from dataclasses import dataclass, field


@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
        input_ids, attention_mask, labels = tuple(
            [instance[key] for instance in instances]
            for key in ('input_ids', 'attention_mask', 'labels')
        )
        # input_ids = torch.nn.utils.rnn.pad_sequence(
        #     input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
        # )
        # labels = torch.nn.utils.rnn.pad_sequence(
        #     labels, batch_first=True, padding_value=IGNORE_INDEX
        # )
        # input_ids = input_ids[:, : self.tokenizer.model_max_length]
        # labels = labels[:, : self.tokenizer.model_max_length]
        batch = dict(
            input_ids=input_ids,
            labels=labels,
            attention_mask=attention_mask
        )
        images = list(
            itertools.chain(
                *(
                    instance["pixel_values"]
                    for instance in instances
                    if "pixel_values" in instance
                )
            )
        )
        videos = list(
            itertools.chain(
                *(
                    instance["pixel_values_videos"]
                    for instance in instances
                    if "pixel_values_videos" in instance
                )
            )
        )
        if len(images) != 0:
            concat_images = torch.cat([image for image in images], dim=0)
            grid_thw = list(
                itertools.chain(
                    *(
                        instance["image_grid_thw"]
                        for instance in instances
                        if "image_grid_thw" in instance
                    )
                )
            )
            grid_thw = torch.stack(grid_thw, dim=0)
        else:
            concat_images = None
            grid_thw = None

        if len(videos) != 0:
            concat_videos = torch.cat([video for video in videos], dim=0)
            video_grid_thw = list(
                itertools.chain(
                    *(
                        instance["video_grid_thw"]
                        for instance in instances
                        if "video_grid_thw" in instance
                    )
                )
            )
            video_grid_thw = torch.stack(video_grid_thw, dim=0)
        else:
            concat_videos = None
            video_grid_thw = None

        batch["pixel_values"] = concat_images
        batch["image_grid_thw"] = grid_thw
        batch["pixel_values_videos"] = concat_videos
        batch["video_grid_thw"] = video_grid_thw
        return batch

In [48]:
collator = DataCollatorForSupervisedDataset(tokenizer=processor.tokenizer)

In [51]:
import intertools

part = out_data.select(range(5))


ModuleNotFoundError: No module named 'intertools'

In [57]:
im = torch.tensor(part[0]['pixel_values'])

In [59]:
torch.cat([im], dim=0).shape

torch.Size([3200, 1176])

In [42]:
out_data = out_data.remove_columns(['human_id', 'image', 'suffix', 'heatmap','audio_file', 'asc_file'])

In [43]:
out_data

Dataset({
    features: ['transcribation', 'input_ids', 'attention_mask', 'labels', 'pixel_values', 'image_grid_thw', 'heatmap_flat'],
    num_rows: 100
})