<a href="https://colab.research.google.com/github/alierenc/di725-transformers-and-attention-based-deep-networks-term-project/blob/main/Phase%20III/3.1.%20SigLIP-T5-Decoder%20Custom%20VLM%20-%20Image%20Captioning%20Fine-tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from huggingface_hub import login
hf_token = " " # Huggingface token
login(token = hf_token)

In [None]:
# Import and log in wandb
import wandb

wandb.login()
# Initialize W&B run
wandb.init(project="term-project-vision-language-model", name="siglip-t5decoder")

In [None]:
!pip install -U datasets
!pip install bitsandbytes --upgrade

In [None]:
from datasets import load_dataset, DatasetDict

# Load the dataset of full riscm
ds = load_dataset('caglarmert/full_riscm')

full = ds["train"]

# test   = indices [0, 3150)
test_ds = full.select(range(3150))

# validation = indices [3150, 6300)
val_ds = full.select(range(3150, 6300))

# train  = indices [6300, end)
train_ds = full.select(range(6300, len(full)))

# bundle into a DatasetDict
ds = DatasetDict({
    "test": test_ds,
    "val": val_ds,
    "train": train_ds,
})

In [None]:
ds["test"][0]

In [None]:
import torch
import torch.nn as nn
from transformers.modeling_outputs import BaseModelOutput

class CustomVLM(nn.Module):
    def __init__(self, vision_model, language_model, vision_hidden_size, language_hidden_size):
        super().__init__()
        self.vision_model = vision_model
        self.language_model = language_model
        self.vision_proj = nn.Linear(vision_hidden_size, language_hidden_size)

    def forward(self, image, input_ids=None, attention_mask=None, labels=None):
        # Step 1: Encode image
        vision_output = self.vision_model(pixel_values=image).last_hidden_state  # [B, N, D]
        vision_embedding = vision_output.mean(dim=1)                             # [B, D]
        encoder_hidden_states = self.vision_proj(vision_embedding).unsqueeze(1)  # [B, 1, d_model]

        # Step 2: Package encoder output for T5
        encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)

        # Step 3: Decode using prompt + image context
        output = self.language_model(
            input_ids=input_ids,                # prompt like "caption en"
            attention_mask=attention_mask,      # attention mask for prompt
            encoder_outputs=encoder_outputs,    # SigLIP embedding as context
            labels=labels                       # optional: ground-truth caption for training
        )

        return output


In [None]:
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    SiglipVisionModel,
    AutoImageProcessor,
    BitsAndBytesConfig
)
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
import bitsandbytes
import torch

# Define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# QLoRA quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

# Load full T5 model (needed for decoder + lm_head)
language_model = T5ForConditionalGeneration.from_pretrained(
    "t5-base",
    quantization_config=bnb_config,
    device_map="auto"
)

# Prepare for QLoRA
language_model = prepare_model_for_kbit_training(language_model)

# Define LoRA config (on full model — affects both encoder & decoder if needed)
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q", "v"],  # extend to 'k', 'o', etc. for more aggressive tuning
    lora_dropout=0.1,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)

# Apply LoRA
language_model = get_peft_model(language_model, lora_config)

# Load tokenizer
tokenizer = T5Tokenizer.from_pretrained("t5-base")
tokenizer.pad_token = tokenizer.eos_token

# Load SigLIP vision encoder
vision_model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
vision_model.requires_grad_(False)

# Load image processor
image_processor = AutoImageProcessor.from_pretrained("google/siglip-base-patch16-224")

# Determine hidden sizes
vision_hidden_size = vision_model.config.hidden_size     # e.g. 768
language_hidden_size = language_model.config.d_model     # 768 for t5-base

# Initialize CustomVLM with full language model
model = CustomVLM(
    vision_model=vision_model,
    language_model=language_model,
    vision_hidden_size=vision_hidden_size,
    language_hidden_size=language_hidden_size
)

# Move to device
model = model.to(device)


In [None]:
from transformers.modeling_outputs import BaseModelOutput

# We check whether the model can produce captions
# Set model to evaluation mode
model.eval()
max_new_tokens = 30
eos_token_id = tokenizer.eos_token_id

for i in range(10):
    print(f"Generating caption for sample {i + 1}")

    # Preprocess image
    image = ds["test"][i]["image"]
    pixel_values = image_processor(image, return_tensors="pt")["pixel_values"].to(device)

    with torch.no_grad():
        # Vision encoding inside VLM forward
        vision_output = model.vision_model(pixel_values=pixel_values).last_hidden_state
        vision_embedding = vision_output.mean(dim=1)
        encoder_hidden_states = model.vision_proj(vision_embedding).unsqueeze(1)

        # Wrap as BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)

        # Prompt setup
        prompt = "caption en"
        tokenized = tokenizer(prompt, return_tensors="pt").to(device)
        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]

        # Remove <pad> if it appears as the first token
        if input_ids[0, 0] == tokenizer.pad_token_id:
            input_ids = input_ids[:, 1:]
            attention_mask = attention_mask[:, 1:]

        decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        caption = decoded.replace(tokenizer.pad_token, "").replace(prompt.strip(), "").strip()


        # Generate
        generated_ids = model.language_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            max_new_tokens=max_new_tokens,
            eos_token_id=eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False
        )

        # Decode output
        decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        caption = decoded.replace(prompt, "").strip()
        caption = decoded.replace("<pad>", "").strip()
        print("Generated caption:", repr(caption))
        print()


In [None]:
kprint("Language model device:", next(model.language_model.parameters()).device)
print("Vision model device:", next(model.vision_model.parameters()).device)
print("Vision projection layer device:", next(model.vision_proj.parameters()).device)

In [None]:
def count_parameters(module):
    total = sum(p.numel() for p in module.parameters())
    trainable = sum(p.numel() for p in module.parameters() if p.requires_grad)
    return {"Total": total, "Trainable": trainable}

print("Printing the total number of parameters and the number of trainable parameters:")
print("Vision Encoder (SigLIP):", count_parameters(model.vision_model))
print("Vision Projection Layer:", count_parameters(model.vision_proj))
print("Language Model (full T5):", count_parameters(model.language_model))

# Count only T5 decoder (within the full model)
print("T5 Decoder (only):", count_parameters(model.language_model.base_model.decoder))



In [None]:
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from torch.nn import functional as F

# Collate function for T5 captioning
def collate_fn(batch):
    images = [image_processor(example["image"], return_tensors="pt")["pixel_values"].squeeze(0) for example in batch]
    captions = [example["caption_3"] for example in batch]

    pixel_values = torch.stack(images)

    # For T5: prompt goes into input_ids, caption goes into labels
    prompts = ["caption en"] * len(captions)
    tokenized_input = tokenizer(prompts, padding=True, return_tensors="pt", truncation=True, max_length=512)
    tokenized_labels = tokenizer(captions, padding=True, return_tensors="pt", truncation=True, max_length=512)

    return {
        "pixel_values": pixel_values,
        "input_ids": tokenized_input["input_ids"],
        "attention_mask": tokenized_input["attention_mask"],
        "labels": tokenized_labels["input_ids"]  # T5 will shift internally
    }

# DataLoaders
train_loader = DataLoader(ds["train"], batch_size=256, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(ds["val"], batch_size=256, shuffle=False, collate_fn=collate_fn)

# Optimizer
optimizer = AdamW(model.parameters(), lr=5e-5)

# Training Loop
num_epochs = 3

for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    total_samples = 0

    for step, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
        pixel_values = batch["pixel_values"].to(device)
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Replace <pad> tokens in labels with -100 so they are ignored in loss
        labels[labels == tokenizer.pad_token_id] = -100

        optimizer.zero_grad()

        outputs = model(
            image=pixel_values,
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        loss = outputs.loss
        loss.backward()
        optimizer.step()

        batch_size = input_ids.size(0)
        total_loss += loss.item() * batch_size
        total_samples += batch_size

        # wandb.log({
        #     "train/loss": loss.item(),
        #     "train/step": epoch * len(train_loader) + step
        # })


    avg_train_loss = total_loss / total_samples
    print(f"Epoch {epoch+1} completed. Average Train Loss: {avg_train_loss:.4f}")
    # wandb.log({"train/avg_epoch_loss": avg_train_loss, "epoch": epoch + 1})

    # Validation Loop
    model.eval()
    val_loss = 0.0
    val_samples = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc="Validating"):
            pixel_values = batch["pixel_values"].to(device)
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            # Same label cleaning
            labels[labels == tokenizer.pad_token_id] = -100

            outputs = model(
                image=pixel_values,
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            batch_size = input_ids.size(0)
            val_loss += outputs.loss.item() * batch_size
            val_samples += batch_size

    avg_val_loss = val_loss / val_samples
    print(f"Average Validation Loss: {avg_val_loss:.4f}")
    # wandb.log({"val/loss": avg_val_loss, "epoch": epoch + 1})



In [None]:
# Set model to evaluation mode
model.eval()
max_new_tokens = 30
eos_token_id = tokenizer.eos_token_id
predictions = []

for i in range(10):
    print(f"Generating caption for sample {i + 1}")

    # Preprocess image
    image = ds["test"][i]["image"]
    pixel_values = image_processor(image, return_tensors="pt")["pixel_values"].to(device)

    with torch.no_grad():
        # Vision encoding inside VLM forward
        vision_output = model.vision_model(pixel_values=pixel_values).last_hidden_state
        vision_embedding = vision_output.mean(dim=1)
        encoder_hidden_states = model.vision_proj(vision_embedding).unsqueeze(1)

        # Wrap as BaseModelOutput
        encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)

        # Prompt setup
        prompt = "caption en"
        tokenized = tokenizer(prompt, return_tensors="pt").to(device)
        input_ids = tokenized["input_ids"]
        attention_mask = tokenized["attention_mask"]

        # Remove <pad> if it appears as the first token
        if input_ids[0, 0] == tokenizer.pad_token_id:
            input_ids = input_ids[:, 1:]
            attention_mask = attention_mask[:, 1:]

        decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        caption = decoded.replace(tokenizer.pad_token, "").replace(prompt.strip(), "").strip()


        # Generate
        generated_ids = model.language_model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            encoder_outputs=encoder_outputs,
            max_new_tokens=max_new_tokens,
            eos_token_id=eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
            do_sample=False
        )

        # Decode output
        decoded = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
        caption = decoded.replace(prompt, "").strip()
        caption = decoded.replace("<pad>", "").strip()
        predictions.append(caption)
        print("Generated caption:", repr(caption))
        print()

In [None]:
# Get the references
# Define a varible to store the reference captions
all_references = []
for i in range(len(ds["test"])):
    # Get the reference
    reference_per_sample = []
    for j in range(1,6):
        reference = ds["test"][i][f"caption_{j}"]
        reference_per_sample.append(reference)
        print(f"The reference caption_{j}:")
        print(repr(reference))

    print()
    all_references.append(reference_per_sample)

In [None]:
# Check the format of the reference captions
print(all_references[:5])

In [None]:
# Check the format of the predicted captions. Each sample starts with a new line
print(predictions[:5])

In [None]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, corpus_bleu, SmoothingFunction

nltk.download('punkt_tab')
nltk.download('punkt')

# Tokenize references and predictions:
tokenized_refs = [
    [nltk.word_tokenize(ref.lower()) for ref in refs]
    for refs in all_references
]

tokenized_hyps = [nltk.word_tokenize(pred.lower()) for pred in predictions]

In [None]:
tokenized_refs[0]

In [None]:
# Sentence-level BLEU-2
smooth = SmoothingFunction().method1
for i, (refs_per_sample, hyp_tok) in enumerate(zip(tokenized_refs, tokenized_hyps)):
    scores = []
    for refs_tok in refs_per_sample:
        score = sentence_bleu(
            [refs_tok],
            hyp_tok,
            weights=(1/2, 1/2),
            smoothing_function=smooth
        )
        scores.append(score)
    max_score = max(scores)
    print(f"Example {i+1:2d} BLEU-2: {max_score*100:.2f}")

In [None]:
# Corpus-level BLEU-2
# corpus_bleu expects list-of-list-of-tokens refs, and list-of-tokens hyps
corpus_score = corpus_bleu(
    tokenized_refs,
    tokenized_hyps,
    weights=(1/2, 1/2),
    smoothing_function=smooth
)
print(f"\nCorpus BLEU-2: {corpus_score*100:.2f}")

In [None]:
# Sentence-level BLEU-3
smooth = SmoothingFunction().method1
for i, (refs_per_sample, hyp_tok) in enumerate(zip(tokenized_refs, tokenized_hyps)):
    scores = []
    for refs_tok in refs_per_sample:
        score = sentence_bleu(
            [refs_tok],
            hyp_tok,
            weights=(1/3, 1/3, 1/3),
            smoothing_function=smooth
        )
        scores.append(score)
    max_score = max(scores)
    print(f"Example {i+1:2d} BLEU-3: {max_score*100:.2f}")

In [None]:
# Corpus-level BLEU-3
# corpus_bleu expects list-of-list-of-tokens refs, and list-of-tokens hyps
corpus_score = corpus_bleu(
    tokenized_refs,
    tokenized_hyps,
    weights=(1/3, 1/3, 1/3),
    smoothing_function=smooth
)
print(f"\nCorpus BLEU-3: {corpus_score*100:.2f}")

In [None]:
# Sentence-level BLEU-4
smooth = SmoothingFunction().method1
for i, (refs_per_sample, hyp_tok) in enumerate(zip(tokenized_refs, tokenized_hyps)):
    scores = []
    for refs_tok in refs_per_sample:
        score = sentence_bleu(
            [refs_tok],
            hyp_tok,
            weights=(1/4, 1/4, 1/4, 1/4),
            smoothing_function=smooth
        )
        scores.append(score)
    max_score = max(scores)
    print(f"Example {i+1:2d} BLEU-4: {max_score*100:.2f}")

In [None]:
# Corpus-level BLEU-4
# corpus_bleu expects list-of-list-of-tokens refs, and list-of-tokens hyps
corpus_score = corpus_bleu(
    tokenized_refs,
    tokenized_hyps,
    weights=(1/4, 1/4, 1/4, 1/4),
    smoothing_function=smooth
)
print(f"\nCorpus BLEU-4: {corpus_score*100:.2f}")

In [None]:
# Go on to calculate ROUGE scores
!pip install rouge-score

In [None]:
import nltk
from collections import Counter

# Ensure tokenizer
nltk.download('punkt', quiet=True)

def rouge_n(ref: str, hyp: str, n: int = 4):
    ref_toks = nltk.word_tokenize(ref.lower())
    hyp_toks = nltk.word_tokenize(hyp.lower())
    ref_ngrams = list(nltk.ngrams(ref_toks, n))
    hyp_ngrams = list(nltk.ngrams(hyp_toks, n))
    ref_counts = Counter(ref_ngrams)
    hyp_counts = Counter(hyp_ngrams)
    overlap = sum(min(ref_counts[ng], hyp_counts[ng]) for ng in ref_counts)
    recall = overlap / max(len(ref_ngrams), 1)
    precision = overlap / max(len(hyp_ngrams), 1)
    f1 = 2 * recall * precision / (recall + precision + 1e-8)
    return (recall, precision, f1)

# Compute ROUGE-2
all_recalls, all_precisions, all_f1s = [], [], []
for refs, pred in zip(all_references, predictions):
    recalls_per_sample, precisions_per_sample, f1s_per_sample = [], [], []
    for ref in refs:
        r, p, f = rouge_n(ref, pred, n=2)
        recalls_per_sample.append(r)
        precisions_per_sample.append(p)
        f1s_per_sample.append(f)

    max_score = max(f1s_per_sample)
    max_index = f1s_per_sample.index(max_score)
    all_recalls.append(recalls_per_sample[max_index])
    all_precisions.append(precisions_per_sample[max_index])
    all_f1s.append(f1s_per_sample[max_index])
    print(f"REF:  {refs[max_index]!r}")
    print(f"HYP:  {pred!r}")
    print(f"   ROUGE-2 Recall:    {recalls_per_sample[max_index] * 100:.2f}%")
    print(f"   ROUGE-2 Precision: {precisions_per_sample[max_index] * 100:.2f}%")
    print(f"   ROUGE-2 F1:        {f1s_per_sample[max_index] * 100:.2f}%\n")

In [None]:
# Report overall averages
avg_r = sum(all_recalls) / len(all_recalls)
avg_p = sum(all_precisions) / len(all_precisions)
avg_f = sum(all_f1s) / len(all_f1s)
print("=== AVERAGE ROUGE-2 METRICS ===")
print(f"Recall:    {avg_r*100:.2f}")
print(f"Precision: {avg_p*100:.2f}")
print(f"F1:        {avg_f*100:.2f}")

In [None]:
# Compute ROUGE-3
all_recalls, all_precisions, all_f1s = [], [], []
for refs, pred in zip(all_references, predictions):
    recalls_per_sample, precisions_per_sample, f1s_per_sample = [], [], []
    for ref in refs:
        r, p, f = rouge_n(ref, pred, n=3)
        recalls_per_sample.append(r)
        precisions_per_sample.append(p)
        f1s_per_sample.append(f)

    max_score = max(f1s_per_sample)
    max_index = f1s_per_sample.index(max_score)
    all_recalls.append(recalls_per_sample[max_index])
    all_precisions.append(precisions_per_sample[max_index])
    all_f1s.append(f1s_per_sample[max_index])
    print(f"REF:  {refs[max_index]!r}")
    print(f"HYP:  {pred!r}")
    print(f"   ROUGE-3 Recall:    {recalls_per_sample[max_index] * 100:.2f}%")
    print(f"   ROUGE-3 Precision: {precisions_per_sample[max_index] * 100:.2f}%")
    print(f"   ROUGE-3 F1:        {f1s_per_sample[max_index] * 100:.2f}%\n")

In [None]:
# Report overall averages
avg_r = sum(all_recalls) / len(all_recalls)
avg_p = sum(all_precisions) / len(all_precisions)
avg_f = sum(all_f1s) / len(all_f1s)
print("=== AVERAGE ROUGE-3 METRICS ===")
print(f"Recall:    {avg_r*100:.2f}")
print(f"Precision: {avg_p*100:.2f}")
print(f"F1:        {avg_f*100:.2f}")

In [None]:
# Compute ROUGE-4
all_recalls, all_precisions, all_f1s = [], [], []
for refs, pred in zip(all_references, predictions):
    recalls_per_sample, precisions_per_sample, f1s_per_sample = [], [], []
    for ref in refs:
        r, p, f = rouge_n(ref, pred, n=4)
        recalls_per_sample.append(r)
        precisions_per_sample.append(p)
        f1s_per_sample.append(f)

    max_score = max(f1s_per_sample)
    max_index = f1s_per_sample.index(max_score)
    all_recalls.append(recalls_per_sample[max_index])
    all_precisions.append(precisions_per_sample[max_index])
    all_f1s.append(f1s_per_sample[max_index])
    print(f"REF:  {refs[max_index]!r}")
    print(f"HYP:  {pred!r}")
    print(f"   ROUGE-4 Recall:    {recalls_per_sample[max_index] * 100:.2f}%")
    print(f"   ROUGE-4 Precision: {precisions_per_sample[max_index] * 100:.2f}%")
    print(f"   ROUGE-4 F1:        {f1s_per_sample[max_index] * 100:.2f}%\n")

In [None]:
# Report overall averages
avg_r = sum(all_recalls) / len(all_recalls)
avg_p = sum(all_precisions) / len(all_precisions)
avg_f = sum(all_f1s) / len(all_f1s)
print("=== AVERAGE ROUGE-4 METRICS ===")
print(f"Recall:    {avg_r*100:.2f}")
print(f"Precision: {avg_p*100:.2f}")
print(f"F1:        {avg_f*100:.2f}")

In [None]:
import os

save_dir = "/content/drive/MyDrive/DI725 - Transformers and Attention-based Deep Networks/Term Project/siglip-t5-custom_vlm_finetuned"
os.makedirs(save_dir, exist_ok=True)

# Save tokenizer
tokenizer.save_pretrained(save_dir)

# Save merged language model (now a clean T5)
model.language_model.save_pretrained(save_dir)

# Save vision encoder and image processor
model.vision_model.save_pretrained(f"{save_dir}/vision_encoder")
image_processor.save_pretrained(f"{save_dir}/vision_encoder")

# Save vision projection layer
torch.save(model.vision_proj.state_dict(), f"{save_dir}/vision_proj.pt")

# Save config for reinitialization
import json
config = {
    "vision_encoder_path": "vision_encoder",
    "language_model_path": ".",
    "vision_proj_path": "vision_proj.pt",
    "vision_hidden_size": model.vision_proj.in_features,
    "language_hidden_size": model.vision_proj.out_features
}
with open(os.path.join(save_dir, "custom_vlm_config.json"), "w") as f:
    json.dump(config, f, indent=2)


In [None]:
print("DONE")