In [None]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
from PIL import Image

device = "cuda"
torch.set_float32_matmul_precision("high")

In [None]:
# =====================================================
# 1Ô∏è‚É£ Load Vision Encoder (Frozen)
# =====================================================

vision_name = "google/siglip-base-patch16-224"
processor = AutoProcessor.from_pretrained(vision_name, use_fast=True)

vision_model = AutoModel.from_pretrained(
    vision_name
).vision_model.to(device, dtype=torch.bfloat16)

vision_model.eval()
for p in vision_model.parameters():
    p.requires_grad = False

In [None]:
# =====================================================
# 2Ô∏è‚É£ Load LLM
# =====================================================

llm_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    torch_dtype=torch.bfloat16
).to(device)

llm_hidden = llm.config.hidden_size

`torch_dtype` is deprecated! Use `dtype` instead!


In [5]:
# =====================================================
# 3Ô∏è‚É£ Projection Layer
# =====================================================

projector = nn.Linear(768, llm_hidden, bias=False).to(device, dtype=torch.bfloat16)

In [None]:
# =====================================================
# 4Ô∏è‚É£ Load Two Images + Captions
# =====================================================

image_A = Image.open("images/airplane.png").convert("RGB")
caption_A = "A large passenger airplane flying through the air."

image_B = Image.open("images/motorcycle.png").convert("RGB")
caption_B = "Riding a motorcycle down a street."

In [6]:
def encode_image(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = vision_model(**inputs)
        return outputs.last_hidden_state.detach()

In [8]:
patch_A = encode_image(image_A)
patch_B = encode_image(image_B)

In [9]:
# =====================================================
# 5Ô∏è‚É£ Tokenize Captions
# =====================================================

tokens_A = tokenizer(caption_A, return_tensors="pt").to(device)
tokens_B = tokenizer(caption_B, return_tensors="pt").to(device)

In [10]:
# =====================================================
# 6Ô∏è‚É£ Optimizer
# =====================================================

optimizer = torch.optim.AdamW(
    list(projector.parameters()) + list(llm.parameters()),
    lr=1e-4
)

In [11]:
# =====================================================
# 7Ô∏è‚É£ Training Loop
# =====================================================

llm.train()
projector.train()

for step in range(800):

    total_loss = 0

    for patch_raw, tokens in [(patch_A, tokens_A), (patch_B, tokens_B)]:

        patch_tokens = projector(patch_raw)

        text_embeds = llm.get_input_embeddings()(tokens.input_ids)

        inputs_embeds = torch.cat([patch_tokens, text_embeds], dim=1)

        visual_attention = torch.ones(
            (1, patch_tokens.size(1)),
            device=device,
            dtype=tokens.attention_mask.dtype
        )

        full_attention = torch.cat(
            [visual_attention, tokens.attention_mask],
            dim=1
        )

        visual_label_pad = torch.full(
            (1, patch_tokens.size(1)),
            -100,
            device=device
        )

        full_labels = torch.cat(
            [visual_label_pad, tokens.input_ids],
            dim=1
        )

        outputs = llm(
            inputs_embeds=inputs_embeds,
            attention_mask=full_attention,
            labels=full_labels
        )

        loss = outputs.loss
        loss.backward()

        total_loss += loss.item()

    optimizer.step()
    optimizer.zero_grad()

    if step % 100 == 0:
        print(f"Step {step} | Loss: {total_loss/2:.6f}")

Step 0 | Loss: 10.959513
Step 100 | Loss: 0.000117
Step 200 | Loss: 0.000099
Step 300 | Loss: 0.000085
Step 400 | Loss: 0.000078
Step 500 | Loss: 0.000072
Step 600 | Loss: 0.000067
Step 700 | Loss: 0.000065


In [16]:
# =====================================================
# 8Ô∏è‚É£ Evaluation Function
# =====================================================

def generate_caption(patch_raw):

    llm.eval()
    projector.eval()

    with torch.no_grad():
        patch_tokens = projector(patch_raw)

        visual_attention = torch.ones(
            (1, patch_tokens.size(1)),
            device=device
        )

        generated = llm.generate(
            inputs_embeds=patch_tokens,
            attention_mask=visual_attention,
            max_new_tokens=10,
            do_sample=False
        )

    return tokenizer.decode(generated[0], skip_special_tokens=True)

In [17]:
# =====================================================
# 9Ô∏è‚É£ Test Both Images
# =====================================================

print("\n=== TEST RESULTS ===")
print("Image A ‚Üí", generate_caption(patch_A))
print("Image B ‚Üí", generate_caption(patch_B))


=== TEST RESULTS ===
Image A ‚Üí A large passenger airplane flying through the air.A
Image B ‚Üí Riding a motorcycle down a street.Riding


In [18]:
# =====================================================
# üîü Zero Vision Test
# =====================================================

print("\n=== ZERO VISION TEST ===")

zero_tokens = torch.zeros_like(projector(patch_A))

with torch.no_grad():
    generated_zero = llm.generate(
        inputs_embeds=zero_tokens,
        attention_mask=torch.ones(
            (1, zero_tokens.size(1)),
            device=device
        ),
        max_new_tokens=10,
        do_sample=False
    )

print("Zero vision ‚Üí", tokenizer.decode(generated_zero[0], skip_special_tokens=True))


=== ZERO VISION TEST ===
Zero vision ‚Üí !iding a motorcycle down a street. A street
