In [1]:
# We start simple.

# Train on pairs:

# (Image A, Image B) → "The first image shows an airplane, the second shows a motorcycle."

# (Image B, Image A) → "The first image shows a motorcycle, the second shows an airplane."

# (A) → one sentence

# (B) → one sentence

In [None]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
from PIL import Image

device = "cuda"
torch.set_float32_matmul_precision("high")

In [None]:
# =====================================================
# 1️⃣ Load Vision Encoder (Frozen)
# =====================================================

vision_name = "google/siglip-base-patch16-224"
processor = AutoProcessor.from_pretrained(vision_name, use_fast=True)

vision_model = AutoModel.from_pretrained(
    vision_name
).vision_model.to(device, dtype=torch.bfloat16)

vision_model.eval()
for p in vision_model.parameters():
    p.requires_grad = False

In [None]:
# =====================================================
# 2️⃣ Load LLM
# =====================================================

llm_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    torch_dtype=torch.bfloat16
).to(device)

llm_hidden = llm.config.hidden_size

`torch_dtype` is deprecated! Use `dtype` instead!


In [5]:
# =====================================================
# 3️⃣ Projection Layer
# =====================================================

projector = nn.Linear(768, llm_hidden, bias=False).to(device, dtype=torch.bfloat16)

In [None]:
# =====================================================
# Images
# =====================================================

images = {
    "airplane": "images/airplane.png",
    "motorcycle": "images/motorcycle.png",
    "person": "images/person-umbrella.png",
    "kitchen": "images/kitchen.png"
}

In [7]:
def encode_image(path):
    image = Image.open(path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = vision_model(**inputs)
        return outputs.last_hidden_state.detach()

In [8]:
# Precompute patches
patches = {k: encode_image(v) for k, v in images.items()}

In [9]:
# =====================================================
# Training Pairs
# =====================================================

train_data = [

    # -------------------------------------------------
    # Two-image pairs
    # -------------------------------------------------

    ("airplane", "motorcycle",
     "The first image shows an airplane. The second image shows a motorcycle."),

    ("motorcycle", "airplane",
     "The first image shows a motorcycle. The second image shows an airplane."),

    ("person", "kitchen",
     "The first image shows a person with an umbrella. The second image shows a kitchen."),

    ("kitchen", "person",
     "The first image shows a kitchen. The second image shows a person with an umbrella."),

    # -------------------------------------------------
    # Single-image entries
    # -------------------------------------------------

    ("airplane", None,
     "The image shows an airplane."),

    ("motorcycle", None,
     "The image shows a motorcycle."),

    ("person", None,
     "The image shows a person with an umbrella."),

    ("kitchen", None,
     "The image shows a kitchen.")
]

In [10]:
dataset = []

for img1, img2, caption in train_data:

    patch1 = patches[img1]

    if img2 is not None:
        patch2 = patches[img2]
        combined_patch = torch.cat([patch1, patch2], dim=1)
    else:
        combined_patch = patch1  # single image

    combined_patch = combined_patch.to(torch.bfloat16)

    tokens = tokenizer(caption, return_tensors="pt").to(device)

    dataset.append((combined_patch, tokens, caption))

In [11]:
# =====================================================
# Optimizer
# =====================================================

optimizer = torch.optim.AdamW(
    list(projector.parameters()) + list(llm.parameters()),
    lr=1e-4
)

In [12]:
# =====================================================
# Training Loop
# =====================================================
llm.train()
projector.train()

for step in range(1001):

    total_loss = 0

    for patch_raw, tokens, _ in dataset:

        patch_tokens = projector(patch_raw)

        text_embeds = llm.get_input_embeddings()(tokens.input_ids)

        inputs_embeds = torch.cat([patch_tokens, text_embeds], dim=1)

        visual_attention = torch.ones(
            (1, patch_tokens.size(1)),
            device=device,
            dtype=tokens.attention_mask.dtype
        )

        full_attention = torch.cat(
            [visual_attention, tokens.attention_mask],
            dim=1
        )

        visual_label_pad = torch.full(
            (1, patch_tokens.size(1)),
            -100,
            device=device
        )

        full_labels = torch.cat(
            [visual_label_pad, tokens.input_ids],
            dim=1
        )

        outputs = llm(
            inputs_embeds=inputs_embeds,
            attention_mask=full_attention,
            labels=full_labels
        )

        loss = outputs.loss
        loss.backward()
        total_loss += loss.item()

    optimizer.step()
    optimizer.zero_grad()

    if step % 200 == 0:
        print(f"Step {step} | Avg Loss: {total_loss/len(dataset):.6f}")


Step 0 | Avg Loss: 11.963358
Step 200 | Avg Loss: 0.000398
Step 400 | Avg Loss: 0.000232
Step 600 | Avg Loss: 0.000177
Step 800 | Avg Loss: 0.000152
Step 1000 | Avg Loss: 0.000138


In [13]:
# =====================================================
# Evaluation
# =====================================================

llm.eval()
projector.eval()

print("\n=== TEST RESULTS ===")

for patch_raw, _, caption in dataset:

    with torch.no_grad():
        patch_tokens = projector(patch_raw)

        generated = llm.generate(
            inputs_embeds=patch_tokens,
            attention_mask=torch.ones(
                (1, patch_tokens.size(1)),
                device=device
            ),
            max_new_tokens=30,
            do_sample=False
        )

    output = tokenizer.decode(generated[0], skip_special_tokens=True)

    print("Target  :", caption)
    print("Output  :", output)
    print("--------")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



=== TEST RESULTS ===
Target  : The first image shows an airplane. The second image shows a motorcycle.
Output  : The first image shows an airplane. The second image shows a motorcycle. The second image shows a motorcycle. The second image shows a motorcycle. The second
--------
Target  : The first image shows a motorcycle. The second image shows an airplane.
Output  : The first image shows a motorcycle. The second image shows an airplane. The second image shows a motorcycle. The second image shows an airplane. The second
--------
Target  : The first image shows a person with an umbrella. The second image shows a kitchen.
Output  : The first image shows a person with an umbrella. The second image shows a kitchen. The second image shows a person with an umbrella. The second image
--------
Target  : The first image shows a kitchen. The second image shows a person with an umbrella.
Output  : The first image shows a kitchen. The second image shows a person with an umbrella. The second imag

In [15]:
# =====================================================
# MULTI-IMAGE BINDING DIAGNOSTICS (UPDATED)
# =====================================================

print("\n=== MULTI-IMAGE BINDING TEST ===")

llm.eval()
projector.eval()

for img1, img2, caption in train_data:

    print("\n---------------------------------")
    print(f"Entry: ({img1}, {img2})")

    patch1 = patches[img1]

    # -------------------------------------------------
    # CASE 1: Two-image entry
    # -------------------------------------------------
    if img2 is not None:

        patch2 = patches[img2]

        combined = torch.cat([patch1, patch2], dim=1).to(torch.bfloat16)

        visual_attention = torch.ones(
            (1, combined.size(1)),
            device=device
        )

        with torch.no_grad():
            projected = projector(combined)

            out_normal = llm.generate(
                inputs_embeds=projected,
                attention_mask=visual_attention,
                max_new_tokens=30,
                do_sample=False
            )

        print("NORMAL:")
        print(tokenizer.decode(out_normal[0], skip_special_tokens=True))

        # -------------------------
        # ZERO SECOND IMAGE
        # -------------------------
        zero_second = torch.cat(
            [patch1, torch.zeros_like(patch2)],
            dim=1
        ).to(torch.bfloat16)

        with torch.no_grad():
            projected_zero_second = projector(zero_second)

            out_zero_second = llm.generate(
                inputs_embeds=projected_zero_second,
                attention_mask=visual_attention,
                max_new_tokens=30,
                do_sample=False
            )

        print("\nZERO SECOND IMAGE:")
        print(tokenizer.decode(out_zero_second[0], skip_special_tokens=True))

        # -------------------------
        # ZERO FIRST IMAGE
        # -------------------------
        zero_first = torch.cat(
            [torch.zeros_like(patch1), patch2],
            dim=1
        ).to(torch.bfloat16)

        with torch.no_grad():
            projected_zero_first = projector(zero_first)

            out_zero_first = llm.generate(
                inputs_embeds=projected_zero_first,
                attention_mask=visual_attention,
                max_new_tokens=30,
                do_sample=False
            )

        print("\nZERO FIRST IMAGE:")
        print(tokenizer.decode(out_zero_first[0], skip_special_tokens=True))

    # -------------------------------------------------
    # CASE 2: Single-image entry
    # -------------------------------------------------
    else:

        combined = patch1.to(torch.bfloat16)

        visual_attention = torch.ones(
            (1, combined.size(1)),
            device=device
        )

        with torch.no_grad():
            projected = projector(combined)

            out_normal = llm.generate(
                inputs_embeds=projected,
                attention_mask=visual_attention,
                max_new_tokens=30,
                do_sample=False
            )

        print("SINGLE IMAGE OUTPUT:")
        print(tokenizer.decode(out_normal[0], skip_special_tokens=True))


=== MULTI-IMAGE BINDING TEST ===

---------------------------------
Entry: (airplane, motorcycle)
NORMAL:
The first image shows an airplane. The second image shows a motorcycle. The second image shows a motorcycle. The second image shows a motorcycle. The second

ZERO SECOND IMAGE:
. The first image shows a motorcycle. The second image shows an airplane. The first image shows a motorcycle. The second image shows a person with an

ZERO FIRST IMAGE:
The image shows a motorcycle. The second image shows a motorcycle. The second image shows an airplane. The second image shows a motorcycle. The second image

---------------------------------
Entry: (motorcycle, airplane)
NORMAL:
The first image shows a motorcycle. The second image shows an airplane. The second image shows a motorcycle. The second image shows an airplane. The second

ZERO SECOND IMAGE:
. The second image shows a motorcycle. The second image shows an airplane. The second image shows a motorcycle. The second image shows a pers