## Turkish Image Captioning System

### Environmental Setup

In [None]:
# Rerun this cell at each session start

# Uninstall conflicting packages (Kaggle specific)
!pip uninstall -y bigframes cesium gcsfs

# Performance metrics
!pip install -r /kaggle/input/requirements/requirements.txt

# To use nltk
import nltk; nltk.download('punkt'); nltk.download('wordnet'); nltk.download('omw-1.4')

# Download mT5-small and load CLIP ViT-B/32
import clip, torch
from transformers import MT5ForConditionalGeneration, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "google/mt5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
mt5 = MT5ForConditionalGeneration.from_pretrained(model_name).to(device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

print("Setup Done!")

### Model Configuration

In [None]:
# Unified configuration and environment setup
from kaggle_secrets import UserSecretsClient
import os

PROJECT_NAME = "XXX"  # e.g., "clip_prefix_captioning"
RUN_NAME = "XXX"      # e.g., "exp1"

ENABLE_WANDB = True  # Set False to skip W&B entirely

WANDB_API_KEY = UserSecretsClient().get_secret("WANDB_API_KEY")

base_config = {
    "model": "google/mt5-small",
    "clip_encoder": "ViT-B/32",  # backbone
    "prefix_tokens": 32,          # stronger conditioning
    "batch_size": 16,              # reduced to mitigate OOM
    "grad_accum_steps": 2,        # accumulate to simulate larger effective batch
    "enable_t5_gradient_checkpointing": True,  # reduce memory
    "lr": 1e-4,
    "epochs": 50,                  # allow a bit longer now that we fine-tune CLIP
    "dataset_limit": None,
    # --- CLIP fine-tuning controls ---
    "freeze_clip": False,          # set False to allow full CLIP fine-tuning
    "unfreeze_clip_last_n": 0,     # if >0 unfreezes only last N vision blocks
    "clip_lr_scale": 0.05,         # scaled LR for ALL CLIP params (lower than main)
    "use_clip_patch_tokens": True, # richer conditioning (patch tokens path) (set False to save memory)
    # --- T5 freezing ---
    "freeze_t5_encoder": False,    # unfreeze encoder so it can adapt
    "freeze_t5_decoder": False,
    # --- Optimization ---
    "seed": 42,
    "weight_decay": 0.01,
    "grad_clip": 1.0,
    "warmup_steps": 500,           # linear warmup steps before cosine decay
    # --- Inference defaults ---
    "num_beams_infer": 4,
    "max_new_tokens_infer": 32,
    "src_max_len": 64,
    "tgt_max_len": 64,
    "use_amp": True,
    # Early stop
    "early_stop_patience": 5,
    "early_stop_min_delta": 0.001,
    # Optional extras:
    "use_bf16": True,
    "enable_tf32": True,
    "finite_loss_skip": True,
    "save_every": 0
}

# Global flags/handles populated after (optional) wandb init cell
use_wandb = False
cfg = None

# Always create a local cfg object here; W&B init (next cell) can sync/override
class _Cfg: ...
cfg = _Cfg()
for k, v in base_config.items():
    setattr(cfg, k, v)
print("[INFO] Local config object created.")
if ENABLE_WANDB:
    print("[INFO] Run the next 'W&B Init' cell to enable Weights & Biases tracking.")
else:
    print("[INFO] W&B disabled (ENABLE_WANDB=False). Using local config only.")

import random, numpy as np, torch
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.seed)

# --- Enforce CUDA-only environment ---
assert torch.cuda.is_available(), "CUDA GPU is required but not detected. Please run in a CUDA-enabled environment."
device = torch.device('cuda')
print(f"[DEVICE] Using CUDA device: {torch.cuda.get_device_name(device)}")

# Performance toggles
if getattr(cfg, 'enable_tf32', False):
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("[DEVICE] TF32 enabled.")
    except Exception as _e:
        print("[WARN] Could not enable TF32:", _e)

if getattr(cfg, 'use_bf16', False):
    bf16_ok = torch.cuda.is_bf16_supported()
    print(f"[DEVICE] bfloat16 support: {bf16_ok}")

print("Active config (local):")
for k, v in base_config.items():
    print(f"  {k}: {getattr(cfg, k)}")

### Model and Dataset Definition

In [None]:
# Model, dataset and pipeline definitions (data + model init only)
import torch, os, json
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import clip
from transformers import MT5ForConditionalGeneration, AutoTokenizer

# (CUDA enforcement handled in config cell; assume cuda device later)

class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, out_dim, prefix_tokens=8, hidden=1024, dropout=0.1):
        super().__init__()
        self.prefix_tokens = prefix_tokens
        self.fc1 = nn.Linear(in_dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, out_dim * prefix_tokens)
        self.ln = nn.LayerNorm(out_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.fc2(x)
        x = x.view(x.size(0), self.prefix_tokens, -1)
        x = self.ln(x)
        return self.dropout(x)

class CLIPmT5Pipeline(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = AutoTokenizer.from_pretrained(cfg.model, use_fast=False)
        self.model = MT5ForConditionalGeneration.from_pretrained(cfg.model)
        self.clip, _ = clip.load(cfg.clip_encoder, device='cpu')
        # CLIP freezing / unfreezing strategy
        if cfg.unfreeze_clip_last_n and cfg.unfreeze_clip_last_n > 0:
            # Freeze everything first
            for p in self.clip.parameters():
                p.requires_grad = False
            blocks = list(self.clip.visual.transformer.resblocks)
            for block in blocks[-cfg.unfreeze_clip_last_n:]:
                for p in block.parameters():
                    p.requires_grad = True
        else:
            # Respect freeze_clip flag (False means fully trainable)
            for p in self.clip.parameters():
                p.requires_grad = not cfg.freeze_clip
        # T5 freeze toggles
        if cfg.freeze_t5_encoder:
            for p in self.model.encoder.parameters():
                p.requires_grad = False
        if cfg.freeze_t5_decoder:
            for p in self.model.decoder.parameters():
                p.requires_grad = False
        self.prefix_tokens = cfg.prefix_tokens
        # Determine embedding dim dynamically (CLIP projection output)
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            embed_dim = self.clip.visual.output_dim if hasattr(self.clip.visual, 'output_dim') else self.clip.encode_image(dummy).shape[-1]
        self.proj = ProjectionMLP(in_dim=embed_dim, out_dim=self.model.config.d_model, prefix_tokens=cfg.prefix_tokens)
        self._cached_sentinel_ids = None  # lazy cache

    def _encode_image_single(self, images: torch.Tensor):
        # (B, D) pooled embedding (already passed through ln_post + proj inside encode_image)
        pooled = self.clip.encode_image(images)
        return pooled

    def _encode_image_patch_tokens(self, images: torch.Tensor):
        """Return patch-averaged embedding projected into CLIP joint space.
        We manually replicate encode_image path but average patch tokens (excluding CLS),
        then apply ln_post and proj so final dim == visual.output_dim (e.g. 512) to match ProjectionMLP in_dim."""
        visual = self.clip.visual
        x = visual.conv1(images)                      # (B, width, grid, grid)
        x = x.reshape(x.shape[0], x.shape[1], -1)     # (B, width, patches)
        x = x.permute(0, 2, 1)                        # (B, patches, width)
        cls_tokens = visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
        x = torch.cat([cls_tokens, x], dim=1)         # prepend CLS
        x = x + visual.positional_embedding.to(x.dtype)
        x = visual.ln_pre(x)
        x = x.permute(1, 0, 2)                        # (sequence, B, width)
        for block in visual.transformer.resblocks:
            x = block(x)
        x = x.permute(1, 0, 2)                        # (B, sequence, width)
        patches = x[:, 1:, :]                         # drop CLS
        pooled = patches.mean(dim=1)                  # (B, width)
        if hasattr(visual, 'ln_post'):
            pooled = visual.ln_post(pooled)
        if hasattr(visual, 'proj') and visual.proj is not None:
            pooled = pooled @ visual.proj             # (B, output_dim)
        return pooled

    def forward(self, images, src_texts, tgt_texts):
        device = next(self.parameters()).device
        images = images.to(device)
        clip_emb = self._encode_image_patch_tokens(images) if self.cfg.use_clip_patch_tokens else self._encode_image_single(images)
        prefix_emb = self.proj(clip_emb)
        tok_src = self.tokenizer(list(src_texts), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.src_max_len).to(device)
        tok_tgt = self.tokenizer(list(tgt_texts), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.tgt_max_len).to(device)
        text_emb = self.model.encoder.embed_tokens(tok_src.input_ids)
        full_emb = torch.cat([prefix_emb, text_emb], dim=1)
        full_attn = torch.cat([
            torch.ones(prefix_emb.size(0), self.prefix_tokens, dtype=tok_src.attention_mask.dtype, device=device),
            tok_src.attention_mask
        ], dim=1)
        return self.model(inputs_embeds=full_emb, attention_mask=full_attn, labels=tok_tgt.input_ids)

    def _prepare_prefix(self, images: torch.Tensor):
        images = images.to(next(self.parameters()).device)
        emb = self._encode_image_patch_tokens(images) if self.cfg.use_clip_patch_tokens else self._encode_image_single(images)
        return self.proj(emb)

    def _get_sentinel_bad_words(self, n=50):
        if self._cached_sentinel_ids is None:
            ids = [self.tokenizer(f'<extra_id_{i}>').input_ids[0] for i in range(n)]
            self._cached_sentinel_ids = [[i] for i in ids]
        return self._cached_sentinel_ids

    @torch.inference_mode()
    def generate(self, image_paths=None, images=None, num_beams=None, max_new_tokens=None, prompt="Bu görüntüyü açıkla: ", ban_sentinels=True, **gen_kwargs):
        device = next(self.parameters()).device
        num_beams = num_beams or self.cfg.num_beams_infer
        max_new_tokens = max_new_tokens or self.cfg.max_new_tokens_infer
        if images is None:
            assert image_paths is not None, "Provide image_paths or images tensor"
            preprocess = clip.load(self.cfg.clip_encoder, device='cpu')[1]
            pil_images = [Image.open(p).convert('RGB') for p in image_paths]
            images = torch.stack([preprocess(im) for im in pil_images])
        images = images.to(device)
        prefix_tokens = self._prepare_prefix(images)
        tok = self.tokenizer([prompt]*images.size(0), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.src_max_len).to(device)
        text_emb = self.model.encoder.embed_tokens(tok.input_ids)
        full_emb = torch.cat([prefix_tokens, text_emb], dim=1)
        full_attn = torch.cat([
            torch.ones(images.size(0), self.prefix_tokens, device=device, dtype=tok.attention_mask.dtype),
            tok.attention_mask
        ], dim=1)
        bad_words_ids = self._get_sentinel_bad_words() if ban_sentinels else None
        gen_ids = self.model.generate(
            inputs_embeds=full_emb,
            attention_mask=full_attn,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            bad_words_ids=bad_words_ids,
            **gen_kwargs
        )
        captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in gen_ids]
        return [c if c else "<EMPTY>" for c in captions]

class Flickr8kCaptions(Dataset):
    def __init__(self, json_path, images_root, split=None, limit=None, clip_preprocess=None):
        self.images_root = images_root
        raw = json.load(open(json_path))
        rows = raw['images'] if isinstance(raw, dict) and 'images' in raw else raw
        self.samples = []
        for row in rows:
            if not isinstance(row, dict): continue
            if split and row.get('split') != split: continue
            img = row.get('filename') or row.get('image') or row.get('img')
            sentences = row.get('sentences')
            if not img or not sentences: continue
            for s in sentences:
                if isinstance(s, dict) and 'raw' in s:
                    self.samples.append((img, s['raw']))
        if limit: self.samples = self.samples[:limit]
        self.transform = clip_preprocess or clip.load(cfg.clip_encoder, device='cpu')[1]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        img_name, cap = self.samples[idx]
        path = os.path.join(self.images_root, img_name)
        image = Image.open(path).convert('RGB')
        return self.transform(image), " ", cap

### WANDB Initialization

In [None]:
# (Optional) Weights & Biases initialization
if 'ENABLE_WANDB' in globals() and ENABLE_WANDB:
    try:
        import wandb
        wandb.login(key=WANDB_API_KEY)
        run = wandb.init(project=PROJECT_NAME, name=RUN_NAME, config=base_config, reinit=True)
        cfg = wandb.config  # sync cfg to wandb
        # Explicitly log config dict to the run (config, summary, and a one-time log)
        try:
            cfg_dict = dict(base_config)
        except Exception:
            cfg_dict = {k: getattr(cfg, k) for k in base_config.keys() if hasattr(cfg, k)}
        wandb.config.update(cfg_dict, allow_val_change=True)
        # Store a namespaced copy in summary for quick viewing
        wandb.summary.update({f"cfg/{k}": v for k, v in cfg_dict.items()})
        # Also log once at step 0 for time-series traceability
        wandb.log({"cfg": cfg_dict}, step=0)
        use_wandb = True
        print('[wandb] run initialized and config logged.')
    except Exception as e:
        use_wandb = False
        print('[wandb] disabled (init failed):', e)
else:
    print('[wandb] Skipped (ENABLE_WANDB is False).')

### Data Loading

In [None]:
# Data Loading
json_path = '/kaggle/input/tasviret/flickr8k/tasviret8k_captions.json'
images_root = '/kaggle/input/tasviret/flickr8k/Images'
train_dataset = Flickr8kCaptions(json_path, images_root, split='train', limit=cfg.dataset_limit)
val_dataset   = Flickr8kCaptions(json_path, images_root, split='val',   limit=None)
test_dataset  = Flickr8kCaptions(json_path, images_root, split='test',  limit=None)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=cfg.batch_size, shuffle=False) if len(val_dataset)>0 else None
test_loader  = DataLoader(test_dataset,  batch_size=cfg.batch_size, shuffle=False) if len(test_dataset)>0 else None

model_mm = CLIPmT5Pipeline(cfg)
print(f"Train samples: {len(train_dataset)}  Val: {len(val_dataset)}  Test: {len(test_dataset)}")
print("Clip ViT-B/32 params:", sum(p.numel() for p in model_mm.clip.parameters() if p.requires_grad))
print("Projection params:", sum(p.numel() for p in model_mm.proj.parameters() if p.requires_grad))
print("mt5-small params:", sum(p.numel() for p in model_mm.model.parameters() if p.requires_grad))
print("Total trainable params:", sum(p.numel() for p in model_mm.parameters() if p.requires_grad))

### Training

In [None]:
# Training cell: loss + BLEU1 early stopping + grad accum + checkpoints + history lists (fixed for this pipeline)
import math, time, torch, warnings, os
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

warnings.filterwarnings("ignore", message=".*legacy behaviour.*MT5Tokenizer.*")

assert torch.cuda.is_available(), "CUDA GPU required."
device = torch.device("cuda")
model_mm.to(device)

# Helper: compute BLEU on validation set using existing generation API
from collections import defaultdict

def _compute_val_bleu(model, val_dataset, images_root, batch_size=16, amp_dtype=None):
    if val_dataset is None or len(val_dataset) == 0:
        return None, None
    refs_map = defaultdict(list)
    for fname, cap in val_dataset.samples:
        c = str(cap).strip()
        if c:
            refs_map[fname].append(c)
    if not refs_map:
        return None, None
    image_files = list(refs_map.keys())
    image_paths = [os.path.join(images_root, f) for f in image_files]

    def _generate_batch(paths):
        if amp_dtype is not None:
            with torch.amp.autocast('cuda', dtype=amp_dtype):
                return model.generate(image_paths=paths, ban_sentinels=True)
        return model.generate(image_paths=paths, ban_sentinels=True)

    hyps, refs = [], []
    model.eval()
    with torch.no_grad():
        for i in range(0, len(image_paths), batch_size):
            batch_paths = image_paths[i:i+batch_size]
            outs = _generate_batch(batch_paths)
            for pth, pred in zip(batch_paths, outs):
                fname = os.path.basename(pth)
                hyps.append(pred if pred else "<EMPTY>")
                refs.append(refs_map.get(fname, [" "]))

    gts = {i: refs[i] for i in range(len(refs))}
    res = {i: [hyps[i]] for i in range(len(hyps))}
    try:
        from pycocoevalcap.bleu.bleu import Bleu
        bleu_scores, _ = Bleu(4).compute_score(gts, res)
        return float(bleu_scores[0]), float(bleu_scores[3])
    except Exception as e:
        print('[WARN] BLEU failed:', e)
        return None, None


def train_model_with_bleu(
    model_mm,
    train_loader,
    val_loader,
    tokenizer,  # kept for signature compatibility; not used directly here
    cfg,
    num_epochs=None,
    grad_accum_steps=None,
    min_epochs_before_bleu=10,
    bleu_patience=5,
):
    # Resolve config values from cfg object
    lr = getattr(cfg, 'lr')
    weight_decay = getattr(cfg, 'weight_decay', 0.01)
    warmup_steps = getattr(cfg, 'warmup_steps', 0)
    epochs = num_epochs if num_epochs is not None else getattr(cfg, 'epochs', 30)
    grad_accum = grad_accum_steps if grad_accum_steps is not None else getattr(cfg, 'grad_accum_steps', 1)

    # Optional AMP dtype per earlier cells
    amp_dtype = None
    if getattr(cfg, 'use_amp', True):
        if getattr(cfg, 'use_bf16', False) and torch.cuda.is_bf16_supported():
            amp_dtype = torch.bfloat16
        else:
            amp_dtype = torch.float16

    model_mm.train()

    # Optimizer: keep CLIP LR scale logic similar to earlier cells
    main_params, clip_params = [], []
    for name, p in model_mm.named_parameters():
        if not p.requires_grad:
            continue
        (clip_params if name.startswith('clip.') else main_params).append(p)
    param_groups = []
    if main_params:
        param_groups.append({"params": main_params, "lr": lr})
    if clip_params:
        scaled_lr = lr * getattr(cfg, 'clip_lr_scale', 0.05)
        param_groups.append({"params": clip_params, "lr": scaled_lr})
        print(f"[INFO] CLIP fine-tune params: {len(clip_params)} with lr={scaled_lr:.2e}")

    optimizer = AdamW(param_groups, weight_decay=weight_decay)

    # Simple warmup to 1.0; matches your snippet intent
    scheduler = LambdaLR(
        optimizer,
        lr_lambda=lambda step: min((step + 1) / float(max(1, warmup_steps)), 1.0) if warmup_steps > 0 else 1.0,
    )

    os.makedirs('checkpoints', exist_ok=True)
    best_path = os.path.join('checkpoints', 'best.pt')
    last_path = os.path.join('checkpoints', 'last.pt')

    best_bleu = -1.0
    best_epoch = -1
    no_improve = 0
    global_step = 0

    train_losses, val_losses, epochs_list = [], [], []

    for epoch in range(epochs):
        t0 = time.time()
        total_loss = 0.0

        # === Training loop ===
        optimizer.zero_grad(set_to_none=True)
        for step, batch in enumerate(train_loader):
            imgs, srcs, tgts = batch
            imgs = imgs.to(device, non_blocking=True)
            if amp_dtype is not None:
                with torch.amp.autocast('cuda', dtype=amp_dtype):
                    outputs = model_mm(imgs, srcs, tgts)
                    loss = outputs.loss / grad_accum
                loss.backward()
            else:
                outputs = model_mm(imgs, srcs, tgts)
                loss = outputs.loss / grad_accum
                loss.backward()

            if (step + 1) % grad_accum == 0:
                if getattr(cfg, 'grad_clip', 0.0) > 0:
                    torch.nn.utils.clip_grad_norm_(model_mm.parameters(), getattr(cfg, 'grad_clip', 1.0))
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad(set_to_none=True)

            total_loss += loss.item()
            global_step += 1

        avg_train_loss = total_loss / max(1, len(train_loader))

        # === Validation loss ===
        model_mm.eval()
        total_val_loss = 0.0
        with torch.no_grad():
            for batch in val_loader:
                imgs, srcs, tgts = batch
                imgs = imgs.to(device, non_blocking=True)
                if amp_dtype is not None:
                    with torch.amp.autocast('cuda', dtype=amp_dtype):
                        outputs = model_mm(imgs, srcs, tgts)
                        total_val_loss += outputs.loss.item()
                else:
                    outputs = model_mm(imgs, srcs, tgts)
                    total_val_loss += outputs.loss.item()
        avg_val_loss = total_val_loss / max(1, len(val_loader))

        # Save losses to lists
        train_losses.append(avg_train_loss)
        val_losses.append(avg_val_loss)
        epochs_list.append(epoch + 1)

        # === BLEU evaluation ===
        bleu1 = bleu4 = None
        if epoch >= min_epochs_before_bleu:
            try:
                bleu1, bleu4 = _compute_val_bleu(model_mm, val_dataset, images_root, batch_size=16, amp_dtype=amp_dtype)
            except Exception as e:
                print('[WARN] BLEU eval failed:', e)

            # Check improvement
            if bleu1 is not None and bleu1 > best_bleu:
                best_bleu = bleu1
                best_epoch = epoch
                no_improve = 0
                torch.save(
                    {
                        'model': model_mm.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'epoch': epoch,
                        'best_bleu': best_bleu,
                        'cfg': {k: getattr(cfg, k) for k in base_config.keys()},
                    },
                    best_path,
                )
                print(f"[Epoch {epoch}] New best BLEU-1={bleu1:.4f} saved at {best_path}")
            elif bleu1 is not None:
                no_improve += 1

        # Save rolling checkpoint
        torch.save(
            {
                'model': model_mm.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'epoch': epoch,
                'cfg': {k: getattr(cfg, k) for k in base_config.keys()},
            },
            last_path,
        )

        # Log (guard W&B)
        time_elapsed = time.time() - t0
        try:
            if 'use_wandb' in globals() and use_wandb:
                log_dict = {
                    'train/epoch_loss': avg_train_loss,
                    'val/epoch_loss': avg_val_loss,
                    'lr': scheduler.get_last_lr()[0],
                }
                if epoch >= min_epochs_before_bleu and bleu1 is not None:
                    log_dict['val/bleu1'] = bleu1
                    if bleu4 is not None:
                        log_dict['val/bleu4'] = bleu4
                wandb.log(log_dict, step=epoch)
        except Exception as e:
            print('[wandb] log skipped:', e)

        # Print summary (include BLEU after activation)
        if epoch >= min_epochs_before_bleu and bleu1 is not None:
            print(f"[Epoch {epoch}] Train Loss={avg_train_loss:.4f} | Val Loss={avg_val_loss:.4f} | BLEU-1={bleu1:.4f} | Time={time_elapsed:.1f}s")
        else:
            print(f"[Epoch {epoch}] Train Loss={avg_train_loss:.4f} | Val Loss={avg_val_loss:.4f} | Time={time_elapsed:.1f}s")

        # Early stopping
        if epoch >= min_epochs_before_bleu and no_improve >= bleu_patience:
            print(f"⏹ Early stopping at epoch {epoch}. Best BLEU-1={best_bleu:.4f} at epoch {best_epoch}")
            break

        model_mm.train()

    return best_bleu, best_epoch, train_losses, val_losses, epochs_list

In [None]:
best_bleu, best_epoch, train_losses, val_losses, epochs_list = train_model_with_bleu(
    model_mm,
    train_loader,
    val_loader,
    model_mm.tokenizer,
    cfg,
    num_epochs=cfg.epochs,
    grad_accum_steps=cfg.grad_accum_steps,
    min_epochs_before_bleu=8,
    bleu_patience=getattr(cfg, 'bleu_patience', 5),
)   

In [None]:
# Plot the training and validation loss curves in the same plot
import matplotlib.pyplot as plt

if 'epochs_list' in globals() and len(epochs_list) == len(train_losses) == len(val_losses) and len(epochs_list) > 0:
    plt.figure(figsize=(7,4))
    plt.plot(epochs_list, train_losses, label='Train Loss', marker='o')
    plt.plot(epochs_list, val_losses, label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()
else:
    print('No loss history found. Run the training cell first to populate train_losses/val_losses/epochs_list.')

#### Model Loader Function (from checkpoint)

In [None]:
# === Direct .pt Checkpoint Loader ===
import torch, os
from typing import Optional, Dict, Any, Tuple

def load_model_from_checkpoint_file(
    checkpoint_path: str,
    device: torch.device,
    resume_optimizer: bool = False,
    build_optimizer_fn=None,
    build_scheduler_fn=None,
    override_cfg: Optional[Dict[str, Any]] = None,
):
    """Load model (and optional optimizer/scheduler) directly from a single .pt file.
    
    Works with both:
    - { 'model_state', 'optimizer_state', 'scheduler_state', 'cfg', ... }
    - { 'model', 'optimizer', 'scheduler', 'cfg', ... }

    Returns: (model, cfg_obj, optimizer, scheduler, epoch, global_step)
    """
    assert os.path.isfile(checkpoint_path), f"Checkpoint not found: {checkpoint_path}"
    bundle = torch.load(checkpoint_path, map_location=device)

    # --- Model state ---
    if 'model_state' in bundle:
        model_state = bundle['model_state']
    elif 'model' in bundle:
        model_state = bundle['model']
    else:
        raise ValueError("Checkpoint missing model weights (expected 'model_state' or 'model')")

    # --- Config ---
    if 'cfg' not in bundle:
        raise ValueError("Checkpoint missing 'cfg'")
    cfg_json = dict(bundle['cfg'])
    if override_cfg:
        cfg_json.update(override_cfg)

    class _Cfg: ...
    cfg_obj = _Cfg()
    for k, v in cfg_json.items():
        setattr(cfg_obj, k, v)

    # --- Build model and load weights ---
    model = CLIPmT5Pipeline(cfg_obj).to(device)
    model.load_state_dict(model_state, strict=True)

    # --- Metadata ---
    epoch = bundle.get('epoch')
    global_step = bundle.get('global_step')

    # --- Optimizer and scheduler (optional) ---
    optimizer, scheduler = None, None
    if resume_optimizer:
        # Optimizer
        opt_state = bundle.get('optimizer_state') or bundle.get('optimizer')
        if opt_state is not None and build_optimizer_fn is not None:
            optimizer = build_optimizer_fn(cfg_obj, model)
            try:
                optimizer.load_state_dict(opt_state)
                print('[IMPORT] Optimizer state restored.')
            except Exception as e:
                print(f'[WARN] Failed to load optimizer state: {e}')
        elif opt_state is not None:
            print('[WARN] Optimizer state present but build_optimizer_fn not provided; skipping restore.')

        # Scheduler
        sch_state = bundle.get('scheduler_state') or bundle.get('scheduler')
        if sch_state is not None and optimizer and build_scheduler_fn is not None:
            scheduler = build_scheduler_fn(cfg_obj, optimizer)
            try:
                scheduler.load_state_dict(sch_state)
                print('[IMPORT] Scheduler state restored.')
            except Exception as e:
                print(f'[WARN] Failed to load scheduler state: {e}')

    print(f"[IMPORT] Loaded model from file: {checkpoint_path}")
    return model, cfg_obj, optimizer, scheduler, epoch, global_step


In [None]:
# Load the model for inference
model_mm, cfg_loaded, _, _, epoch_loaded, step_loaded = load_model_from_checkpoint_file('/kaggle/working/checkpoints/best.pt', device=torch.device('cuda'), resume_optimizer=False)

### Test / Inference

In [None]:
# Test/Inference on the test_dataset
import os, random, torch
from collections import defaultdict
from pycocoevalcap.bleu.bleu import Bleu
from pycocoevalcap.cider.cider import Cider
from pycocoevalcap.meteor.meteor import Meteor
from pycocoevalcap.rouge.rouge import Rouge
from PIL import Image
import matplotlib.pyplot as plt

def evaluate_test_set(model, test_dataset, images_root, batch_size=16, top_k_examples=5):
    if test_dataset is None or len(test_dataset) == 0:
        print("No test dataset available.")
        return

    # Build references map
    refs_map = defaultdict(list)
    for fname, cap in test_dataset.samples:
        c = str(cap).strip()
        if c:
            refs_map[fname].append(c)

    image_files = list(refs_map.keys())
    image_paths = [os.path.join(images_root, f) for f in image_files]

    print(f"Evaluating on {len(image_paths)} test images…")
    model.eval()
    hyps, refs = [], []

    def _generate_batch(paths):
        local_amp = globals().get('amp_dtype', None)
        if local_amp is not None:
            with torch.amp.autocast('cuda', dtype=local_amp):
                return model.generate(image_paths=paths, ban_sentinels=True)
        return model.generate(image_paths=paths, ban_sentinels=True)

    # Generate predictions
    with torch.no_grad():
        for i in range(0, len(image_paths), batch_size):
            batch_paths = image_paths[i:i+batch_size]
            outs = _generate_batch(batch_paths)
            for pth, pred in zip(batch_paths, outs):
                fname = os.path.basename(pth)
                hyps.append(pred if pred else "<EMPTY>")
                refs.append(refs_map.get(fname, [" "]))

    # Prepare for COCO eval
    gts = {i: refs[i] for i in range(len(refs))}
    res = {i: [hyps[i]] for i in range(len(hyps))}

    # Compute BLEU, METEOR, ROUGE-L, CIDEr
    try:
        bleu_scores, _ = Bleu(4).compute_score(gts, res)
    except Exception as e:
        print("[WARN] BLEU failed:", e)
        bleu_scores = [0,0,0,0]
    try:
        meteor_score, _ = Meteor().compute_score(gts, res)
    except Exception as e:
        print("[WARN] METEOR failed:", e)
        meteor_score = 0
    try:
        rouge_score, _ = Rouge().compute_score(gts, res)
    except Exception as e:
        print("[WARN] ROUGE-L failed:", e)
        rouge_score = 0
    try:
        cider_score, _ = Cider().compute_score(gts, res)
    except Exception as e:
        print("[WARN] CIDEr failed:", e)
        cider_score = 0

    # Print metrics
    print(f"\nTest Set Metrics:\nBLEU-1: {bleu_scores[0]:.4f}  BLEU-2: {bleu_scores[1]:.4f}  BLEU-3: {bleu_scores[2]:.4f}  BLEU-4: {bleu_scores[3]:.4f}")
    print(f"METEOR: {meteor_score:.4f}  ROUGE-L: {rouge_score:.4f}  CIDEr: {cider_score:.4f}")

    # Show qualitative examples
    k = min(top_k_examples, len(image_files))
    if k > 0:
        print(f"\nQualitative examples (random {k}):")
        picked = random.sample(image_files, k)
        sample_paths = [os.path.join(images_root, f) for f in picked]
        preds = _generate_batch(sample_paths)
        for pth, pred in zip(sample_paths, preds):
            fname = os.path.basename(pth)
            gt_refs = refs_map.get(fname, [])
            print("\nImage:", fname)
            if gt_refs:
                print("Ground truths (up to 3):")
                for r in gt_refs[:3]:
                    print(" -", r)
            else:
                print("(No references found)")
            print("Prediction:")
            print(" -", pred if pred else "<EMPTY>")
            try:
                img = Image.open(pth).convert("RGB")
                plt.figure(figsize=(4,4))
                plt.imshow(img); plt.axis("off")
                plt.title(fname)
                plt.show()
            except Exception as e:
                print(f"[WARN] Could not display image: {e}")

In [None]:
evaluate_test_set(model_mm, test_dataset, images_root)

In [None]:
# Single Image Inference with Metrics (BLEU1-4, METEOR, ROUGE-L, CIDEr, SPICE)
from typing import Optional, Dict, Any, List
import matplotlib.pyplot as plt
import os, json, torch
from PIL import Image

_PYCOCO_SCORERS = {}

def _lazy_load_scorers():
    global _PYCOCO_SCORERS
    from pycocoevalcap.bleu.bleu import Bleu
    from pycocoevalcap.meteor.meteor import Meteor
    from pycocoevalcap.rouge.rouge import Rouge
    from pycocoevalcap.cider.cider import Cider
    try:
        from pycocoevalcap.spice.spice import Spice
    except Exception:
        Spice = None
    if 'bleu' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['bleu'] = Bleu(4)
    if 'meteor' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['meteor'] = Meteor()
    if 'rouge' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['rouge'] = Rouge()
    if 'cider' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['cider'] = Cider()
    if Spice and 'spice' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['spice'] = Spice()
    return _PYCOCO_SCORERS

def _compute_single_caption_metrics(pred: str, refs: List[str]) -> Dict[str, float]:
    refs_clean = [r.strip() for r in refs if r and r.strip()]
    if not refs_clean:
        return {}
    scorers = _lazy_load_scorers()
    gts = {0: refs_clean}
    res = {0: [pred]}
    out = {}
    try:
        bleu_scores, _ = scorers['bleu'].compute_score(gts, res)
        out['bleu1'] = float(bleu_scores[0])
        out['bleu2'] = float(bleu_scores[1])
        out['bleu3'] = float(bleu_scores[2])
        out['bleu4'] = float(bleu_scores[3])
        out['bleu']  = float(bleu_scores[3])
    except Exception as e:
        print('[WARN] BLEU failed:', e)
    try:
        meteor_score, _ = scorers['meteor'].compute_score(gts, res)
        out['meteor'] = float(meteor_score)
    except Exception as e:
        print('[WARN] METEOR failed:', e)
    try:
        rouge_score, _ = scorers['rouge'].compute_score(gts, res)
        out['rougeL'] = float(rouge_score)
    except Exception as e:
        print('[WARN] ROUGE-L failed:', e)
    try:
        cider_score, _ = scorers['cider'].compute_score(gts, res)
        out['cider'] = float(cider_score)
    except Exception as e:
        print('[WARN] CIDEr failed:', e)
    if 'spice' in scorers:
        try:
            spice_score, spice_scores = scorers['spice'].compute_score(gts, res)
            out['spice'] = float(spice_score)
        except Exception as e:
            print('[WARN] SPICE failed:', e)
    return out

def predict(
    image_path: str,
    prompt: str = "Bu görüntüyü açıkla: ",
    mode: str = "beam",              # 'beam' or 'sample'
    max_refs: Optional[int] = 5,
    show_image: bool = True,
    show_refs: bool = True,
    gen_kwargs: Optional[Dict[str, Any]] = None,
    json_file: Optional[str] = None,
    ban_sentinels: bool = True,
    compute_metrics: bool = True,
    print_metrics: bool = True,
) -> Dict[str, Any]:
    assert os.path.isfile(image_path), f"Image not found: {image_path}"
    jf = json_file or json_path
    refs: List[str] = []
    try:
        with open(jf) as f:
            data = json.load(f)
        entries = data['images'] if isinstance(data, dict) and 'images' in data else data
        target_name = os.path.basename(image_path)
        for e in entries:
            if not isinstance(e, dict):
                continue
            if e.get('filename') == target_name:
                for s in e.get('sentences', []):
                    if isinstance(s, dict) and 'raw' in s:
                        cap = s['raw'].strip()
                        if cap:
                            refs.append(cap)
                break
    except Exception as e:
        print(f"[WARN] Could not parse references ({e})")
    if max_refs is not None:
        refs = refs[:max_refs]
    gen_kwargs = (gen_kwargs or {}).copy()
    if mode == "sample":
        defaults = dict(temperature=0.8, top_p=0.9, do_sample=True, repetition_penalty=1.15, no_repeat_ngram_size=3)
    else:
        defaults = dict(repetition_penalty=1.2, no_repeat_ngram_size=3, num_beams=getattr(cfg, "num_beams_infer", 4))
    for k, v in defaults.items():
        gen_kwargs.setdefault(k, v)
    local_amp_dtype = globals().get("amp_dtype", None)
    model_mm.eval()
    if local_amp_dtype is not None:
        with torch.amp.autocast('cuda', dtype=local_amp_dtype):
            pred = model_mm.generate(
                image_paths=[image_path],
                prompt=prompt,
                ban_sentinels=ban_sentinels,
                **gen_kwargs
            )[0]
    else:
        pred = model_mm.generate(
            image_paths=[image_path],
            prompt=prompt,
            ban_sentinels=ban_sentinels,
            **gen_kwargs
        )[0]
    metrics: Dict[str, float] = {}
    if compute_metrics and refs:
        try:
            metrics = _compute_single_caption_metrics(pred if pred else "<EMPTY>", refs)
        except Exception as e:
            print(f"[WARN] Metric computation failed: {e}")
    if show_image:
        try:
            img = Image.open(image_path).convert("RGB")
            plt.figure(figsize=(5,5))
            plt.imshow(img); plt.axis("off")
            plt.title("Image")
            plt.show()
        except Exception as e:
            print(f"[WARN] Could not display image ({e})")
    if show_refs:
        if refs:
            print("References:")
            for i, r in enumerate(refs, 1):
                print(f"  {i}. {r}")
        else:
            print("(No references found)")
    print("Prediction:")
    print(" ", pred)
    if print_metrics and metrics:
        print("\nMetrics:")
        for k, v in metrics.items():
            print(f"  {k}: {v:.4f}" if isinstance(v,(int,float)) else f"  {k}: {v}")
    return {
        "image_path": image_path,
        "prediction": pred,
        "references": refs,
        "metrics": metrics,
        "mode": mode,
    }

In [None]:
result = predict("/kaggle/input/tasviret/flickr8k/Images/1032460886_4a598ed535.jpg", mode='beam')