## CLIP (ViT-L/14) + Projection MLP + mT5-Small Decoder Pipeline

Bu bölüm: 
- CLIP ViT-L/14 image encoder (tamamen freeze)
- Görsel embedding -> K adet prefix token üreten MLP (öğrenilir)
- mT5-small (sadece decoder veya istersen tamamı) caption/çeviri üretimi
- Projection MLP ve mT5 decoder parametreleri eğitilecek.

Strateji (prefix approach):
1. Image -> CLIP encode_image -> (B,512)
2. MLP: 512 -> (K * d_model) reshape -> (B,K,512) -> LayerNorm
3. mT5 encoder'a inputs_embeds olarak bu prefix (opsiyonel ek tekst prompt tokenleri ile concat)
4. Decoder hedef yazıyı üretir (teacher forcing, cross-entropy)

Seçilebilir dondurma opsiyonları:
- freeze_clip = True (zorunlu senaryon)
- freeze_t5_encoder = True bırakıp sadece decoder + projection eğitilebilir

Aşağıdaki kod Flickr8k JSON (tasviret8k_captions.json) içinden (örnek) tek caption seçip dataset oluşturma iskeleti içerir.


In [None]:
# Unified configuration + optional wandb init
import os

PROJECT_NAME = "bites-tr-image-captioning"
RUN_NAME = "clip_mt5_prefix_run"

ENABLE_WANDB = True

base_config = {
    "model": "google/mt5-small",
    "clip_encoder": "ViT-B/32",  # switched from ViT-L/14 to ViT-B/32 for speed
    "prefix_tokens": 8,
    "batch_size": 4,
    "lr": 1e-4,
    "epochs": 30,
    "dataset_limit": 1000,
    "freeze_clip": True,
    "freeze_t5_encoder": True,
    "freeze_t5_decoder": False,
    "seed": 42,
    "weight_decay": 0.0,
    "grad_clip": 1.0,
    "warmup_steps": 0,
    "num_beams_infer": 4,
    "max_new_tokens_infer": 32,
    "src_max_len": 64,
    "tgt_max_len": 64,
}

use_wandb = False
cfg = None
if ENABLE_WANDB:
    try:
        import wandb
        run = wandb.init(project=PROJECT_NAME, name=RUN_NAME, config=base_config)
        cfg = wandb.config
        use_wandb = True
        print("[wandb] run initialized")
    except Exception as e:
        print("[wandb] disabled (init failed):", e)

if cfg is None:
    class _Cfg: pass
    cfg = _Cfg()
    for k, v in base_config.items():
        setattr(cfg, k, v)
    print("[INFO] Using local cfg (wandb off)")

import random, numpy as np, torch
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.seed)

print("Active config:")
for k, v in base_config.items():
    print(f"  {k}: {getattr(cfg, k)}")

[34m[1mwandb[0m: Currently logged in as: [33mabdulkadirparlak[0m ([33mabdulkadirparlak-hacettepe-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[wandb] run initialized
Active config:
  model: google/mt5-small
  clip_encoder: ViT-B/32
  prefix_tokens: 8
  batch_size: 4
  lr: 0.0001
  epochs: 20
  dataset_limit: 1000
  freeze_clip: True
  freeze_t5_encoder: True
  freeze_t5_decoder: False
  seed: 42
  weight_decay: 0
  grad_clip: 1
  warmup_steps: 0
  num_beams_infer: 4
  max_new_tokens_infer: 32
  src_max_len: 64
  tgt_max_len: 64


In [None]:
# Model, dataset and pipeline definitions (data + model init only)
import torch, os, json
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import clip
from transformers import MT5ForConditionalGeneration, MT5Tokenizer

class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, out_dim, prefix_tokens=8, hidden=1024, dropout=0.1):
        super().__init__()
        self.prefix_tokens = prefix_tokens
        self.fc1 = nn.Linear(in_dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, out_dim * prefix_tokens)
        self.ln = nn.LayerNorm(out_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.fc2(x)
        x = x.view(x.size(0), self.prefix_tokens, -1)
        x = self.ln(x)
        return self.dropout(x)

class CLIPmT5Pipeline(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = MT5Tokenizer.from_pretrained(cfg.model)
        self.model = MT5ForConditionalGeneration.from_pretrained(cfg.model)
        self.clip, _ = clip.load(cfg.clip_encoder, device='cpu')
        for p in self.clip.parameters(): p.requires_grad = not cfg.freeze_clip
        if cfg.freeze_t5_encoder:
            for p in self.model.encoder.parameters(): p.requires_grad = False
        if cfg.freeze_t5_decoder:
            for p in self.model.decoder.parameters(): p.requires_grad = False
        self.prefix_tokens = cfg.prefix_tokens
        # Determine CLIP embedding dim dynamically (ViT-B/32 = 512, ViT-L/14 = 768)
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            embed_dim = self.clip.visual.output_dim if hasattr(self.clip.visual, 'output_dim') else self.clip.encode_image(dummy).shape[-1]
        self.proj = ProjectionMLP(in_dim=embed_dim, out_dim=self.model.config.d_model, prefix_tokens=cfg.prefix_tokens)
        self._cached_sentinel_ids = None  # lazy cache

    # Training forward (teacher forcing)
    def forward(self, images, src_texts, tgt_texts):
        device = next(self.parameters()).device
        with torch.no_grad():
            clip_emb = self.clip.encode_image(images)
        prefix_emb = self.proj(clip_emb)
        tok_src = self.tokenizer(list(src_texts), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.src_max_len).to(device)
        tok_tgt = self.tokenizer(list(tgt_texts), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.tgt_max_len).to(device)
        text_emb = self.model.encoder.embed_tokens(tok_src.input_ids)
        full_emb = torch.cat([prefix_emb, text_emb], dim=1)
        full_attn = torch.cat([
            torch.ones(prefix_emb.size(0), self.prefix_tokens, dtype=tok_src.attention_mask.dtype, device=device),
            tok_src.attention_mask
        ], dim=1)
        return self.model(inputs_embeds=full_emb, attention_mask=full_attn, labels=tok_tgt.input_ids)

    # --- Inference helpers ---
    def _prepare_prefix(self, images: torch.Tensor):
        with torch.no_grad():
            emb = self.clip.encode_image(images)
        return self.proj(emb)  # (B, K, d)

    def _get_sentinel_bad_words(self, n=50):
        if self._cached_sentinel_ids is None:
            ids = [self.tokenizer(f'<extra_id_{i}>').input_ids[0] for i in range(n)]
            self._cached_sentinel_ids = [[i] for i in ids]
        return self._cached_sentinel_ids

    @torch.inference_mode()
    def generate(self, image_paths=None, images=None, num_beams=None, max_new_tokens=None, prompt=" ", ban_sentinels=True):
        """High-level caption generation.
        Args:
            image_paths: list of image file paths (if images tensor not provided).
            images: preprocessed tensor (B,C,H,W) matching CLIP preprocess.
            num_beams: beam search width.
            max_new_tokens: decoding length.
            prompt: optional textual prompt appended after prefix.
            ban_sentinels: whether to block <extra_id_*> placeholders.
        Returns: list[str] captions.
        """
        device = next(self.parameters()).device
        num_beams = num_beams or self.cfg.num_beams_infer
        max_new_tokens = max_new_tokens or self.cfg.max_new_tokens_infer
        if images is None:
            assert image_paths is not None, "Provide image_paths or images tensor"
            preprocess = clip.load(self.cfg.clip_encoder, device='cpu')[1]  # lightweight fetch
            pil_images = [Image.open(p).convert('RGB') for p in image_paths]
            images = torch.stack([preprocess(im) for im in pil_images])
        images = images.to(device)
        prefix_tokens = self._prepare_prefix(images)  # (B,K,d)
        # Prompt tokens
        tok = self.tokenizer([prompt]*images.size(0), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.src_max_len).to(device)
        text_emb = self.model.encoder.embed_tokens(tok.input_ids)
        full_emb = torch.cat([prefix_tokens, text_emb], dim=1)
        full_attn = torch.cat([
            torch.ones(images.size(0), self.prefix_tokens, device=device, dtype=tok.attention_mask.dtype),
            tok.attention_mask
        ], dim=1)
        bad_words_ids = self._get_sentinel_bad_words() if ban_sentinels else None
        gen_ids = self.model.generate(
            inputs_embeds=full_emb,
            attention_mask=full_attn,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            bad_words_ids=bad_words_ids
        )
        captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in gen_ids]
        return [c if c else "<EMPTY>" for c in captions]

class Flickr8kCaptions(Dataset):
    def __init__(self, json_path, images_root, split=None, limit=None, clip_preprocess=None):
        self.images_root = images_root
        raw = json.load(open(json_path))
        rows = raw['images'] if isinstance(raw, dict) and 'images' in raw else raw
        self.samples = []
        for row in rows:
            if not isinstance(row, dict): continue
            if split and row.get('split') != split: continue
            img = row.get('filename') or row.get('image') or row.get('img')
            sentences = row.get('sentences')
            if not img or not sentences: continue
            for s in sentences:
                if isinstance(s, dict) and 'raw' in s:
                    self.samples.append((img, s['raw']))
        if limit: self.samples = self.samples[:limit]
        self.transform = clip_preprocess or clip.load(cfg.clip_encoder, device='cpu')[1]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        img_name, cap = self.samples[idx]
        path = os.path.join(self.images_root, img_name)
        image = Image.open(path).convert('RGB')
        return self.transform(image), " ", cap

json_path = 'data/flickr8k/tasviret8k_captions.json'
images_root = 'data/flickr8k/Images'
train_dataset = Flickr8kCaptions(json_path, images_root, split='train', limit=cfg.dataset_limit)
val_dataset   = Flickr8kCaptions(json_path, images_root, split='val',   limit=50)
test_dataset  = Flickr8kCaptions(json_path, images_root, split='test',  limit=30)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=cfg.batch_size, shuffle=False) if len(val_dataset)>0 else None
test_loader  = DataLoader(test_dataset,  batch_size=cfg.batch_size, shuffle=False) if len(test_dataset)>0 else None

model_mm = CLIPmT5Pipeline(cfg)
print(f"Train samples: {len(train_dataset)}  Val: {len(val_dataset)}  Test: {len(test_dataset)}")
print("Trainable params:", sum(p.numel() for p in model_mm.parameters() if p.requires_grad))

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.mt5.tokenization_mt5.MT5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Train samples: 1000  Val: 50  Test: 30
Trainable params: 157960896


In [None]:
# Training cell (only training + validation)
import math, time, torch, warnings, os
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
warnings.filterwarnings("ignore", message=".*legacy behaviour.*MT5Tokenizer.*")

use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'mps' if torch.backends.mps.is_available() else 'cpu')
model_mm.to(device)
params = [p for p in model_mm.parameters() if p.requires_grad]
optimizer = AdamW(params, lr=cfg.lr, weight_decay=cfg.weight_decay)

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * cfg.epochs

def lr_lambda(step):
    if cfg.warmup_steps > 0 and step < cfg.warmup_steps:
        return float(step) / float(max(1, cfg.warmup_steps))
    progress = (step - cfg.warmup_steps) / float(max(1, total_steps - cfg.warmup_steps))
    return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
scaler = torch.amp.GradScaler('cuda') if use_cuda else None
best_val = float('inf')
CKPT_DIR = 'checkpoints'; os.makedirs(CKPT_DIR, exist_ok=True)

global_step = 0
for epoch in range(cfg.epochs):
    model_mm.train()
    train_sum = 0.0
    t0 = time.time()
    for step, batch in enumerate(train_loader, start=1):
        imgs, srcs, tgts = batch
        imgs = imgs.to(device)
        if use_cuda:
            with torch.amp.autocast('cuda'):
                out = model_mm(imgs, srcs, tgts)
                loss = out.loss
            scaler.scale(loss).backward()
            if cfg.grad_clip and cfg.grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(params, cfg.grad_clip)
            scaler.step(optimizer); scaler.update()
        else:
            out = model_mm(imgs, srcs, tgts); loss = out.loss
            loss.backward()
            if cfg.grad_clip and cfg.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(params, cfg.grad_clip)
            optimizer.step()
        optimizer.zero_grad(set_to_none=True)
        scheduler.step()
        train_sum += loss.item()
        global_step += 1
    train_epoch_loss = train_sum / max(1, len(train_loader))

    val_epoch_loss = None
    if val_loader:
        model_mm.eval(); v = 0.0
        with torch.no_grad():
            for batch in val_loader:
                imgs, srcs, tgts = batch
                imgs = imgs.to(device)
                out = model_mm(imgs, srcs, tgts)
                v += out.loss.item()
        val_epoch_loss = v / max(1, len(val_loader))

    dt = time.time() - t0
    if val_epoch_loss is not None:
        print(f"Epoch {epoch+1}/{cfg.epochs} train_loss={train_epoch_loss:.4f} val_loss={val_epoch_loss:.4f} time={dt:.1f}s lr={scheduler.get_last_lr()[0]:.2e}", flush=True)
        if use_wandb:
            wandb.log({"train/epoch_loss": train_epoch_loss, "val/epoch_loss": val_epoch_loss, "lr": scheduler.get_last_lr()[0]}, step=epoch)
    else:
        print(f"Epoch {epoch+1}/{cfg.epochs} train_loss={train_epoch_loss:.4f} time={dt:.1f}s lr={scheduler.get_last_lr()[0]:.2e}", flush=True)
        if use_wandb:
            wandb.log({"train/epoch_loss": train_epoch_loss, "lr": scheduler.get_last_lr()[0]}, step=epoch)

    metric = val_epoch_loss if val_epoch_loss is not None else train_epoch_loss
    if metric < best_val:
        best_val = metric
        torch.save({
            'model': model_mm.state_dict(),
            'cfg': base_config,
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_val': best_val,
        }, os.path.join(CKPT_DIR, 'best.pt'))
        print(f"  -> Saved checkpoint (metric {best_val:.4f})")
print("Training finished. Best metric:", best_val)

In [None]:
# Testing / Inference cell (evaluate test set + simple generation API)
import torch, time, os
from typing import List
from PIL import Image

model_mm.eval()
device = next(model_mm.parameters()).device

test_loss = None
if test_loader:
    t0 = time.time(); total=0.0
    with torch.no_grad():
        for batch in test_loader:
            imgs, srcs, tgts = batch
            imgs = imgs.to(device)
            out = model_mm(imgs, srcs, tgts)
            total += out.loss.item()
    test_loss = total / max(1,len(test_loader))
    print(f"Test loss={test_loss:.4f} time={(time.time()-t0):.1f}s")
else:
    print("No test split available.")

# Show a few sample generated captions from first N test images
SAMPLE_PRINTS = 5
if test_loader and SAMPLE_PRINTS > 0:
    shown = 0
    printed_imgs = set()
    for img_path, _ in [(os.path.join('data/flickr8k/Images', s[0]), s[1]) for s in test_dataset.samples]:
        if img_path in printed_imgs:  # avoid duplicates if multiple captions per image
            continue
        caps = model_mm.generate(image_paths=[img_path])
        # Find the ground truth caption for this image (first caption in test_dataset.samples for this image)
        gt_caption = next(s[1] for s in test_dataset.samples if os.path.join(images_root, s[0]) == img_path)
        print('GT:', gt_caption)
        print('CAP:', caps[0])
        print('---------------')
        printed_imgs.add(img_path)
        shown += 1
        if shown >= SAMPLE_PRINTS:
            break

def generate_captions(image_paths: List[str], **kwargs):
    return model_mm.generate(image_paths=image_paths, **kwargs)

print("[Ready] generate_captions(['path/to/img.jpg'])")