In [1]:
%%capture
import os, re
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
    # !pip install --upgrade --no-cache-dir "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    import torch; v = re.match(r"[0-9\.]{3,}", str(torch.__version__)).group(0)
    xformers = "xformers==" + ("0.0.32.post2" if v == "2.8.0" else "0.0.29.post3")
    !pip install --no-deps bitsandbytes accelerate {xformers} peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth
!pip install transformers==4.56.2
!pip install --no-deps trl==0.22.2

In [7]:
import json
from unsloth import FastVisionModel
from datasets import load_dataset
from tqdm import tqdm
import torch

model, processor = FastVisionModel.from_pretrained(
    model_name = "unsloth/Llama-3.2-11B-Vision-Instruct",
    load_in_4bit = True,
)
FastVisionModel.for_inference(model)
# The model all the padding on the left side so we can extract the last token for sentence summary.
processor.tokenizer.padding_side = "left"

==((====))==  Unsloth 2025.11.6: Fast Mllama patching. Transformers: 4.56.2.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

MllamaForConditionalGeneration(
  (model): MllamaModel(
    (vision_model): MllamaVisionModel(
      (patch_embedding): Conv2d(3, 1280, kernel_size=(14, 14), stride=(14, 14), padding=valid, bias=False)
      (gated_positional_embedding): MllamaPrecomputedPositionEmbedding(
        (tile_embedding): Embedding(9, 8197120)
      )
      (pre_tile_positional_embedding): MllamaPrecomputedAspectRatioEmbedding(
        (embedding): Embedding(9, 5120)
      )
      (post_tile_positional_embedding): MllamaPrecomputedAspectRatioEmbedding(
        (embedding): Embedding(9, 5120)
      )
      (layernorm_pre): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (layernorm_post): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (transformer): MllamaVisionEncoder(
        (layers): ModuleList(
          (0-12): 13 x MllamaVisionEncoderLayer(
            (self_attn): MllamaVisionAttention(
              (q_proj): Linear4bit(in_features=1280, out_features=1280, bias=False)
       

In [4]:
from datasets import load_dataset
dataset = load_dataset("crag-mm-2025/crag-mm-single-turn-public")

Generating validation split:   0%|          | 0/1938 [00:00<?, ? examples/s]

Generating public_test split:   0%|          | 0/1936 [00:00<?, ? examples/s]

### **ELIP**

In [10]:
from torch.utils.data import Dataset, DataLoader

# A. Create a simple Dataset Wrapper
class LlamaAlignmentDataset(Dataset):
    def __init__(self, hf_dataset, processor):
        self.data = hf_dataset
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = item['image']
        prompt = item['turns']['query'][0]
        answer = item['answers']['ans_full'][0]

        return image, prompt, answer

valid_data = dataset['validation'].filter(lambda x: x['image'] is not None)

# pin_memory=True: Speeds up transfer to GPU
train_ds = LlamaAlignmentDataset(valid_data, processor)
train_loader = DataLoader(
    train_ds,
    batch_size=1,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    collate_fn=lambda x: x[0]
)

Filter:   0%|          | 0/1938 [00:00<?, ? examples/s]

In [11]:
import torch
import torch.nn as nn

class TextToVisionBridge(nn.Module):
    def __init__(self, base_model, num_extra_patches=10):
        super().__init__()
        self.model = base_model
        self.vision_model = base_model.vision_model
        self.language_model = base_model.language_model

        # Dynamically get device/dtype
        param = next(base_model.parameters())
        self.device = param.device
        self.dtype = param.dtype

        # --- Linear Bridge ---
        # Maps Text Dim (4096) -> Vision Dim (1280)
        self.text_to_patch = nn.Sequential(
            nn.Linear(4096, 5376),
            nn.GELU(),
            nn.Linear(5376, 6656),
            nn.GELU(),
            nn.Linear(6656, 9216),
            nn.GELU(),
            nn.Linear(9216, num_extra_patches * 1280)
        ).to(self.device, dtype=self.dtype)

        self.num_extra_patches = num_extra_patches

        # Freeze everything
        for p in self.vision_model.parameters(): p.requires_grad = False
        for p in self.language_model.parameters(): p.requires_grad = False

        # Unfreeze Bridge
        for p in self.text_to_patch.parameters(): p.requires_grad = True

    def forward(
        self,
        pixel_values: torch.Tensor,
        input_ids: torch.Tensor = None,
        attention_mask: torch.Tensor = None,
        aspect_ratio_ids: torch.Tensor = None,
        **kwargs,
    ) -> torch.Tensor:

        # Unpack and Flatten Images
        if pixel_values.dim() == 6:
            B, _, num_tiles, C, H, W = pixel_values.shape
        else:
            B, num_tiles, C, H, W = pixel_values.shape

        # Flatten Batch and Tiles together -> [B*Tiles, C, H, W]
        imgs = pixel_values.reshape(B * num_tiles, C, H, W)

        # 2. Get Text Summaries
        with torch.no_grad():
            text_emb = self.language_model.embed_tokens(input_ids)

        # Take last token state since it is decoder only model
        # [Batch_Size, Sequence_Length, Hidden_Dimension] -> [B, 4096]
        text_summary = text_emb[:, -1, :]

        # 3. Generate Extra Patches
        # Map to Vision Dim -> [B, Num_Patches, 1280]
        extra_patches = self.text_to_patch(text_summary).view(B, self.num_extra_patches, 1280)

        # Need [B * Tiles, Num_Patches, 1280]
        # [B, 1, N, D]
        extra_patches = extra_patches.unsqueeze(1)

        # Repeat for every tile -> [B, Tiles, N, D]
        # Duplicate the patches for every tile.
        extra_patches = extra_patches.repeat(1, num_tiles, 1, 1)

        # FLATTEN Batch and Tiles -> [B*Tiles, N, D]
        extra_patches = extra_patches.reshape(1, B * num_tiles, self.num_extra_patches, 1280)


        with torch.no_grad():
            # Convert raw pixels into patch embeddings
            x = self.vision_model.patch_embedding(imgs)

            # Flatten spatial dims (H*W) into Sequence length and transpose.
            # [Batch*Tiles, Seq_Len, 1280]
            x = x.flatten(2).transpose(1, 2)

            # Create a Learnable Class Token (CLS) initialized to zeros.
            cls_token = torch.zeros(x.shape[0], 1, x.shape[-1], device=x.device, dtype=x.dtype)

            # Append the CLS token to the sequence.
            x = torch.cat([cls_token, x], dim=1)

            # Apply Llama 3.2's special gated positional embeddings based on tile ID.
            x = self.vision_model.gated_positional_embedding(x, aspect_ratio_ids)

            # Apply Layer Normalization before the Transformer layers.
            x = self.vision_model.layernorm_pre(x)

        # Concat the extra patches
        x = torch.cat([extra_patches, x], dim=2)
        x = x.flatten(0, 1)


        # Run the Transformer Encoder Layers.
        x = self.vision_model.transformer(x)
        # Save all hidden states (used later for multi-stage feature aggregation).
        all_intermediate_hidden_states = x[1]
        x = x[0]


        with torch.no_grad():
            x = self.vision_model.layernorm_post(x)

        # Project vision features to language model dimension.
        vision_output = self.vision_model.global_transformer(x)

        # Get the indices of layers to extract features from (defined in model config). Llama extract special part of layer
        indices = self.model.config.vision_config.intermediate_layers_indices

        # Stack all layer outputs into one tensor.
        intermediate_hidden_states = torch.stack(all_intermediate_hidden_states, dim=-1)

        # Select only the specific layers required by the configuration.
        intermediate_hidden_states = intermediate_hidden_states[..., indices]

        # Combine intermediate features with the final vision output.
        output = torch.cat([intermediate_hidden_states, vision_output.last_hidden_state.unsqueeze(-1)], dim=-1)

        output = output.flatten(2)

        return output

In [12]:
import torch.nn.functional as F

def get_text_embedding(model, processor, text_list, device):
    """
    Encodes the 'Answer' text into a single vector using the LLM's embedding layer.
    We extract the last token embedding.
    """
    inputs = processor.tokenizer(
        text_list, return_tensors="pt", padding=True, truncation=True
    ).to(device)

    with torch.no_grad():
        embeddings = model.language_model.embed_tokens(inputs["input_ids"])

        # Extract the last token embedding
        last_embedding = embeddings[:, -1, :]

    return last_embedding

print("Starting Training...")
BATCH_SIZE = 1
GRAD_ACCUMULATION = 8
LEARNING_RATE = 2e-5
NUM_EPOCHS = 1
NUM_EXTRA_PATCHES = 10

bridge_model = TextToVisionBridge(model).to(model.device)
bridge_model.text_to_patch.train()
bridge_model.text_to_patch.to(torch.bfloat16)


optimizer = torch.optim.AdamW(bridge_model.text_to_patch.parameters(), lr=LEARNING_RATE)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=50
)
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    optimizer.zero_grad()

    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

    for step, (image, prompt, answer) in enumerate(train_pbar):

        prompt_text = f"{prompt} <|end_of_text|>"
        text_inputs = processor.tokenizer(prompt_text, return_tensors="pt").to(model.device)


        image_inputs = processor(images=image, return_tensors="pt")
        image_inputs = {k: v.to(model.device) for k, v in image_inputs.items()}

        if "pixel_values" in image_inputs:
            image_inputs["pixel_values"] = image_inputs["pixel_values"].to(torch.bfloat16)

        # Output shape: [B, Seq_Len, Vision_Dim]
        vision_outputs = bridge_model(
            pixel_values=image_inputs["pixel_values"],
            input_ids=text_inputs["input_ids"],
            aspect_ratio_ids=image_inputs.get("aspect_ratio_ids", None)
        )

        # Project to llm space
        projected_vision = model.model.multi_modal_projector(vision_outputs)

        # takes every single vision patch
        # Treat every vision patch as an independent embedding. Row is patches column is embedding dimension
        vision_vec = projected_vision.view(-1, 4096)

        target_vec = get_text_embedding(model, processor, [prompt_text], model.device)

        v_norm = F.normalize(vision_vec, p=2, dim=1)
        t_norm = F.normalize(target_vec, p=2, dim=1)

        similarity_map = torch.matmul(v_norm, t_norm.t()).squeeze()

        k = 1500

        # Find the scores of the best K patches
        # Instead of getting one patch or whole image mean we extract a region of patch and calculate it mean.
        # So it more like letting the model have a higher score when they pay more attention of the object text mention
        top_k_scores, top_k_indices = torch.topk(similarity_map, k)

        loss = 1 - top_k_scores.mean()

        # Backward
        loss = loss / GRAD_ACCUMULATION
        loss.backward()

        if (step + 1) % GRAD_ACCUMULATION == 0:
          optimizer.step()
          optimizer.zero_grad()

          # Update Scheduler based on current loss
          current_loss_val = loss.item() * GRAD_ACCUMULATION
          scheduler.step(current_loss_val)

          print(f"Epoch {epoch} | Step {step} | Loss: {current_loss_val:.4f}")

print("Training Complete.")

Starting Training...


Epoch 1:   1%|          | 8/1423 [00:07<13:19,  1.77it/s]

Epoch 0 | Step 7 | Loss: 0.9844


Epoch 1:   1%|          | 16/1423 [00:10<10:39,  2.20it/s]

Epoch 0 | Step 15 | Loss: 0.9883


Epoch 1:   2%|▏         | 24/1423 [00:14<10:24,  2.24it/s]

Epoch 0 | Step 23 | Loss: 0.9844


Epoch 1:   2%|▏         | 32/1423 [00:17<10:17,  2.25it/s]

Epoch 0 | Step 31 | Loss: 0.9844


Epoch 1:   3%|▎         | 40/1423 [00:21<10:07,  2.28it/s]

Epoch 0 | Step 39 | Loss: 0.9844


Epoch 1:   3%|▎         | 48/1423 [00:24<10:09,  2.26it/s]

Epoch 0 | Step 47 | Loss: 0.9805


Epoch 1:   4%|▍         | 56/1423 [00:28<10:18,  2.21it/s]

Epoch 0 | Step 55 | Loss: 0.9805


Epoch 1:   4%|▍         | 64/1423 [00:31<10:01,  2.26it/s]

Epoch 0 | Step 63 | Loss: 0.9844


Epoch 1:   5%|▌         | 72/1423 [00:35<10:07,  2.23it/s]

Epoch 0 | Step 71 | Loss: 0.9844


Epoch 1:   6%|▌         | 80/1423 [00:38<10:10,  2.20it/s]

Epoch 0 | Step 79 | Loss: 0.9766


Epoch 1:   6%|▌         | 88/1423 [00:42<09:54,  2.25it/s]

Epoch 0 | Step 87 | Loss: 0.9805


Epoch 1:   7%|▋         | 96/1423 [00:45<09:47,  2.26it/s]

Epoch 0 | Step 95 | Loss: 0.9766


Epoch 1:   7%|▋         | 104/1423 [00:49<09:42,  2.26it/s]

Epoch 0 | Step 103 | Loss: 0.9766


Epoch 1:   8%|▊         | 112/1423 [00:52<09:43,  2.25it/s]

Epoch 0 | Step 111 | Loss: 0.9805


Epoch 1:   8%|▊         | 120/1423 [00:56<09:40,  2.24it/s]

Epoch 0 | Step 119 | Loss: 0.9766


Epoch 1:   9%|▉         | 128/1423 [00:59<09:29,  2.27it/s]

Epoch 0 | Step 127 | Loss: 0.9727


Epoch 1:  10%|▉         | 136/1423 [01:02<09:34,  2.24it/s]

Epoch 0 | Step 135 | Loss: 0.9727


Epoch 1:  10%|█         | 144/1423 [01:06<09:30,  2.24it/s]

Epoch 0 | Step 143 | Loss: 0.9727


Epoch 1:  11%|█         | 152/1423 [01:09<09:35,  2.21it/s]

Epoch 0 | Step 151 | Loss: 0.9648


Epoch 1:  11%|█         | 160/1423 [01:13<09:19,  2.26it/s]

Epoch 0 | Step 159 | Loss: 0.9727


Epoch 1:  12%|█▏        | 168/1423 [01:16<09:44,  2.15it/s]

Epoch 0 | Step 167 | Loss: 0.9688


Epoch 1:  12%|█▏        | 176/1423 [01:20<09:16,  2.24it/s]

Epoch 0 | Step 175 | Loss: 0.9648


Epoch 1:  13%|█▎        | 184/1423 [01:23<09:10,  2.25it/s]

Epoch 0 | Step 183 | Loss: 0.9648


Epoch 1:  13%|█▎        | 192/1423 [01:27<09:09,  2.24it/s]

Epoch 0 | Step 191 | Loss: 0.9609


Epoch 1:  14%|█▍        | 200/1423 [01:30<09:02,  2.26it/s]

Epoch 0 | Step 199 | Loss: 0.9648


Epoch 1:  15%|█▍        | 208/1423 [01:34<08:58,  2.26it/s]

Epoch 0 | Step 207 | Loss: 0.9648


Epoch 1:  15%|█▌        | 216/1423 [01:37<08:53,  2.26it/s]

Epoch 0 | Step 215 | Loss: 0.9609


Epoch 1:  16%|█▌        | 224/1423 [01:41<08:51,  2.25it/s]

Epoch 0 | Step 223 | Loss: 0.9609


Epoch 1:  16%|█▋        | 232/1423 [01:44<08:46,  2.26it/s]

Epoch 0 | Step 231 | Loss: 0.9570


Epoch 1:  17%|█▋        | 240/1423 [01:48<08:45,  2.25it/s]

Epoch 0 | Step 239 | Loss: 0.9609


Epoch 1:  17%|█▋        | 248/1423 [01:51<08:37,  2.27it/s]

Epoch 0 | Step 247 | Loss: 0.9570


Epoch 1:  18%|█▊        | 256/1423 [01:55<08:36,  2.26it/s]

Epoch 0 | Step 255 | Loss: 0.9570


Epoch 1:  19%|█▊        | 264/1423 [01:58<08:28,  2.28it/s]

Epoch 0 | Step 263 | Loss: 0.9570


Epoch 1:  19%|█▉        | 272/1423 [02:01<08:28,  2.26it/s]

Epoch 0 | Step 271 | Loss: 0.9570


Epoch 1:  20%|█▉        | 280/1423 [02:05<08:27,  2.25it/s]

Epoch 0 | Step 279 | Loss: 0.9570


Epoch 1:  20%|██        | 288/1423 [02:08<08:22,  2.26it/s]

Epoch 0 | Step 287 | Loss: 0.9570


Epoch 1:  21%|██        | 296/1423 [02:12<08:17,  2.26it/s]

Epoch 0 | Step 295 | Loss: 0.9609


Epoch 1:  21%|██▏       | 304/1423 [02:15<08:25,  2.21it/s]

Epoch 0 | Step 303 | Loss: 0.9531


Epoch 1:  22%|██▏       | 312/1423 [02:19<08:27,  2.19it/s]

Epoch 0 | Step 311 | Loss: 0.9570


Epoch 1:  22%|██▏       | 320/1423 [02:22<08:19,  2.21it/s]

Epoch 0 | Step 319 | Loss: 0.9492


Epoch 1:  23%|██▎       | 328/1423 [02:26<08:09,  2.24it/s]

Epoch 0 | Step 327 | Loss: 0.9570


Epoch 1:  24%|██▎       | 336/1423 [02:29<07:59,  2.27it/s]

Epoch 0 | Step 335 | Loss: 0.9492


Epoch 1:  24%|██▍       | 344/1423 [02:33<07:58,  2.26it/s]

Epoch 0 | Step 343 | Loss: 0.9609


Epoch 1:  25%|██▍       | 352/1423 [02:36<07:54,  2.26it/s]

Epoch 0 | Step 351 | Loss: 0.9492


Epoch 1:  25%|██▌       | 360/1423 [02:40<07:51,  2.26it/s]

Epoch 0 | Step 359 | Loss: 0.9492


Epoch 1:  26%|██▌       | 368/1423 [02:43<07:46,  2.26it/s]

Epoch 0 | Step 367 | Loss: 0.9453


Epoch 1:  26%|██▋       | 376/1423 [02:47<07:43,  2.26it/s]

Epoch 0 | Step 375 | Loss: 0.9492


Epoch 1:  27%|██▋       | 384/1423 [02:50<07:35,  2.28it/s]

Epoch 0 | Step 383 | Loss: 0.9492


Epoch 1:  28%|██▊       | 392/1423 [02:53<07:35,  2.27it/s]

Epoch 0 | Step 391 | Loss: 0.9648


Epoch 1:  28%|██▊       | 400/1423 [02:57<07:29,  2.27it/s]

Epoch 0 | Step 399 | Loss: 0.9453


Epoch 1:  29%|██▊       | 408/1423 [03:00<07:27,  2.27it/s]

Epoch 0 | Step 407 | Loss: 0.9453


Epoch 1:  29%|██▉       | 416/1423 [03:04<07:23,  2.27it/s]

Epoch 0 | Step 415 | Loss: 0.9570


Epoch 1:  30%|██▉       | 424/1423 [03:07<07:28,  2.23it/s]

Epoch 0 | Step 423 | Loss: 0.9492


Epoch 1:  30%|███       | 432/1423 [03:11<07:17,  2.26it/s]

Epoch 0 | Step 431 | Loss: 0.9453


Epoch 1:  31%|███       | 440/1423 [03:14<07:10,  2.28it/s]

Epoch 0 | Step 439 | Loss: 0.9492


Epoch 1:  31%|███▏      | 448/1423 [03:18<07:08,  2.28it/s]

Epoch 0 | Step 447 | Loss: 0.9453


Epoch 1:  32%|███▏      | 456/1423 [03:21<07:06,  2.27it/s]

Epoch 0 | Step 455 | Loss: 0.9453


Epoch 1:  33%|███▎      | 464/1423 [03:24<07:00,  2.28it/s]

Epoch 0 | Step 463 | Loss: 0.9492


Epoch 1:  33%|███▎      | 472/1423 [03:28<06:55,  2.29it/s]

Epoch 0 | Step 471 | Loss: 0.9492


Epoch 1:  34%|███▎      | 480/1423 [03:31<06:56,  2.26it/s]

Epoch 0 | Step 479 | Loss: 0.9375


Epoch 1:  34%|███▍      | 488/1423 [03:35<06:51,  2.27it/s]

Epoch 0 | Step 487 | Loss: 0.9414


Epoch 1:  35%|███▍      | 496/1423 [03:38<06:50,  2.26it/s]

Epoch 0 | Step 495 | Loss: 0.9375


Epoch 1:  35%|███▌      | 504/1423 [03:42<06:45,  2.27it/s]

Epoch 0 | Step 503 | Loss: 0.9414


Epoch 1:  36%|███▌      | 512/1423 [03:45<06:52,  2.21it/s]

Epoch 0 | Step 511 | Loss: 0.9492


Epoch 1:  37%|███▋      | 520/1423 [03:49<06:44,  2.23it/s]

Epoch 0 | Step 519 | Loss: 0.9453


Epoch 1:  37%|███▋      | 528/1423 [03:52<06:41,  2.23it/s]

Epoch 0 | Step 527 | Loss: 0.9414


Epoch 1:  38%|███▊      | 536/1423 [03:56<06:36,  2.23it/s]

Epoch 0 | Step 535 | Loss: 0.9414


Epoch 1:  38%|███▊      | 544/1423 [03:59<06:37,  2.21it/s]

Epoch 0 | Step 543 | Loss: 0.9414


Epoch 1:  39%|███▉      | 552/1423 [04:03<06:27,  2.25it/s]

Epoch 0 | Step 551 | Loss: 0.9375


Epoch 1:  39%|███▉      | 560/1423 [04:06<06:24,  2.24it/s]

Epoch 0 | Step 559 | Loss: 0.9492


Epoch 1:  40%|███▉      | 568/1423 [04:11<06:47,  2.10it/s]

Epoch 0 | Step 567 | Loss: 0.9453


Epoch 1:  40%|████      | 576/1423 [04:14<06:22,  2.21it/s]

Epoch 0 | Step 575 | Loss: 0.9414


Epoch 1:  41%|████      | 584/1423 [04:19<06:54,  2.02it/s]

Epoch 0 | Step 583 | Loss: 0.9375


Epoch 1:  42%|████▏     | 592/1423 [04:22<06:18,  2.20it/s]

Epoch 0 | Step 591 | Loss: 0.9414


Epoch 1:  42%|████▏     | 600/1423 [04:26<06:02,  2.27it/s]

Epoch 0 | Step 599 | Loss: 0.9375


Epoch 1:  43%|████▎     | 608/1423 [04:29<05:58,  2.28it/s]

Epoch 0 | Step 607 | Loss: 0.9453


Epoch 1:  43%|████▎     | 616/1423 [04:32<05:53,  2.28it/s]

Epoch 0 | Step 615 | Loss: 0.9336


Epoch 1:  44%|████▍     | 624/1423 [04:37<06:18,  2.11it/s]

Epoch 0 | Step 623 | Loss: 0.9336


Epoch 1:  44%|████▍     | 632/1423 [04:40<05:48,  2.27it/s]

Epoch 0 | Step 631 | Loss: 0.9375


Epoch 1:  45%|████▍     | 640/1423 [04:44<05:45,  2.26it/s]

Epoch 0 | Step 639 | Loss: 0.9375


Epoch 1:  46%|████▌     | 648/1423 [04:48<06:08,  2.10it/s]

Epoch 0 | Step 647 | Loss: 0.9414


Epoch 1:  46%|████▌     | 656/1423 [04:51<05:43,  2.24it/s]

Epoch 0 | Step 655 | Loss: 0.9414


Epoch 1:  47%|████▋     | 664/1423 [04:55<05:36,  2.25it/s]

Epoch 0 | Step 663 | Loss: 0.9375


Epoch 1:  47%|████▋     | 672/1423 [04:58<05:33,  2.25it/s]

Epoch 0 | Step 671 | Loss: 0.9375


Epoch 1:  48%|████▊     | 680/1423 [05:02<05:31,  2.24it/s]

Epoch 0 | Step 679 | Loss: 0.9375


Epoch 1:  48%|████▊     | 688/1423 [05:05<05:23,  2.27it/s]

Epoch 0 | Step 687 | Loss: 0.9375


Epoch 1:  49%|████▉     | 696/1423 [05:09<05:20,  2.27it/s]

Epoch 0 | Step 695 | Loss: 0.9375


Epoch 1:  49%|████▉     | 704/1423 [05:12<05:17,  2.27it/s]

Epoch 0 | Step 703 | Loss: 0.9375


Epoch 1:  50%|█████     | 712/1423 [05:16<05:18,  2.24it/s]

Epoch 0 | Step 711 | Loss: 0.9375


Epoch 1:  51%|█████     | 720/1423 [05:19<05:10,  2.27it/s]

Epoch 0 | Step 719 | Loss: 0.9375


Epoch 1:  51%|█████     | 728/1423 [05:22<05:08,  2.25it/s]

Epoch 0 | Step 727 | Loss: 0.9414


Epoch 1:  52%|█████▏    | 736/1423 [05:26<05:04,  2.26it/s]

Epoch 0 | Step 735 | Loss: 0.9336


Epoch 1:  52%|█████▏    | 744/1423 [05:29<04:58,  2.28it/s]

Epoch 0 | Step 743 | Loss: 0.9336


Epoch 1:  53%|█████▎    | 752/1423 [05:33<04:55,  2.27it/s]

Epoch 0 | Step 751 | Loss: 0.9375


Epoch 1:  53%|█████▎    | 760/1423 [05:36<04:53,  2.26it/s]

Epoch 0 | Step 759 | Loss: 0.9453


Epoch 1:  54%|█████▍    | 768/1423 [05:40<04:50,  2.26it/s]

Epoch 0 | Step 767 | Loss: 0.9336


Epoch 1:  55%|█████▍    | 776/1423 [05:43<04:43,  2.28it/s]

Epoch 0 | Step 775 | Loss: 0.9336


Epoch 1:  55%|█████▌    | 784/1423 [05:47<04:43,  2.25it/s]

Epoch 0 | Step 783 | Loss: 0.9258


Epoch 1:  56%|█████▌    | 792/1423 [05:50<04:40,  2.25it/s]

Epoch 0 | Step 791 | Loss: 0.9336


Epoch 1:  56%|█████▌    | 800/1423 [05:54<04:39,  2.23it/s]

Epoch 0 | Step 799 | Loss: 0.9375


Epoch 1:  57%|█████▋    | 808/1423 [05:57<04:30,  2.28it/s]

Epoch 0 | Step 807 | Loss: 0.9453


Epoch 1:  57%|█████▋    | 816/1423 [06:00<04:26,  2.28it/s]

Epoch 0 | Step 815 | Loss: 0.9336


Epoch 1:  58%|█████▊    | 824/1423 [06:04<04:29,  2.22it/s]

Epoch 0 | Step 823 | Loss: 0.9375


Epoch 1:  58%|█████▊    | 832/1423 [06:07<04:22,  2.25it/s]

Epoch 0 | Step 831 | Loss: 0.9336


Epoch 1:  59%|█████▉    | 840/1423 [06:11<04:17,  2.27it/s]

Epoch 0 | Step 839 | Loss: 0.9297


Epoch 1:  60%|█████▉    | 848/1423 [06:14<04:15,  2.25it/s]

Epoch 0 | Step 847 | Loss: 0.9375


Epoch 1:  60%|██████    | 856/1423 [06:18<04:16,  2.21it/s]

Epoch 0 | Step 855 | Loss: 0.9297


Epoch 1:  61%|██████    | 864/1423 [06:21<04:06,  2.27it/s]

Epoch 0 | Step 863 | Loss: 0.9258


Epoch 1:  61%|██████▏   | 872/1423 [06:25<04:03,  2.26it/s]

Epoch 0 | Step 871 | Loss: 0.9336


Epoch 1:  62%|██████▏   | 880/1423 [06:28<04:02,  2.24it/s]

Epoch 0 | Step 879 | Loss: 0.9297


Epoch 1:  62%|██████▏   | 888/1423 [06:32<03:55,  2.27it/s]

Epoch 0 | Step 887 | Loss: 0.9336


Epoch 1:  63%|██████▎   | 896/1423 [06:35<03:52,  2.27it/s]

Epoch 0 | Step 895 | Loss: 0.9297


Epoch 1:  64%|██████▎   | 904/1423 [06:39<03:52,  2.24it/s]

Epoch 0 | Step 903 | Loss: 0.9336


Epoch 1:  64%|██████▍   | 912/1423 [06:42<03:44,  2.28it/s]

Epoch 0 | Step 911 | Loss: 0.9297


Epoch 1:  65%|██████▍   | 920/1423 [06:46<03:40,  2.28it/s]

Epoch 0 | Step 919 | Loss: 0.9297


Epoch 1:  65%|██████▌   | 928/1423 [06:49<03:37,  2.28it/s]

Epoch 0 | Step 927 | Loss: 0.9297


Epoch 1:  66%|██████▌   | 936/1423 [06:52<03:37,  2.24it/s]

Epoch 0 | Step 935 | Loss: 0.9336


Epoch 1:  66%|██████▋   | 944/1423 [06:56<03:33,  2.24it/s]

Epoch 0 | Step 943 | Loss: 0.9336


Epoch 1:  67%|██████▋   | 952/1423 [06:59<03:26,  2.28it/s]

Epoch 0 | Step 951 | Loss: 0.9336


Epoch 1:  67%|██████▋   | 960/1423 [07:03<03:24,  2.27it/s]

Epoch 0 | Step 959 | Loss: 0.9375


Epoch 1:  68%|██████▊   | 968/1423 [07:06<03:22,  2.25it/s]

Epoch 0 | Step 967 | Loss: 0.9297


Epoch 1:  69%|██████▊   | 976/1423 [07:10<03:16,  2.27it/s]

Epoch 0 | Step 975 | Loss: 0.9258


Epoch 1:  69%|██████▉   | 984/1423 [07:13<03:15,  2.25it/s]

Epoch 0 | Step 983 | Loss: 0.9297


Epoch 1:  70%|██████▉   | 992/1423 [07:17<03:12,  2.24it/s]

Epoch 0 | Step 991 | Loss: 0.9414


Epoch 1:  70%|███████   | 1000/1423 [07:20<03:06,  2.27it/s]

Epoch 0 | Step 999 | Loss: 0.9297


Epoch 1:  71%|███████   | 1008/1423 [07:24<03:03,  2.26it/s]

Epoch 0 | Step 1007 | Loss: 0.9297


Epoch 1:  71%|███████▏  | 1016/1423 [07:27<02:58,  2.27it/s]

Epoch 0 | Step 1015 | Loss: 0.9336


Epoch 1:  72%|███████▏  | 1024/1423 [07:30<02:57,  2.25it/s]

Epoch 0 | Step 1023 | Loss: 0.9180


Epoch 1:  73%|███████▎  | 1032/1423 [07:34<02:52,  2.26it/s]

Epoch 0 | Step 1031 | Loss: 0.9297


Epoch 1:  73%|███████▎  | 1040/1423 [07:37<02:50,  2.25it/s]

Epoch 0 | Step 1039 | Loss: 0.9297


Epoch 1:  74%|███████▎  | 1048/1423 [07:41<02:47,  2.24it/s]

Epoch 0 | Step 1047 | Loss: 0.9336


Epoch 1:  74%|███████▍  | 1056/1423 [07:44<02:43,  2.24it/s]

Epoch 0 | Step 1055 | Loss: 0.9297


Epoch 1:  75%|███████▍  | 1064/1423 [07:48<02:36,  2.29it/s]

Epoch 0 | Step 1063 | Loss: 0.9297


Epoch 1:  75%|███████▌  | 1072/1423 [07:51<02:35,  2.25it/s]

Epoch 0 | Step 1071 | Loss: 0.9219


Epoch 1:  76%|███████▌  | 1080/1423 [07:55<02:31,  2.26it/s]

Epoch 0 | Step 1079 | Loss: 0.9297


Epoch 1:  76%|███████▋  | 1088/1423 [07:58<02:27,  2.27it/s]

Epoch 0 | Step 1087 | Loss: 0.9297


Epoch 1:  77%|███████▋  | 1096/1423 [08:02<02:24,  2.27it/s]

Epoch 0 | Step 1095 | Loss: 0.9219


Epoch 1:  78%|███████▊  | 1104/1423 [08:05<02:21,  2.25it/s]

Epoch 0 | Step 1103 | Loss: 0.9297


Epoch 1:  78%|███████▊  | 1112/1423 [08:08<02:18,  2.24it/s]

Epoch 0 | Step 1111 | Loss: 0.9297


Epoch 1:  79%|███████▊  | 1120/1423 [08:12<02:13,  2.27it/s]

Epoch 0 | Step 1119 | Loss: 0.9297


Epoch 1:  79%|███████▉  | 1128/1423 [08:15<02:10,  2.25it/s]

Epoch 0 | Step 1127 | Loss: 0.9297


Epoch 1:  80%|███████▉  | 1136/1423 [08:19<02:09,  2.21it/s]

Epoch 0 | Step 1135 | Loss: 0.9297


Epoch 1:  80%|████████  | 1144/1423 [08:22<02:03,  2.27it/s]

Epoch 0 | Step 1143 | Loss: 0.9219


Epoch 1:  81%|████████  | 1152/1423 [08:26<01:59,  2.27it/s]

Epoch 0 | Step 1151 | Loss: 0.9297


Epoch 1:  82%|████████▏ | 1160/1423 [08:29<01:56,  2.27it/s]

Epoch 0 | Step 1159 | Loss: 0.9180


Epoch 1:  82%|████████▏ | 1168/1423 [08:33<01:52,  2.26it/s]

Epoch 0 | Step 1167 | Loss: 0.9219


Epoch 1:  83%|████████▎ | 1176/1423 [08:36<01:48,  2.28it/s]

Epoch 0 | Step 1175 | Loss: 0.9219


Epoch 1:  83%|████████▎ | 1184/1423 [08:39<01:45,  2.26it/s]

Epoch 0 | Step 1183 | Loss: 0.9180


Epoch 1:  84%|████████▍ | 1192/1423 [08:43<01:41,  2.27it/s]

Epoch 0 | Step 1191 | Loss: 0.9297


Epoch 1:  84%|████████▍ | 1200/1423 [08:46<01:37,  2.28it/s]

Epoch 0 | Step 1199 | Loss: 0.9414


Epoch 1:  85%|████████▍ | 1208/1423 [08:50<01:34,  2.27it/s]

Epoch 0 | Step 1207 | Loss: 0.9258


Epoch 1:  85%|████████▌ | 1216/1423 [08:53<01:31,  2.26it/s]

Epoch 0 | Step 1215 | Loss: 0.9258


Epoch 1:  86%|████████▌ | 1224/1423 [08:57<01:29,  2.23it/s]

Epoch 0 | Step 1223 | Loss: 0.9258


Epoch 1:  87%|████████▋ | 1232/1423 [09:00<01:25,  2.24it/s]

Epoch 0 | Step 1231 | Loss: 0.9219


Epoch 1:  87%|████████▋ | 1240/1423 [09:04<01:20,  2.27it/s]

Epoch 0 | Step 1239 | Loss: 0.9297


Epoch 1:  88%|████████▊ | 1248/1423 [09:07<01:17,  2.27it/s]

Epoch 0 | Step 1247 | Loss: 0.9219


Epoch 1:  88%|████████▊ | 1256/1423 [09:10<01:14,  2.25it/s]

Epoch 0 | Step 1255 | Loss: 0.9219


Epoch 1:  89%|████████▉ | 1264/1423 [09:15<01:20,  1.98it/s]

Epoch 0 | Step 1263 | Loss: 0.9297


Epoch 1:  89%|████████▉ | 1272/1423 [09:18<01:07,  2.24it/s]

Epoch 0 | Step 1271 | Loss: 0.9219


Epoch 1:  90%|████████▉ | 1280/1423 [09:22<01:04,  2.22it/s]

Epoch 0 | Step 1279 | Loss: 0.9297


Epoch 1:  91%|█████████ | 1288/1423 [09:25<00:59,  2.26it/s]

Epoch 0 | Step 1287 | Loss: 0.9258


Epoch 1:  91%|█████████ | 1296/1423 [09:29<00:55,  2.27it/s]

Epoch 0 | Step 1295 | Loss: 0.9375


Epoch 1:  92%|█████████▏| 1304/1423 [09:32<00:52,  2.25it/s]

Epoch 0 | Step 1303 | Loss: 0.9375


Epoch 1:  92%|█████████▏| 1312/1423 [09:35<00:48,  2.27it/s]

Epoch 0 | Step 1311 | Loss: 0.9258


Epoch 1:  93%|█████████▎| 1320/1423 [09:39<00:45,  2.26it/s]

Epoch 0 | Step 1319 | Loss: 0.9141


Epoch 1:  93%|█████████▎| 1328/1423 [09:42<00:41,  2.27it/s]

Epoch 0 | Step 1327 | Loss: 0.9219


Epoch 1:  94%|█████████▍| 1336/1423 [09:46<00:38,  2.26it/s]

Epoch 0 | Step 1335 | Loss: 0.9258


Epoch 1:  94%|█████████▍| 1344/1423 [09:49<00:34,  2.28it/s]

Epoch 0 | Step 1343 | Loss: 0.9219


Epoch 1:  95%|█████████▌| 1352/1423 [09:53<00:31,  2.27it/s]

Epoch 0 | Step 1351 | Loss: 0.9219


Epoch 1:  96%|█████████▌| 1360/1423 [09:56<00:27,  2.27it/s]

Epoch 0 | Step 1359 | Loss: 0.9297


Epoch 1:  96%|█████████▌| 1368/1423 [10:00<00:24,  2.26it/s]

Epoch 0 | Step 1367 | Loss: 0.9219


Epoch 1:  97%|█████████▋| 1376/1423 [10:03<00:21,  2.21it/s]

Epoch 0 | Step 1375 | Loss: 0.9180


Epoch 1:  97%|█████████▋| 1384/1423 [10:06<00:17,  2.25it/s]

Epoch 0 | Step 1383 | Loss: 0.9258


Epoch 1:  98%|█████████▊| 1392/1423 [10:10<00:13,  2.27it/s]

Epoch 0 | Step 1391 | Loss: 0.9258


Epoch 1:  98%|█████████▊| 1400/1423 [10:13<00:10,  2.25it/s]

Epoch 0 | Step 1399 | Loss: 0.9258


Epoch 1:  99%|█████████▉| 1408/1423 [10:17<00:06,  2.28it/s]

Epoch 0 | Step 1407 | Loss: 0.9219


Epoch 1: 100%|█████████▉| 1416/1423 [10:20<00:03,  2.30it/s]

Epoch 0 | Step 1415 | Loss: 0.9258


Epoch 1: 100%|██████████| 1423/1423 [10:24<00:00,  2.28it/s]

Training Complete.





In [16]:
example = dataset['public_test'][0]
image = example['image']
prompt = example['turns']['query'][0]

bridge_model.eval()
inputs = processor(image, prompt, return_tensors="pt").to(model.device)
inputs = {
    k: (v.to(torch.bfloat16) if torch.is_floating_point(v) else v)
    for k, v in inputs.items()
}

out = bridge_model(**inputs)
result = model.model.multi_modal_projector(out)

batch_size = 1
num_images = 1
num_tiles = 4

# Reshape: [4, 1611, 4096] -> [1, 1, 4, 1611, 4096]
cross_attention_states = result.view(batch_size, num_images, num_tiles, 1611, 4096)

cross_attention_mask = torch.ones(
    (batch_size, num_images, num_tiles, 1611),
    dtype=torch.bfloat16,
    device=result.device
)
cross_attention_mask = cross_attention_mask.view(batch_size, 1, 1, -1)

with torch.no_grad():
    outputs = model.language_model(
    input_ids=inputs['input_ids'],
    attention_mask=inputs['attention_mask'],
    cross_attention_states=cross_attention_states,
    cross_attention_mask=cross_attention_mask
)


input_ids = inputs['input_ids']
generated_text = ""

print("Generating...")

for i in range(150):
    outputs = model.language_model(
        input_ids=input_ids,
        cross_attention_states=cross_attention_states,
        cross_attention_mask=cross_attention_mask
    )

    logits = model.lm_head(outputs.last_hidden_state)

    # Get the next token ID
    next_token_logits = logits[:, -1, :]
    next_token_id = torch.argmax(next_token_logits, dim=-1).unsqueeze(-1)

    # 4. Check for EOS (End of Sentence)
    if next_token_id.item() == processor.tokenizer.eos_token_id:
        print("\n<EOS reached>")
        break

    input_ids = torch.cat([input_ids, next_token_id], dim=1)

    # Decode and print
    word = processor.tokenizer.decode(next_token_id[0])
    generated_text += word
    print(word, end="", flush=True)

Generating...
The vehicle's model is a 1969-1970 Chevrolet Camaro, and the engine is a 350. The 1969-1970 Chevrolet Camaro with a 350 engine has a top speed of approximately 120 mph. The 1969-1970 Chevrolet Camaro with a 350 engine has a top speed of approximately 120 mph. The 1969-1970 Chevrolet Camaro with a 350 engine has a top speed of approximately 120 mph. The 1969-1970 Chevrolet Camaro with a 350 engine has a top speed of approximately 120 mph. The 1969-1970 Chevrolet Camaro with a 350 engine has a top speed of approximately 120 mph. The