In [None]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
from PIL import Image

device = "cuda"
torch.set_float32_matmul_precision("high")

In [None]:
# ----------------------------------
# Load Vision Encoder (Frozen)
# ----------------------------------

vision_name = "google/siglip-base-patch16-224"
processor = AutoProcessor.from_pretrained(vision_name, use_fast=True)
vision_model = AutoModel.from_pretrained(vision_name).vision_model.to(device)
vision_model.eval()

for p in vision_model.parameters():
    p.requires_grad = False

In [None]:
# ----------------------------------
# Load LLM
# ----------------------------------

llm_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
).to(device)

llm_hidden = llm.config.hidden_size

print("LLM Hidden State: ", llm_hidden)

LLM Hidden State:  1024


In [4]:
# ----------------------------------
# Projection Layer (Trainable)
# ----------------------------------

projector = nn.Linear(768, llm_hidden, bias=False).to(device, dtype=torch.bfloat16)

In [None]:
# ----------------------------------
# Single Image + Caption
# ----------------------------------

image = Image.open("dogs-playing-in-grassy-field.jpg").convert("RGB")
caption = "Two dogs playing."

In [6]:
inputs = processor(images=image, return_tensors="pt").to(device)

In [7]:
with torch.no_grad():
    vision_outputs = vision_model(**inputs)
    patch_tokens_raw = vision_outputs.last_hidden_state # (1, 196, 768)

print("Patch Tokens Size: ", patch_tokens_raw.shape)

Patch Tokens Size:  torch.Size([1, 196, 768])


In [8]:
# ----------------------------------
# Text Prompt
# ----------------------------------

prompt = f"USER: <image>\nDescribe the image.\nASSISTANT: {caption}"

tokens = tokenizer(
    prompt,
    return_tensors="pt"
).to(device)

input_ids = tokens.input_ids
attention_mask = tokens.attention_mask

In [9]:
# ----------------------------------
# Prepare Labels (IMPORTANT)
# ----------------------------------

labels = input_ids.clone()

# Mask everything except assistant answer
assistant_start = prompt.index("ASSISTANT:")
assistant_tokens = tokenizer(
    prompt[:assistant_start],
    return_tensors="pt"
).input_ids.shape[1]

labels[:, :assistant_tokens] = -100

In [10]:
# ----------------------------------
# Training Setup
# ----------------------------------

optimizer = torch.optim.AdamW(
    list(projector.parameters()) + list(llm.parameters()),
    lr=1e-4
)

llm.train()
projector.train()

Linear(in_features=768, out_features=1024, bias=False)

In [11]:
# ----------------------------------
# Training Loop
# ----------------------------------

for step in range(500):
    
    patch_tokens = patch_tokens_raw.to(torch.bfloat16)
    patch_tokens = projector(patch_tokens)

    text_embeds = llm.get_input_embeddings()(input_ids)

    inputs_embeds = torch.cat([patch_tokens, text_embeds], dim=1)

    visual_attention = torch.ones(
        (1, patch_tokens.size(1)),
        device=device
    )

    full_attention = torch.cat(
        [visual_attention, attention_mask],
        dim=1
    )

    # Pad labels for visual tokens
    visual_label_pad = torch.full(
        (1, patch_tokens.size(1)),
        -100,
        device=device
    )

    full_labels = torch.cat(
        [visual_label_pad, labels],
        dim=1
    )

    outputs = llm(
        inputs_embeds=inputs_embeds,
        attention_mask=full_attention,
        labels=full_labels
    )

    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 50 == 0:
        print(f"Step {step} | Loss: {loss.item()}")

Step 0 | Loss: 14.728104591369629
Step 50 | Loss: 3.994830331066623e-05
Step 100 | Loss: 1.877509566838853e-05
Step 150 | Loss: 1.3455543012241833e-05
Step 200 | Loss: 1.0385981113358866e-05
Step 250 | Loss: 8.329663614858873e-06
Step 300 | Loss: 6.8693752837134525e-06
Step 350 | Loss: 5.7816068874672055e-06
Step 400 | Loss: 4.962053935742006e-06
Step 450 | Loss: 4.261707545083482e-06


In [14]:
llm.eval()

with torch.no_grad():
    generated = llm.generate(
        inputs_embeds=inputs_embeds,
        attention_mask=full_attention,
        max_new_tokens=50
    )

print(tokenizer.decode(generated[0], skip_special_tokens=True))

 Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs


In [15]:
# Correct generation test

inference_prompt = "USER: <image>\nDescribe the image.\nASSISTANT:"

tokens = tokenizer(
    inference_prompt,
    return_tensors="pt"
).to(device)

input_ids = tokens.input_ids
attention_mask = tokens.attention_mask

In [16]:
text_embeds = llm.get_input_embeddings()(input_ids)

patch_tokens = patch_tokens_raw.to(torch.bfloat16)
patch_tokens = projector(patch_tokens)

inputs_embeds = torch.cat([patch_tokens, text_embeds], dim=1)

visual_attention = torch.ones(
    (1, patch_tokens.size(1)),
    device=device
)

full_attention = torch.cat(
    [visual_attention, attention_mask],
    dim=1
)

In [17]:
generated = llm.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=full_attention,
    max_new_tokens=30,
    do_sample=False
)

print(tokenizer.decode(generated[0], skip_special_tokens=True))

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


 Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs


In [18]:
# testing vision ablation

zero_visual = torch.zeros_like(patch_tokens)

inputs_embeds = torch.cat([zero_visual, text_embeds], dim=1)

generated = llm.generate(
    inputs_embeds=inputs_embeds,
    attention_mask=full_attention,
    max_new_tokens=30,
    do_sample=False
)

print(tokenizer.decode(generated[0], skip_special_tokens=True))

 Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs playing. Two dogs
