## CLIP (ViT-L/14) + Projection MLP + mT5-Small Decoder Pipeline

Bu bölüm: 
- CLIP ViT-L/14 image encoder (tamamen freeze)
- Görsel embedding -> K adet prefix token üreten MLP (öğrenilir)
- mT5-small (sadece decoder veya istersen tamamı) caption/çeviri üretimi
- Projection MLP ve mT5 decoder parametreleri eğitilecek.

Strateji (prefix approach):
1. Image -> CLIP encode_image -> (B,512)
2. MLP: 512 -> (K * d_model) reshape -> (B,K,512) -> LayerNorm
3. mT5 encoder'a inputs_embeds olarak bu prefix (opsiyonel ek tekst prompt tokenleri ile concat)
4. Decoder hedef yazıyı üretir (teacher forcing, cross-entropy)

Seçilebilir dondurma opsiyonları:
- freeze_clip = True (zorunlu senaryon)
- freeze_t5_encoder = True bırakıp sadece decoder + projection eğitilebilir

Aşağıdaki kod Flickr8k JSON (tasviret8k_captions.json) içinden (örnek) tek caption seçip dataset oluşturma iskeleti içerir.


In [1]:
import json, random, math
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import clip
from transformers import MT5ForConditionalGeneration, T5Tokenizer
from PIL import Image
from pathlib import Path

class ProjectionMLP(nn.Module):
    def __init__(self, in_dim=512, d_model=512, prefix_tokens=16, hidden=1024, dropout=0.05):
        super().__init__()
        self.prefix_tokens = prefix_tokens
        self.d_model = d_model
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden),
            nn.GELU(),
            nn.Linear(hidden, prefix_tokens * d_model)
        )
        self.ln = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        b = x.size(0)
        out = self.net(x).view(b, self.prefix_tokens, self.d_model)
        out = self.ln(out)
        out = self.dropout(out)
        return out

class CLIPmT5Pipeline(nn.Module):
    def __init__(self, mt5_name="google/mt5-small", clip_name="ViT-L/14", prefix_tokens=16,
                 freeze_clip=True, freeze_t5_encoder=True, freeze_t5_decoder=False, device=None):
        super().__init__()
        self.device_ref = device or ("mps" if torch.backends.mps.is_available() else "cpu")
        self.tokenizer = T5Tokenizer.from_pretrained(mt5_name)
        self.t5 = MT5ForConditionalGeneration.from_pretrained(mt5_name)
        if freeze_t5_encoder:
            for p in self.t5.encoder.parameters():
                p.requires_grad = False
        if freeze_t5_decoder:
            for p in self.t5.decoder.parameters():
                p.requires_grad = False
            for p in self.t5.lm_head.parameters():
                p.requires_grad = False
        self.clip_model, self.clip_preprocess = clip.load(clip_name, device="cpu")
        if freeze_clip:
            for p in self.clip_model.parameters():
                p.requires_grad = False
        clip_dim = self.clip_model.visual.output_dim if hasattr(self.clip_model.visual, 'output_dim') else 512
        d_model = self.t5.config.d_model
        assert d_model == 512, "mt5-small d_model beklenen 512 değil!"
        self.proj = ProjectionMLP(in_dim=clip_dim, d_model=d_model, prefix_tokens=prefix_tokens)
        self.to(self.device_ref)
    @property
    def device(self):
        return self.device_ref
    def encode_image(self, pil_images):
        imgs = torch.stack([self.clip_preprocess(im) for im in pil_images]).to(self.device)
        with torch.no_grad():
            feats = self.clip_model.encode_image(imgs)
        feats = feats / feats.norm(dim=-1, keepdim=True)
        return feats
    def forward(self, images, source_texts, target_texts=None, max_new_tokens=64, num_beams=1):
        feats = self.encode_image(images)
        prefix = self.proj(feats)
        enc_tokens = self.tokenizer(list(source_texts), padding=True, return_tensors="pt")
        input_ids = enc_tokens.input_ids.to(self.device)
        attn_mask = enc_tokens.attention_mask.to(self.device)
        text_embeds = self.t5.encoder.embed_tokens(input_ids)
        inputs_embeds = torch.cat([prefix, text_embeds], dim=1)
        prefix_mask = torch.ones(prefix.size()[:2], dtype=attn_mask.dtype, device=self.device)
        encoder_attention_mask = torch.cat([prefix_mask, attn_mask], dim=1)
        encoder_outputs = self.t5.encoder(inputs_embeds=inputs_embeds, attention_mask=encoder_attention_mask, return_dict=True)
        if target_texts is not None:
            dec_tokens = self.tokenizer(list(target_texts), padding=True, return_tensors="pt")
            labels = dec_tokens.input_ids.to(self.device)
            labels[labels == self.tokenizer.pad_token_id] = -100
            out = self.t5(encoder_outputs=encoder_outputs, attention_mask=encoder_attention_mask, labels=labels, return_dict=True)
            return out
        gen = self.t5.generate(encoder_outputs=encoder_outputs, attention_mask=encoder_attention_mask,
                               max_new_tokens=max_new_tokens, num_beams=num_beams)
        decoded = self.tokenizer.batch_decode(gen, skip_special_tokens=True)
        return decoded

class Flickr8kCaptions(Dataset):
    def __init__(self, json_path, images_root, instruction_prefix="describe image in Turkish:", limit=None, seed=42):
        data = json.loads(Path(json_path).read_text())
        items = []
        # Expected structure: {"images": [{"filename":..., "sentences": [{"raw":..., "tokens":[...]}, ...]}, ...]}
        if isinstance(data, dict) and isinstance(data.get("images"), list):
            for im_entry in data["images"]:
                fn = im_entry.get("filename")
                sents = im_entry.get("sentences", [])
                for s in sents:
                    caption = s.get("raw") or " ".join(s.get("tokens", []))
                    if fn and caption:
                        items.append((fn, caption))
        else:
            # Fallback to previous (filename -> list) style if needed
            for k, v in (data.items() if isinstance(data, dict) else []):
                if isinstance(v, list):
                    for c in v:
                        items.append((k, c))
        random.Random(seed).shuffle(items)
        if limit:
            items = items[:limit]
        self.items = items
        self.base = Path(images_root)
        self.prefix = instruction_prefix
    def __len__(self):
        return len(self.items)
    def __getitem__(self, idx):
        img_name, caption = self.items[idx]
        img_path = self.base / Path(img_name).name
        image = Image.open(img_path).convert("RGB")
        return image, self.prefix, caption

def collate(batch):
    images, sources, targets = zip(*batch)
    return list(images), list(sources), list(targets)

json_path = "data/flickr8k/tasviret8k_captions.json"
images_root = "data/flickr8k/Images"
try:
    dataset = Flickr8kCaptions(json_path, images_root, limit=3000)
    print("Loaded samples:", len(dataset))
    if len(dataset):
        print("Sample item:", dataset.items[0])
    loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate)
except Exception as e:
    print("Dataset init error (check JSON format):", e)
    loader = None

model_mm = CLIPmT5Pipeline(prefix_tokens=16, freeze_clip=True, freeze_t5_encoder=True, freeze_t5_decoder=False)
print("Trainable params:")
trainable = sum(p.numel() for p in model_mm.parameters() if p.requires_grad)
print(trainable, "parameters")

Loaded samples: 3000
Sample item: ('498404951_527adba7b8.jpg', 'Beyaz renkli bir köpek otların arasında bir suyun içinde.')


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Trainable params:
162421440 parameters


In [2]:
from torch.optim import AdamW
from contextlib import nullcontext
from torch import amp

max_epochs = 2
lr = 2e-4
use_cuda = torch.cuda.is_available()
use_amp = use_cuda  # Only enable AMP on CUDA
scaler = amp.GradScaler(device_type='cuda', enabled=use_amp) if use_amp else None

optim_params = [p for p in model_mm.parameters() if p.requires_grad]
optimizer = AdamW(optim_params, lr=lr)

if loader is not None:
    model_mm.train()
    for epoch in range(max_epochs):
        total_loss = 0.0
        for step, (imgs, srcs, tgts) in enumerate(loader):
            optimizer.zero_grad()
            ctx = amp.autocast(device_type='cuda', dtype=torch.float16) if use_amp else nullcontext()
            with ctx:
                out = model_mm(imgs, srcs, tgts)
                loss = out.loss
            if use_amp:
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                optimizer.step()
            total_loss += loss.item()
            if (step + 1) % 5 == 0:
                print(f"Epoch {epoch+1} Step {step+1} Loss {loss.item():.4f}")
        print(f"Epoch {epoch+1} Mean Loss: {total_loss/len(loader):.4f}")
else:
    print("Loader not initialized; check dataset parsing.")

Epoch 1 Step 5 Loss 21.6134
Epoch 1 Step 10 Loss 19.7147
Epoch 1 Step 15 Loss 24.0285
Epoch 1 Step 20 Loss 15.8134
Epoch 1 Step 25 Loss 21.7727
Epoch 1 Step 30 Loss 17.6961
Epoch 1 Step 35 Loss 19.3310
Epoch 1 Step 40 Loss 20.1627
Epoch 1 Step 45 Loss 18.5610
Epoch 1 Step 50 Loss 19.1264
Epoch 1 Step 55 Loss 23.5003
Epoch 1 Step 60 Loss 21.5944
Epoch 1 Step 65 Loss 17.9103
Epoch 1 Step 70 Loss 16.1339
Epoch 1 Step 75 Loss 22.5555
Epoch 1 Step 80 Loss 13.6943
Epoch 1 Step 85 Loss 21.5188
Epoch 1 Step 90 Loss 16.4548
Epoch 1 Step 95 Loss 13.1756
Epoch 1 Step 100 Loss 16.3613
Epoch 1 Step 105 Loss 13.4107
Epoch 1 Step 110 Loss 14.7152
Epoch 1 Step 115 Loss 13.5813
Epoch 1 Step 120 Loss 16.3904
Epoch 1 Step 125 Loss 15.1450
Epoch 1 Step 130 Loss 14.4007
Epoch 1 Step 135 Loss 14.2127
Epoch 1 Step 140 Loss 15.1400
Epoch 1 Step 145 Loss 16.2860
Epoch 1 Step 150 Loss 12.9744
Epoch 1 Step 155 Loss 14.1738
Epoch 1 Step 160 Loss 17.0998
Epoch 1 Step 165 Loss 12.6047
Epoch 1 Step 170 Loss 16.0285


In [3]:
# Inference / quick generation test
if loader is not None:
    model_mm.eval()
    sample_imgs, sample_srcs, sample_tgts = next(iter(loader))
    with torch.no_grad():
        preds = model_mm(sample_imgs, sample_srcs, target_texts=None, max_new_tokens=32, num_beams=4)
    for i in range(len(preds)):
        print(f"SRC: {sample_srcs[i][:50]}")
        print(f"PRED: {preds[i]}")
        print(f"GOLD: {sample_tgts[i]}")
        print("----")
else:
    print("No loader for inference.")

SRC: describe image in Turkish:
PRED: brali rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada rada
GOLD: Bir grup insan deniz aracının içinden selam veriyor.
----
SRC: describe image in Turkish:
PRED: nkomstātā
GOLD: Bir takım israilli insan kendi milletlerini eleştirenlere karşı bir eylemdeler.
----
SRC: describe image in Turkish:
PRED: 㲟 bir... bir.kuyu. köpkuyu.kuyu.kuyu.kuyu bir. köpkuyu. köpkuyu. köpkuyu. köpkuyu. köp
GOLD: Ağzında tuttuğu renkli top ile çimlerin üzerinde koşan kahverengi küçük bir köpek.
----
SRC: describe image in Turkish:
PRED: resize
GOLD: Göl kenarına kurulmuş bir tribünde bir adam oturmuş, gitar çalıyor.
----
