In [None]:
import torch
import torch.nn as nn
from transformers import AutoProcessor, AutoModel, AutoTokenizer, AutoModelForCausalLM
from PIL import Image

device = "cuda"
torch.set_float32_matmul_precision("high")

In [None]:
# =====================================================
# 1️⃣ Load Vision Encoder (Frozen)
# =====================================================

vision_name = "google/siglip-base-patch16-224"
processor = AutoProcessor.from_pretrained(vision_name, use_fast=True)

vision_model = AutoModel.from_pretrained(
    vision_name
).vision_model.to(device, dtype=torch.bfloat16)

vision_model.eval()
for p in vision_model.parameters():
    p.requires_grad = False

In [None]:
# =====================================================
# 2️⃣ Load LLM
# =====================================================

llm_name = "Qwen/Qwen3-0.6B"
tokenizer = AutoTokenizer.from_pretrained(llm_name)
tokenizer.pad_token = tokenizer.eos_token

llm = AutoModelForCausalLM.from_pretrained(
    llm_name,
    torch_dtype=torch.bfloat16
).to(device)

llm_hidden = llm.config.hidden_size

`torch_dtype` is deprecated! Use `dtype` instead!


In [4]:
# =====================================================
# 3️⃣ Projection Layer
# =====================================================

projector = nn.Linear(768, llm_hidden, bias=False).to(device, dtype=torch.bfloat16)

In [None]:
# =====================================================
# 4️⃣ Load 5 Images + Captions
# =====================================================

data = [
    ("images/airplane.png",
     "A large passenger airplane flying through the air."),

    ("images/motorcycle.png",
     "Riding a motorcycle down a street."),

    ("images/gd-dog.jpg",
     "A dog standing in a grassy field."),

    ("images/kitchen.png",
     "A kitchen stove, sink, and counter with stuff on it."),

    ("images/person-umbrella.png",
     "A person walking in the rain while holding an umbrella.")
]

In [6]:
def encode_image(path):
    image = Image.open(path).convert("RGB")
    inputs = processor(images=image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = vision_model(**inputs)
        return outputs.last_hidden_state.detach()

In [7]:
# Precompute visual tokens
dataset = []
for img_path, caption in data:
    patch = encode_image(img_path)
    tokens = tokenizer(caption, return_tensors="pt").to(device)
    dataset.append((patch, tokens, caption))

In [8]:
# =====================================================
# 5️⃣ Optimizer
# =====================================================

optimizer = torch.optim.AdamW(
    list(projector.parameters()) + list(llm.parameters()),
    lr=1e-4
)

In [9]:
# =====================================================
# 6️⃣ Training Loop
# =====================================================

llm.train()
projector.train()

for step in range(1200):

    total_loss = 0

    for patch_raw, tokens, _ in dataset:

        patch_tokens = projector(patch_raw)

        text_embeds = llm.get_input_embeddings()(tokens.input_ids)

        inputs_embeds = torch.cat([patch_tokens, text_embeds], dim=1)

        visual_attention = torch.ones(
            (1, patch_tokens.size(1)),
            device=device,
            dtype=tokens.attention_mask.dtype
        )

        full_attention = torch.cat(
            [visual_attention, tokens.attention_mask],
            dim=1
        )

        visual_label_pad = torch.full(
            (1, patch_tokens.size(1)),
            -100,
            device=device
        )

        full_labels = torch.cat(
            [visual_label_pad, tokens.input_ids],
            dim=1
        )

        outputs = llm(
            inputs_embeds=inputs_embeds,
            attention_mask=full_attention,
            labels=full_labels
        )

        loss = outputs.loss
        loss.backward()

        total_loss += loss.item()

    optimizer.step()
    optimizer.zero_grad()

    if step % 200 == 0:
        print(f"Step {step} | Avg Loss: {total_loss/len(dataset):.6f}")


Step 0 | Avg Loss: 8.819773
Step 200 | Avg Loss: 0.000247
Step 400 | Avg Loss: 0.000174
Step 600 | Avg Loss: 0.000145
Step 800 | Avg Loss: 0.000135
Step 1000 | Avg Loss: 0.000126


In [10]:
# =====================================================
# 7️⃣ Evaluation
# =====================================================

llm.eval()
projector.eval()

print("\n=== TEST RESULTS ===")

for patch_raw, _, caption in dataset:

    with torch.no_grad():
        patch_tokens = projector(patch_raw)

        generated = llm.generate(
            inputs_embeds=patch_tokens,
            attention_mask=torch.ones(
                (1, patch_tokens.size(1)),
                device=device
            ),
            max_new_tokens=15,
            do_sample=False
        )

    output = tokenizer.decode(generated[0], skip_special_tokens=True)

    print("Target  :", caption)
    print("Output  :", output)
    print("--------")

The following generation flags are not valid and may be ignored: ['temperature', 'top_p', 'top_k']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



=== TEST RESULTS ===
Target  : A large passenger airplane flying through the air.
Output  : A large passenger airplane flying through the air. air. air. air.
--------
Target  : Riding a motorcycle down a street.
Output  : Riding a motorcycle down a street.iding a street.iding a motorcycle
--------
Target  : A dog standing in a grassy field.
Output  : A dog standing in a grassy field. the the in theyy
--------
Target  : A kitchen stove, sink, and counter with stuff on it.
Output  : A kitchen stove, sink, and counter with stuff on it. and down
--------
Target  : A person walking in the rain while holding an umbrella.
Output  : A person walking in the rain while holding an umbrella. the the the the
--------


In [17]:
# =====================================================
# ZERO VISION TEST
# =====================================================

print("\n=== ZERO VISION TEST ===")

zero_tokens = torch.zeros_like(projector(dataset[0][0]))

with torch.no_grad():
    generated_zero = llm.generate(
        inputs_embeds=zero_tokens,
        attention_mask=torch.ones(
            (1, zero_tokens.size(1)),
            device=device
        ),
        max_new_tokens=15,
        do_sample=False
    )

print("Original Caption →", dataset[0][2])
print("Zero vision →", tokenizer.decode(generated_zero[0], skip_special_tokens=True))


=== ZERO VISION TEST ===
Original Caption → A large passenger airplane flying through the air.
Zero vision → ! * 1.5. The air. The air. The air
