In [None]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
from PIL import Image
import os

device = "cuda"
torch.set_float32_matmul_precision("high")

In [None]:
# =====================================================
# 1️⃣ Load Vision Encoder (Frozen, BF16)
# =====================================================

vision_name = "google/siglip-base-patch16-224"
processor = AutoProcessor.from_pretrained(vision_name)
vision_model = AutoModel.from_pretrained(vision_name).vision_model.to(device, dtype=torch.bfloat16)
vision_model.eval()

for p in vision_model.parameters():
    p.requires_grad = False

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
# =====================================================
# 2️⃣ Load LLM (BF16)
# =====================================================

llm_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    torch_dtype=torch.bfloat16
).to(device)

llm_hidden = llm.config.hidden_size

`torch_dtype` is deprecated! Use `dtype` instead!


In [4]:
# =====================================================
# 3️⃣ Projection Layer (Trainable)
# =====================================================

projector = nn.Linear(768, llm_hidden, bias=False).to(device, dtype=torch.bfloat16)

In [None]:
# ----------------------------------
# Single Image + Caption
# ----------------------------------

image = Image.open("dogs-playing-in-grassy-field.jpg").convert("RGB")
caption = "Two dogs playing."

In [6]:
# Encode image once (frozen)
inputs = processor(images=image, return_tensors="pt").to(device)

with torch.no_grad():
    vision_outputs = vision_model(**inputs)
    patch_tokens_raw = vision_outputs.last_hidden_state.detach()  # (1,196,768)

In [7]:
# =====================================================
# 5️⃣ Tokenize Caption ONLY
# =====================================================

tokens = tokenizer(
    caption,
    return_tensors="pt"
).to(device)

input_ids = tokens.input_ids
attention_mask = tokens.attention_mask

# We want to predict ALL tokens
labels = input_ids.clone()

In [8]:
# =====================================================
# 6️⃣ Optimizer
# =====================================================

optimizer = torch.optim.AdamW(
    list(projector.parameters()) + list(llm.parameters()),
    lr=1e-4
)

llm.train()
projector.train()

Linear(in_features=768, out_features=1024, bias=False)

In [None]:
# =====================================================
# 7️⃣ Training Loop
# =====================================================

for step in range(500):

    # Project vision tokens inside loop
    patch_tokens = projector(patch_tokens_raw)

    # Text embeddings
    text_embeds = llm.get_input_embeddings()(input_ids)

    # Concatenate vision + text
    inputs_embeds = torch.cat([patch_tokens, text_embeds], dim=1)

    # Attention mask
    visual_attention = torch.ones(
        (1, patch_tokens.size(1)),
        device=device,
        dtype=attention_mask.dtype
    )

    full_attention = torch.cat(
        [visual_attention, attention_mask],
        dim=1
    )

    # Mask visual token positions in loss
    visual_label_pad = torch.full(
        (1, patch_tokens.size(1)),
        -100,
        device=device
    )

    full_labels = torch.cat(
        [visual_label_pad, labels],
        dim=1
    )

    outputs = llm(
        inputs_embeds=inputs_embeds,
        attention_mask=full_attention,
        labels=full_labels
    )

    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 50 == 0:
        print(f"Step {step} | Loss: {loss.item()}")

print("Outputs Logits Shape: ", outputs.logits.shape) # (batch, seq_len, vocab_size) -> (1, 200 (196 visual tokens + 4 text), vocab_size)

Step 0 | Loss: 14.136856079101562
Step 50 | Loss: 0.00236580241471529
Step 100 | Loss: 0.0013496836181730032
Step 150 | Loss: 0.0009765183785930276
Step 200 | Loss: 0.0007501288782805204
Step 250 | Loss: 0.0005781014915555716
Step 300 | Loss: 0.0004980136873200536
Step 350 | Loss: 0.00045549171045422554
Step 400 | Loss: 0.0004082410014234483
Step 450 | Loss: 0.00036625508801080287
Outputs Logits Shape:  torch.Size([1, 200, 151936])


In [10]:
# =====================================================
# 8️⃣ Inference Test (With Vision)
# =====================================================

llm.eval()
projector.eval()

with torch.no_grad():

    patch_tokens = projector(patch_tokens_raw)

    # No text prompt at all
    inputs_embeds = patch_tokens

    visual_attention = torch.ones(
        (1, patch_tokens.size(1)),
        device=device
    )

    generated = llm.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=visual_attention,
        max_new_tokens=10,
        do_sample=False
    )

print("\n=== WITH VISION ===")
print(tokenizer.decode(generated[0], skip_special_tokens=True))

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



=== WITH VISION ===
Two dogs playing...... playing


In [11]:
# =====================================================
# 9️⃣ Ablation Test (Zero Vision)
# =====================================================

with torch.no_grad():

    zero_visual = torch.zeros_like(patch_tokens)

    generated_zero = llm.generate(
        inputs_embeds=zero_visual,
        attention_mask=visual_attention,
        max_new_tokens=10,
        do_sample=False
    )

print("\n=== ZERO VISION ===")
print(tokenizer.decode(generated_zero[0], skip_special_tokens=True))


=== ZERO VISION ===
!lll. 2. 3.


In [None]:
# =====================================================
# Inference Test (With Vision) - Different Image
# =====================================================

llm.eval()
projector.eval()

# =====================================================
# Test Images Folder
# =====================================================

image_folder = "images"  # put multiple images here
image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder)]

# =====================================================
# Evaluation Loop
# =====================================================

for img_path in image_paths:

    print("\n======================================")
    print("Image:", img_path.split("/")[-1])

    image = Image.open(img_path).convert("RGB")

    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        vision_outputs = vision_model(**inputs)
        patch_tokens_raw = vision_outputs.last_hidden_state

        patch_tokens = projector(patch_tokens_raw)

        visual_attention = torch.ones(
            (1, patch_tokens.size(1)),
            device=device
        )

    # -------------------------
    # WITH VISION
    # -------------------------
    with torch.no_grad():

        generated = llm.generate(
            inputs_embeds=patch_tokens,
            attention_mask=visual_attention,
            max_new_tokens=10,
            do_sample=False
        )

    output_with_vision = tokenizer.decode(
        generated[0],
        skip_special_tokens=True
    )

    # -------------------------
    # ZERO VISION
    # -------------------------
    with torch.no_grad():

        zero_visual = torch.zeros_like(patch_tokens)

        generated_zero = llm.generate(
            inputs_embeds=zero_visual,
            attention_mask=visual_attention,
            max_new_tokens=10,
            do_sample=False
        )

    output_zero = tokenizer.decode(
        generated_zero[0],
        skip_special_tokens=True
    )

    # -------------------------
    # Print Results
    # -------------------------

    print("WITH VISION :", output_with_vision)
    print("ZERO VISION :", output_zero)


Image: dogs-playing-in-grassy-field.jpg
WITH VISION : Two dogs playing...... playing
ZERO VISION : !lll. 2. 3.

Image: dog_and_girl.jpeg
WITH VISION : Two dogs playing.......
ZERO VISION : !lll. 2. 3.

Image: sample.png
WITH VISION : Two dogs playing.. playing....
ZERO VISION : !lll. 2. 3.

Image: dog.png
WITH VISION : Two dogs playing.......
ZERO VISION : !lll. 2. 3.

Image: gd-dog.jpg
WITH VISION : Two dogs playing.......
ZERO VISION : !lll. 2. 3.


In [17]:
# measure grounding metrics
# we compute

# Δ = || logits_with_vision − logits_without_vision ||

def measure_vision_influence(patch_tokens):
    with torch.no_grad():

        # Random baseline instead of zero
        random_visual = torch.randn_like(patch_tokens) * 0.01

        logits_v = llm(inputs_embeds=patch_tokens).logits[:, -1, :]
        logits_r = llm(inputs_embeds=random_visual).logits[:, -1, :]

        diff = torch.norm(logits_v - logits_r).item()
        base = torch.norm(logits_r).item()

        relative = diff / (base + 1e-8)

    print(f"Vision influence (L2 diff): {diff:.4f}")
    print(f"Baseline norm: {base:.4f}")
    print(f"Relative influence: {relative:.4f}")

In [None]:
# =====================================================
# Inference Test (With Vision) - Different Image
# =====================================================

llm.eval()
projector.eval()

# =====================================================
# Test Images Folder
# =====================================================

image_folder = "images"  # put multiple images here
image_paths = [os.path.join(image_folder, f) for f in os.listdir(image_folder)]

# =====================================================
# Evaluation Loop
# =====================================================

for img_path in image_paths:

    print("\n======================================")
    print("Image:", img_path.split("/")[-1])

    image = Image.open(img_path).convert("RGB")

    inputs = processor(images=image, return_tensors="pt").to(device)

    with torch.no_grad():
        vision_outputs = vision_model(**inputs)
        patch_tokens_raw = vision_outputs.last_hidden_state

        patch_tokens = projector(patch_tokens_raw)

        visual_attention = torch.ones(
            (1, patch_tokens.size(1)),
            device=device
        )

        measure_vision_influence(patch_tokens)

    # -------------------------
    # WITH VISION
    # -------------------------
    with torch.no_grad():

        generated = llm.generate(
            inputs_embeds=patch_tokens,
            attention_mask=visual_attention,
            max_new_tokens=10,
            do_sample=False
        )

    output_with_vision = tokenizer.decode(
        generated[0],
        skip_special_tokens=True
    )

    # -------------------------
    # ZERO VISION
    # -------------------------
    with torch.no_grad():

        zero_visual = torch.zeros_like(patch_tokens)

        generated_zero = llm.generate(
            inputs_embeds=zero_visual,
            attention_mask=visual_attention,
            max_new_tokens=10,
            do_sample=False
        )

    output_zero = tokenizer.decode(
        generated_zero[0],
        skip_special_tokens=True
    )

    # -------------------------
    # Print Results
    # -------------------------

    print("WITH VISION :", output_with_vision)
    print("ZERO VISION :", output_zero)


Image: dogs-playing-in-grassy-field.jpg
Vision influence (L2 diff): 2528.0000
Baseline norm: 868.0000
Relative influence: 2.9124
WITH VISION : Two dogs playing...... playing
ZERO VISION : !lll. 2. 3.

Image: dog_and_girl.jpeg
Vision influence (L2 diff): 2240.0000
Baseline norm: 732.0000
Relative influence: 3.0601
WITH VISION : Two dogs playing.......
ZERO VISION : !lll. 2. 3.

Image: sample.png
Vision influence (L2 diff): 2160.0000
Baseline norm: 696.0000
Relative influence: 3.1034
WITH VISION : Two dogs playing.. playing....
ZERO VISION : !lll. 2. 3.

Image: dog.png
Vision influence (L2 diff): 2144.0000
Baseline norm: 884.0000
Relative influence: 2.4253
WITH VISION : Two dogs playing.......
ZERO VISION : !lll. 2. 3.

Image: gd-dog.jpg
Vision influence (L2 diff): 2512.0000
Baseline norm: 784.0000
Relative influence: 3.2041
WITH VISION : Two dogs playing.......
ZERO VISION : !lll. 2. 3.
