In [None]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset
from tqdm import tqdm
import re

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    CLIPVisionModel,
    CLIPImageProcessor,
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

clip_name = "openai/clip-vit-base-patch32"
clip_vision = CLIPVisionModel.from_pretrained(clip_name).to(device)
clip_processor = CLIPImageProcessor.from_pretrained(clip_name)

for p in clip_vision.parameters():
    p.requires_grad = False

vision_width = clip_vision.config.hidden_size
print("CLIP vision loaded. hidden_size =", vision_width)

llama_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(llama_name)
llama = AutoModelForCausalLM.from_pretrained(
    llama_name,
    torch_dtype=torch.float16,
).to(device)

for p in llama.parameters():
    p.requires_grad = False

llama.train()
lm_hidden_size = llama.config.hidden_size
print("TinyLLaMA loaded. hidden_size =", lm_hidden_size)

class Projector(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.GELU(),
            nn.Linear(out_dim, out_dim),
        )

    def forward(self, x):
        return self.mlp(x)

projector = Projector(vision_width, lm_hidden_size).to(device)
print("Projector loaded on", device)

def normalize(s):
    s = s.lower().strip()
    s = re.sub(r"[^a-z0-9]+", " ", s)
    return " ".join(s.split())

def f1(pred, gt):
    pt = normalize(pred).split()
    gt = normalize(gt).split()
    if len(pt) == 0 or len(gt) == 0:
        return float(pt == gt)
    common = set(pt) & set(gt)
    if not common:
        return 0.0
    p = len(common) / len(pt)
    r = len(common) / len(gt)
    return 2 * p * r / (p + r)

def em(pred, gt):
    return float(normalize(pred) == normalize(gt))

train_ds = load_dataset("flaviagiammarino/path-vqa", split="train")
test_ds  = load_dataset("flaviagiammarino/path-vqa", split="test")

MAX_TRAIN_SAMPLES = None
MAX_TEST_SAMPLES  = None

if MAX_TRAIN_SAMPLES:
    train_ds = train_ds.select(range(MAX_TRAIN_SAMPLES))
if MAX_TEST_SAMPLES:
    test_ds = test_ds.select(range(MAX_TEST_SAMPLES))

print("Train set:", len(train_ds), " Test set:", len(test_ds))

def build_inputs_and_labels(question, answer):
    prompt = f"Question: {question}\nAnswer:"
    full   = f"{prompt} {answer}"

    tok_full   = tokenizer(full, return_tensors="pt")
    tok_prompt = tokenizer(prompt, return_tensors="pt")

    input_ids_full = tok_full.input_ids.to(device)        # (1, L_full)
    input_ids_prompt = tok_prompt.input_ids.to(device)    # (1, L_prompt)

    L_full   = input_ids_full.size(1)
    L_prompt = input_ids_prompt.size(1)


    labels = input_ids_full.clone()
    labels[:, :L_prompt] = -100

    labels_padded = torch.full(
        (1, L_full + 1),
        -100,
        dtype=torch.long,
        device=device,
    )
    labels_padded[:, 1:] = labels

    return input_ids_full, labels_padded, L_full

NUM_EPOCHS = 15
LR = 1e-4

optimizer = torch.optim.AdamW(projector.parameters(), lr=LR)

def evaluate_model():
    projector.eval()
    llama.eval()

    all_f1, all_em = [], []

    for sample in tqdm(test_ds, desc="Evaluating", leave=False):
        question = sample["question"]
        answer_gt = sample["answer"]
        img = sample["image"]       # PIL Image

        with torch.no_grad():
            clip_inputs = clip_processor(
                images=img,
                return_tensors="pt"
            ).to(device)

            vision_outputs = clip_vision(**clip_inputs)
            # pooler_output: (B, vision_width)
            vision_feat = vision_outputs.pooler_output    # (1, 768)

            vision_embed_f32 = projector(vision_feat)     # (1, 2048) float32
            vision_embed = vision_embed_f32.to(torch.float16)

            prompt = f"Question: {question}\nAnswer:"
            tok = tokenizer(prompt, return_tensors="pt")
            input_ids = tok.input_ids.to(device)

            text_emb = llama.model.embed_tokens(input_ids)  # (1, L, 2048) fp16

            inputs_embeds = torch.cat(
                [vision_embed.unsqueeze(1), text_emb],
                dim=1
            )

            output = llama.generate(
                inputs_embeds=inputs_embeds,
                max_new_tokens=32,
                do_sample=False,
            )

            pred = tokenizer.decode(output[0], skip_special_tokens=True)
            pred = pred.split("Answer:")[-1].strip().split("\n")[0].strip()

        all_f1.append(f1(pred, answer_gt))
        all_em.append(em(pred, answer_gt))

    mean_f1 = float(np.mean(all_f1))
    mean_em = float(np.mean(all_em))
    return mean_f1, mean_em

for epoch in range(NUM_EPOCHS):
    projector.train()
    llama.train()

    total_loss = 0.0

    for sample in tqdm(train_ds, desc=f"Epoch {epoch+1} / {NUM_EPOCHS}"):
        question = sample["question"]
        answer_gt = sample["answer"]
        img = sample["image"]

        clip_inputs = clip_processor(
            images=img,
            return_tensors="pt"
        ).to(device)  # pixel_values: (1,3,224,224)

        with torch.no_grad():
            vision_outputs = clip_vision(**clip_inputs)
            vision_feat = vision_outputs.pooler_output     # (1, 768) float32

        vision_embed_f32 = projector(vision_feat)          # (1, 2048) float32
        vision_embed = vision_embed_f32.to(torch.float16)  # (1, 2048) fp16

        input_ids, labels_padded, L_full = build_inputs_and_labels(
            question, answer_gt
        )

        with torch.no_grad():
            text_emb = llama.model.embed_tokens(input_ids)  # (1,L_full,2048), fp16

        inputs_embeds = torch.cat(
            [vision_embed.unsqueeze(1), text_emb],
            dim=1
        )  # (1, L_full+1, 2048)

        outputs = llama(
            inputs_embeds=inputs_embeds,
            labels=labels_padded,
        )
        loss = outputs.loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    avg_loss = total_loss / len(train_ds)
    print(f"\n[Epoch {epoch+1}] train loss = {avg_loss:.4f}")

    val_f1, val_em = evaluate_model()
    print(f"[Epoch {epoch+1}] val F1 = {val_f1:.4f}, val EM = {val_em:.4f}\n")

print("Training finished.")


Device: cuda
CLIP vision loaded. hidden_size = 768


`torch_dtype` is deprecated! Use `dtype` instead!


TinyLLaMA loaded. hidden_size = 2048
Projector loaded on cuda
Train set: 19654  Test set: 6719


Epoch 1 / 15: 100%|██████████| 19654/19654 [16:42<00:00, 19.61it/s]



[Epoch 1] train loss = 1.1457


                                                               

[Epoch 1] val F1 = 0.0523, val EM = 0.0000



Epoch 2 / 15: 100%|██████████| 19654/19654 [16:57<00:00, 19.32it/s]



[Epoch 2] train loss = 1.0456


                                                               

[Epoch 2] val F1 = 0.3245, val EM = 0.2722



Epoch 3 / 15: 100%|██████████| 19654/19654 [16:52<00:00, 19.42it/s]



[Epoch 3] train loss = 0.9545


Epoch 12 / 15: 100%|██████████| 19654/19654 [16:03<00:00, 20.39it/s]



[Epoch 12] train loss = 0.5515


Evaluating:  35%|███▍      | 2322/6719 [13:40<25:29,  2.88it/s]