In [None]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
from PIL import Image

# =====================================================
# Setup
# =====================================================

device = "cuda"
torch.set_float32_matmul_precision("high")

In [None]:
# =====================================================
# 1️⃣ Load Vision Encoder (Frozen)
# =====================================================

vision_name = "google/siglip-base-patch16-224"
processor = AutoProcessor.from_pretrained(vision_name)

vision_model = AutoModel.from_pretrained(
    vision_name
).vision_model.to(device, dtype=torch.bfloat16)

vision_model.eval()
for p in vision_model.parameters():
    p.requires_grad = False

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [None]:
# =====================================================
# 2️⃣ Load LLM
# =====================================================

llm_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    torch_dtype=torch.bfloat16
).to(device)

hidden = llm.config.hidden_size

`torch_dtype` is deprecated! Use `dtype` instead!


In [4]:
# =====================================================
# 3️⃣ Projector + Image Separator
# =====================================================

projector = nn.Linear(768, hidden, bias=False).to(device, dtype=torch.bfloat16)

image_sep = nn.Parameter(
    torch.randn(1, 1, hidden, dtype=torch.bfloat16, device=device)
)

In [None]:
# =====================================================
# 4️⃣ Load Images
# =====================================================

images = {
    "airplane": "images/airplane.png",
    "motorcycle": "images/motorcycle.png",
    "person": "images/person-umbrella.png",
    "kitchen": "images/kitchen.png"
}

In [6]:
def encode_image(path):
    image = Image.open(path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = vision_model(**inputs)
        return outputs.last_hidden_state.detach()

patches = {k: encode_image(v) for k, v in images.items()}

In [7]:
# =====================================================
# 5️⃣ Training Data
# =====================================================

train_data = [

    # Two-image entries
    ("airplane", "motorcycle",
     "The first image shows an airplane. The second image shows a motorcycle."),

    ("motorcycle", "airplane",
     "The first image shows a motorcycle. The second image shows an airplane."),

    ("person", "kitchen",
     "The first image shows a person with an umbrella. The second image shows a kitchen."),

    ("kitchen", "person",
     "The first image shows a kitchen. The second image shows a person with an umbrella."),

    # Single-image entries
    ("airplane", None, "The image shows an airplane."),
    ("motorcycle", None, "The image shows a motorcycle."),
    ("person", None, "The image shows a person with an umbrella."),
    ("kitchen", None, "The image shows a kitchen.")
]

In [8]:
dataset = []

for img1, img2, caption in train_data:

    if img2 is not None:
        dataset.append(((patches[img1], patches[img2]), caption))
    else:
        dataset.append((patches[img1], caption))

In [9]:
# =====================================================
# 6️⃣ Optimizer
# =====================================================

optimizer = torch.optim.AdamW(
    list(projector.parameters()) +
    list(llm.parameters()) +
    [image_sep],
    lr=1e-4
)

In [10]:
# =====================================================
# 7️⃣ Training Loop
# =====================================================

llm.train()
projector.train()

for step in range(1001):

    total_loss = 0

    for visual_data, caption in dataset:

        tokens = tokenizer(caption, return_tensors="pt").to(device)

        # -------------------------------------------------
        # Build visual tokens
        # -------------------------------------------------

        if isinstance(visual_data, tuple):

            patch1, patch2 = visual_data

            patch1 = patch1.to(torch.bfloat16)
            patch2 = patch2.to(torch.bfloat16)

            patch1_proj = projector(patch1)
            patch2_proj = projector(patch2)

            sep = image_sep.expand(1, 1, -1)

            visual_tokens = torch.cat(
                [patch1_proj, sep, patch2_proj],
                dim=1
            )

        else:
            patch1 = visual_data.to(torch.bfloat16)
            visual_tokens = projector(patch1)

        # -------------------------------------------------
        # Text embeddings
        # -------------------------------------------------

        text_embeds = llm.get_input_embeddings()(tokens.input_ids)

        inputs_embeds = torch.cat([visual_tokens, text_embeds], dim=1)

        visual_attention = torch.ones(
            (1, visual_tokens.size(1)),
            device=device,
            dtype=tokens.attention_mask.dtype
        )

        full_attention = torch.cat(
            [visual_attention, tokens.attention_mask],
            dim=1
        )

        visual_label_pad = torch.full(
            (1, visual_tokens.size(1)),
            -100,
            device=device
        )

        full_labels = torch.cat(
            [visual_label_pad, tokens.input_ids],
            dim=1
        )

        outputs = llm(
            inputs_embeds=inputs_embeds,
            attention_mask=full_attention,
            labels=full_labels
        )

        loss = outputs.loss
        loss.backward()
        total_loss += loss.item()

    optimizer.step()
    optimizer.zero_grad()

    if step % 200 == 0:
        print(f"Step {step} | Avg Loss: {total_loss/len(dataset):.6f}")

Step 0 | Avg Loss: 10.557027
Step 200 | Avg Loss: 0.000286
Step 400 | Avg Loss: 0.000189
Step 600 | Avg Loss: 0.000141
Step 800 | Avg Loss: 0.000113
Step 1000 | Avg Loss: 0.000082


In [11]:
# =====================================================
# 8️⃣ Evaluation
# =====================================================

llm.eval()
projector.eval()

print("\n=== TEST RESULTS ===")

for visual_data, caption in dataset:

    if isinstance(visual_data, tuple):
        patch1, patch2 = visual_data
        patch1_proj = projector(patch1.to(torch.bfloat16))
        patch2_proj = projector(patch2.to(torch.bfloat16))
        sep = image_sep.expand(1, 1, -1)
        visual_tokens = torch.cat([patch1_proj, sep, patch2_proj], dim=1)
    else:
        visual_tokens = projector(visual_data.to(torch.bfloat16))

    with torch.no_grad():
        generated = llm.generate(
            inputs_embeds=visual_tokens,
            attention_mask=torch.ones(
                (1, visual_tokens.size(1)),
                device=device
            ),
            max_new_tokens=30,
            do_sample=False,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id
        )

    output = tokenizer.decode(generated[0], skip_special_tokens=True).split(tokenizer.eos_token)[0]

    print("Target :", caption)
    print("Output :", output)
    print("--------")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



=== TEST RESULTS ===
Target : The first image shows an airplane. The second image shows a motorcycle.
Output : The first image shows an airplane. The second image shows a motorcycle. The second image shows an airplane. The third image shows a motorcycle. The second
--------
Target : The first image shows a motorcycle. The second image shows an airplane.
Output : The first image shows a motorcycle. The second image shows an airplane. The second image shows a motorcycle. The third image shows an airplane. The second
--------
Target : The first image shows a person with an umbrella. The second image shows a kitchen.
Output : The first image shows a person with an umbrella. The second image shows a kitchen. The second image shows a person with an umbrella. The first image
--------
Target : The first image shows a kitchen. The second image shows a person with an umbrella.
Output : The first image shows a kitchen. The second image shows a person with an umbrella. The second image shows a ki

In [12]:
# =====================================================
# 9️⃣ Multi-Image Binding Diagnostics
# =====================================================

print("\n=== MULTI-IMAGE BINDING TEST ===")

for img1, img2, caption in train_data:

    print("\n---------------------------------")
    print(f"Entry: ({img1}, {img2})")

    patch1 = patches[img1]

    if img2 is not None:

        patch2 = patches[img2]

        patch1_proj = projector(patch1.to(torch.bfloat16))
        patch2_proj = projector(patch2.to(torch.bfloat16))
        sep = image_sep.expand(1, 1, -1)

        visual_tokens = torch.cat([patch1_proj, sep, patch2_proj], dim=1)

        visual_attention = torch.ones(
            (1, visual_tokens.size(1)),
            device=device
        )

        # NORMAL
        with torch.no_grad():
            out = llm.generate(
                inputs_embeds=visual_tokens,
                attention_mask=visual_attention,
                max_new_tokens=30,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id
            )

        print("NORMAL:")
        print(tokenizer.decode(out[0], skip_special_tokens=True).split(tokenizer.eos_token)[0])

        # ZERO SECOND
        zero_second_proj = projector(patch1.to(torch.bfloat16))
        visual_zero_second = zero_second_proj

        with torch.no_grad():
            out_zero = llm.generate(
                inputs_embeds=visual_zero_second,
                attention_mask=torch.ones(
                    (1, visual_zero_second.size(1)),
                    device=device
                ),
                max_new_tokens=30,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id
            )

        print("\nZERO SECOND IMAGE:")
        print(tokenizer.decode(out_zero[0], skip_special_tokens=True).split(tokenizer.eos_token)[0])

    else:
        patch1_proj = projector(patch1.to(torch.bfloat16))

        with torch.no_grad():
            out = llm.generate(
                inputs_embeds=patch1_proj,
                attention_mask=torch.ones(
                    (1, patch1_proj.size(1)),
                    device=device
                ),
                max_new_tokens=30,
                do_sample=False,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.eos_token_id
            )

        print("SINGLE IMAGE OUTPUT:")
        print(tokenizer.decode(out[0], skip_special_tokens=True).split(tokenizer.eos_token)[0])


=== MULTI-IMAGE BINDING TEST ===

---------------------------------
Entry: (airplane, motorcycle)
NORMAL:
The first image shows an airplane. The second image shows a motorcycle. The second image shows an airplane. The third image shows a motorcycle. The second

ZERO SECOND IMAGE:
The image shows an airplane. The second image shows a motorcycle. The first image shows an airplane. The second image shows a motorcycle. The third image

---------------------------------
Entry: (motorcycle, airplane)
NORMAL:
The first image shows a motorcycle. The second image shows an airplane. The second image shows a motorcycle. The third image shows an airplane. The second

ZERO SECOND IMAGE:
The image shows a motorcycle. The second image shows an airplane. The third image shows a motorcycle. The second image shows a motorcycle. The third image

---------------------------------
Entry: (person, kitchen)
NORMAL:
The first image shows a person with an umbrella. The second image shows a kitchen. The second