## CLIP (ViT-B/32) + Projection MLP + mT5-Small Decoder Pipeline

Bu bölüm: 
- CLIP ViT-L/14 image encoder (tamamen freeze, edit: unfreeze)
- Görsel embedding -> K adet prefix token üreten MLP (öğrenilir)
- mT5-small (sadece decoder veya istersen tamamı) caption/çeviri üretimi
- Projection MLP ve mT5 decoder parametreleri eğitilecek.

Strateji (prefix approach):
1. Image -> CLIP encode_image -> (B,512)
2. MLP: 512 -> (K * d_model) reshape -> (B,K,512) -> LayerNorm
3. mT5 encoder'a inputs_embeds olarak bu prefix (opsiyonel ek tekst prompt tokenleri ile concat)
4. Decoder hedef yazıyı üretir (teacher forcing, cross-entropy)

Seçilebilir dondurma opsiyonları:
- freeze_clip = True (zorunlu senaryon)
- freeze_t5_encoder = True bırakıp sadece decoder + projection eğitilebilir

Aşağıdaki kod Flickr8k JSON (tasviret8k_captions.json) içinden (örnek) tek caption seçip dataset oluşturma iskeleti içerir.


In [1]:
# Rerun this cell at each session start

# Uninstall conflicting packages (Kaggle specific)
!pip uninstall -y bigframes cesium gcsfs

# Performance metrics
!pip install -r /kaggle/input/requirements8/requirements.txt

# To use nltk
import nltk; nltk.download('punkt'); nltk.download('wordnet'); nltk.download('omw-1.4')

# Download mT5-small ViT-B/32
import clip, torch
from transformers import MT5ForConditionalGeneration, MT5Tokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "google/mt5-small"

tokenizer = MT5Tokenizer.from_pretrained(model_name)
mt5 = MT5ForConditionalGeneration.from_pretrained(model_name).to(device)
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

print("Setup Done!")

Found existing installation: bigframes 2.8.0
Uninstalling bigframes-2.8.0:
  Successfully uninstalled bigframes-2.8.0
Found existing installation: cesium 0.12.4
Uninstalling cesium-0.12.4:
  Successfully uninstalled cesium-0.12.4
Found existing installation: gcsfs 2025.3.2
Uninstalling gcsfs-2025.3.2:
  Successfully uninstalled gcsfs-2025.3.2
Collecting git+https://github.com/openai/CLIP.git (from -r /kaggle/input/requirements8/requirements.txt (line 20))
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-61x507zk
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-61x507zk
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pycocoevalcap@ git+https://github.com/salaniz/pycocoevalcap.git (from -r /kaggle/input/requirements8/requirements.txt (line 23))
  Cloning https://github.com/salaniz/pycocoevalcap.git t

[nltk_data] Downloading package punkt to /usr/share/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to /usr/share/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /usr/share/nltk_data...
2025-09-08 12:07:47.910379: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1757333268.076287      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1757333268.132819      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


tokenizer_config.json:   0%|          | 0.00/82.0 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/4.31M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/553 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.
You are using the default legacy behaviour of the <class 'transformers.models.mt5.tokenization_mt5.MT5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]


  0%|                                               | 0.00/338M [00:00<?, ?iB/s][A
  2%|▋                                     | 6.00M/338M [00:00<00:05, 61.7MiB/s][A
  4%|█▎                                    | 12.2M/338M [00:00<00:05, 62.7MiB/s][A
  5%|██                                    | 18.3M/338M [00:00<00:05, 63.6MiB/s][A
  7%|██▋                                   | 24.4M/338M [00:00<00:05, 56.6MiB/s][A
 10%|███▋                                  | 32.9M/338M [00:00<00:04, 67.2MiB/s][A
 13%|████▊                                 | 42.5M/338M [00:00<00:03, 78.2MiB/s][A
 16%|█████▉                                | 52.7M/338M [00:00<00:03, 87.1MiB/s][A
 19%|███████                               | 62.7M/338M [00:00<00:03, 92.5MiB/s][A
 21%|████████                              | 71.6M/338M [00:00<00:03, 89.4MiB/s][A
 24%|█████████                             | 80.6M/338M [00:01<00:02, 90.8MiB/s][A
 26%|██████████                            | 89.3M/338M [00:01<00:02, 89.2M

Setup Done!


In [2]:
# Unified configuration + optional wandb init
from kaggle_secrets import UserSecretsClient
import os

PROJECT_NAME = "bites-tr-image-captioning"
RUN_NAME = "clip_mt5_prefix_run"

ENABLE_WANDB = True

WANDB_API_KEY = UserSecretsClient().get_secret("WANDB_API_KEY")

base_config = {
    "model": "google/mt5-small",
    "clip_encoder": "ViT-B/32",  # backbone
    "prefix_tokens": 32,          # stronger conditioning
    "batch_size": 16,
    "lr": 1e-4,
    "epochs": 25,                  # allow a bit longer now that we fine-tune CLIP
    "dataset_limit": None,
    # --- CLIP fine-tuning controls ---
    "freeze_clip": False,          # set False to allow full CLIP fine-tuning
    "unfreeze_clip_last_n": 0,     # if >0 unfreezes only last N vision blocks (overrides freeze_clip); 0 means ignore and use freeze_clip flag
    "clip_lr_scale": 0.05,         # scaled LR for ALL CLIP params (set lower than main to avoid destroying pretrained space)
    "use_clip_patch_tokens": True, # richer conditioning (patch tokens path)
    # --- T5 freezing ---
    "freeze_t5_encoder": False,    # unfreeze encoder so it can adapt to visual prefix distribution
    "freeze_t5_decoder": False,
    # --- Optimization ---
    "seed": 42,
    "weight_decay": 0.01,
    "grad_clip": 1.0,
    "warmup_steps": 500,           # linear warmup steps before cosine decay
    # --- Inference defaults ---
    "num_beams_infer": 4,
    "max_new_tokens_infer": 32,
    "src_max_len": 64,
    "tgt_max_len": 64,
    # --- Early stopping hyperparameters ---
    "early_stop_patience": 4,
    "early_stop_min_delta": 0.01,
    
    "use_amp": True,              # force AMP on CUDA for speed
    # Optional extras:
    "use_bf16": True,
    "enable_tf32": True,
    "finite_loss_skip": True,
    "save_every": 0
}

use_wandb = False
cfg = None
if ENABLE_WANDB:
    try:
        import wandb
        wandb.login(key=WANDB_API_KEY)
        run = wandb.init(project=PROJECT_NAME, name=RUN_NAME, config=base_config)
        cfg = wandb.config
        use_wandb = True
        print("[wandb] run initialized")
    except Exception as e:
        print("[wandb] disabled (init failed):", e)

if cfg is None:
    class _Cfg: pass
    cfg = _Cfg()
    for k, v in base_config.items():
        setattr(cfg, k, v)
    print("[INFO] Using local cfg (wandb off)")

import random, numpy as np, torch
random.seed(cfg.seed)
np.random.seed(cfg.seed)
torch.manual_seed(cfg.seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(cfg.seed)

# --- Enforce CUDA-only environment ---
assert torch.cuda.is_available(), "CUDA GPU is required but not detected. Please run in a CUDA-enabled environment."
device = torch.device('cuda')
print(f"[DEVICE] Using CUDA device: {torch.cuda.get_device_name(device)}")

# Performance toggles
if getattr(cfg, 'enable_tf32', False):
    try:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print("[DEVICE] TF32 enabled.")
    except Exception as _e:
        print("[WARN] Could not enable TF32:", _e)

if getattr(cfg, 'use_bf16', False):
    bf16_ok = torch.cuda.is_bf16_supported()
    print(f"[DEVICE] bfloat16 support: {bf16_ok}")

print("Active config:")
for k, v in base_config.items():
    print(f"  {k}: {getattr(cfg, k)}")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mabdulkadirparlak[0m ([33mabdulkadirparlak-hacettepe-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[wandb] run initialized
[DEVICE] Using CUDA device: Tesla P100-PCIE-16GB
[DEVICE] TF32 enabled.
[DEVICE] bfloat16 support: True
Active config:
  model: google/mt5-small
  clip_encoder: ViT-B/32
  prefix_tokens: 32
  batch_size: 16
  lr: 0.0001
  epochs: 25
  dataset_limit: None
  freeze_clip: False
  unfreeze_clip_last_n: 0
  clip_lr_scale: 0.05
  use_clip_patch_tokens: True
  freeze_t5_encoder: False
  freeze_t5_decoder: False
  seed: 42
  weight_decay: 0.01
  grad_clip: 1
  warmup_steps: 500
  num_beams_infer: 4
  max_new_tokens_infer: 32
  src_max_len: 64
  tgt_max_len: 64
  early_stop_patience: 4
  early_stop_min_delta: 0.01
  use_amp: True
  use_bf16: True
  enable_tf32: True
  finite_loss_skip: True
  save_every: 0


In [3]:
# Model, dataset and pipeline definitions (data + model init only)
import torch, os, json
from torch import nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import clip
from transformers import MT5ForConditionalGeneration, MT5Tokenizer

# (CUDA enforcement handled in config cell; assume cuda device later)

class ProjectionMLP(nn.Module):
    def __init__(self, in_dim, out_dim, prefix_tokens=8, hidden=1024, dropout=0.1):
        super().__init__()
        self.prefix_tokens = prefix_tokens
        self.fc1 = nn.Linear(in_dim, hidden)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden, out_dim * prefix_tokens)
        self.ln = nn.LayerNorm(out_dim)
        self.dropout = nn.Dropout(dropout)
    def forward(self, x):
        x = self.fc1(x); x = self.act(x); x = self.fc2(x)
        x = x.view(x.size(0), self.prefix_tokens, -1)
        x = self.ln(x)
        return self.dropout(x)

class CLIPmT5Pipeline(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.tokenizer = MT5Tokenizer.from_pretrained(cfg.model)
        self.model = MT5ForConditionalGeneration.from_pretrained(cfg.model)
        self.clip, _ = clip.load(cfg.clip_encoder, device='cpu')
        # CLIP freezing / unfreezing strategy
        if cfg.unfreeze_clip_last_n and cfg.unfreeze_clip_last_n > 0:
            # Freeze everything first
            for p in self.clip.parameters():
                p.requires_grad = False
            blocks = list(self.clip.visual.transformer.resblocks)
            for block in blocks[-cfg.unfreeze_clip_last_n:]:
                for p in block.parameters():
                    p.requires_grad = True
        else:
            # Respect freeze_clip flag (False means fully trainable)
            for p in self.clip.parameters():
                p.requires_grad = not cfg.freeze_clip
        # T5 freeze toggles
        if cfg.freeze_t5_encoder:
            for p in self.model.encoder.parameters():
                p.requires_grad = False
        if cfg.freeze_t5_decoder:
            for p in self.model.decoder.parameters():
                p.requires_grad = False
        self.prefix_tokens = cfg.prefix_tokens
        # Determine embedding dim dynamically (CLIP projection output)
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            embed_dim = self.clip.visual.output_dim if hasattr(self.clip.visual, 'output_dim') else self.clip.encode_image(dummy).shape[-1]
        self.proj = ProjectionMLP(in_dim=embed_dim, out_dim=self.model.config.d_model, prefix_tokens=cfg.prefix_tokens)
        self._cached_sentinel_ids = None  # lazy cache

    def _encode_image_single(self, images: torch.Tensor):
        # (B, D) pooled embedding (already passed through ln_post + proj inside encode_image)
        pooled = self.clip.encode_image(images)
        return pooled

    def _encode_image_patch_tokens(self, images: torch.Tensor):
        """Return patch-averaged embedding projected into CLIP joint space.
        We manually replicate encode_image path but average patch tokens (excluding CLS),
        then apply ln_post and proj so final dim == visual.output_dim (e.g. 512) to match ProjectionMLP in_dim."""
        visual = self.clip.visual
        x = visual.conv1(images)                      # (B, width, grid, grid)
        x = x.reshape(x.shape[0], x.shape[1], -1)     # (B, width, patches)
        x = x.permute(0, 2, 1)                        # (B, patches, width)
        cls_tokens = visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device)
        x = torch.cat([cls_tokens, x], dim=1)         # prepend CLS
        x = x + visual.positional_embedding.to(x.dtype)
        x = visual.ln_pre(x)
        x = x.permute(1, 0, 2)                        # (sequence, B, width)
        for block in visual.transformer.resblocks:
            x = block(x)
        x = x.permute(1, 0, 2)                        # (B, sequence, width)
        patches = x[:, 1:, :]                         # drop CLS
        pooled = patches.mean(dim=1)                  # (B, width)
        if hasattr(visual, 'ln_post'):
            pooled = visual.ln_post(pooled)
        if hasattr(visual, 'proj') and visual.proj is not None:
            pooled = pooled @ visual.proj             # (B, output_dim)
        return pooled

    def forward(self, images, src_texts, tgt_texts):
        device = next(self.parameters()).device
        images = images.to(device)
        clip_emb = self._encode_image_patch_tokens(images) if self.cfg.use_clip_patch_tokens else self._encode_image_single(images)
        prefix_emb = self.proj(clip_emb)
        tok_src = self.tokenizer(list(src_texts), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.src_max_len).to(device)
        tok_tgt = self.tokenizer(list(tgt_texts), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.tgt_max_len).to(device)
        text_emb = self.model.encoder.embed_tokens(tok_src.input_ids)
        full_emb = torch.cat([prefix_emb, text_emb], dim=1)
        full_attn = torch.cat([
            torch.ones(prefix_emb.size(0), self.prefix_tokens, dtype=tok_src.attention_mask.dtype, device=device),
            tok_src.attention_mask
        ], dim=1)
        return self.model(inputs_embeds=full_emb, attention_mask=full_attn, labels=tok_tgt.input_ids)

    def _prepare_prefix(self, images: torch.Tensor):
        images = images.to(next(self.parameters()).device)
        emb = self._encode_image_patch_tokens(images) if self.cfg.use_clip_patch_tokens else self._encode_image_single(images)
        return self.proj(emb)

    def _get_sentinel_bad_words(self, n=50):
        if self._cached_sentinel_ids is None:
            ids = [self.tokenizer(f'<extra_id_{i}>').input_ids[0] for i in range(n)]
            self._cached_sentinel_ids = [[i] for i in ids]
        return self._cached_sentinel_ids

    @torch.inference_mode()
    def generate(self, image_paths=None, images=None, num_beams=None, max_new_tokens=None, prompt="Bu görüntüyü açıkla: ", ban_sentinels=True, **gen_kwargs):
        device = next(self.parameters()).device
        num_beams = num_beams or self.cfg.num_beams_infer
        max_new_tokens = max_new_tokens or self.cfg.max_new_tokens_infer
        if images is None:
            assert image_paths is not None, "Provide image_paths or images tensor"
            preprocess = clip.load(self.cfg.clip_encoder, device='cpu')[1]
            pil_images = [Image.open(p).convert('RGB') for p in image_paths]
            images = torch.stack([preprocess(im) for im in pil_images])
        images = images.to(device)
        prefix_tokens = self._prepare_prefix(images)
        tok = self.tokenizer([prompt]*images.size(0), return_tensors='pt', padding=True, truncation=True, max_length=self.cfg.src_max_len).to(device)
        text_emb = self.model.encoder.embed_tokens(tok.input_ids)
        full_emb = torch.cat([prefix_tokens, text_emb], dim=1)
        full_attn = torch.cat([
            torch.ones(images.size(0), self.prefix_tokens, device=device, dtype=tok.attention_mask.dtype),
            tok.attention_mask
        ], dim=1)
        bad_words_ids = self._get_sentinel_bad_words() if ban_sentinels else None
        gen_ids = self.model.generate(
            inputs_embeds=full_emb,
            attention_mask=full_attn,
            num_beams=num_beams,
            max_new_tokens=max_new_tokens,
            bad_words_ids=bad_words_ids,
            **gen_kwargs
        )
        captions = [self.tokenizer.decode(g, skip_special_tokens=True).strip() for g in gen_ids]
        return [c if c else "<EMPTY>" for c in captions]

class Flickr8kCaptions(Dataset):
    def __init__(self, json_path, images_root, split=None, limit=None, clip_preprocess=None):
        self.images_root = images_root
        raw = json.load(open(json_path))
        rows = raw['images'] if isinstance(raw, dict) and 'images' in raw else raw
        self.samples = []
        for row in rows:
            if not isinstance(row, dict): continue
            if split and row.get('split') != split: continue
            img = row.get('filename') or row.get('image') or row.get('img')
            sentences = row.get('sentences')
            if not img or not sentences: continue
            for s in sentences:
                if isinstance(s, dict) and 'raw' in s:
                    self.samples.append((img, s['raw']))
        if limit: self.samples = self.samples[:limit]
        self.transform = clip_preprocess or clip.load(cfg.clip_encoder, device='cpu')[1]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        img_name, cap = self.samples[idx]
        path = os.path.join(self.images_root, img_name)
        image = Image.open(path).convert('RGB')
        return self.transform(image), " ", cap

In [4]:
# Data Loading
json_path = '/kaggle/input/tasviret/flickr8k/tasviret8k_captions.json'
images_root = '/kaggle/input/tasviret/flickr8k/Images'
train_dataset = Flickr8kCaptions(json_path, images_root, split='train', limit=cfg.dataset_limit)
val_dataset   = Flickr8kCaptions(json_path, images_root, split='val',   limit=None)
test_dataset  = Flickr8kCaptions(json_path, images_root, split='test',  limit=None)

train_loader = DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True)
val_loader   = DataLoader(val_dataset,   batch_size=cfg.batch_size, shuffle=False) if len(val_dataset)>0 else None
test_loader  = DataLoader(test_dataset,  batch_size=cfg.batch_size, shuffle=False) if len(test_dataset)>0 else None

model_mm = CLIPmT5Pipeline(cfg)
print(f"Train samples: {len(train_dataset)}  Val: {len(val_dataset)}  Test: {len(test_dataset)}")
print("Trainable params:", sum(p.numel() for p in model_mm.parameters() if p.requires_grad))

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


Train samples: 12028  Val: 2006  Test: 2003
Trainable params: 468774017


In [5]:
# Training cell (only training + validation)
import math, time, torch, warnings, os
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
warnings.filterwarnings("ignore", message=".*legacy behaviour.*MT5Tokenizer.*")

# CUDA-only enforcement (already asserted earlier, but double-check for isolation run)
assert torch.cuda.is_available(), "CUDA GPU required."
device = torch.device('cuda')
model_mm.to(device)

# Decide AMP dtype
amp_dtype = None
if getattr(cfg, 'use_amp', True):
    if getattr(cfg, 'use_bf16', False) and torch.cuda.is_bf16_supported():
        amp_dtype = torch.bfloat16
        print("[AMP] Using bfloat16 mixed precision")
    else:
        amp_dtype = torch.float16
        print("[AMP] Using float16 mixed precision")
else:
    print("[AMP] Disabled; using full float32")

# Parameter groups: separate CLIP vs non-CLIP for LR scaling
main_params = []
clip_params = []
for name, p in model_mm.named_parameters():
    if not p.requires_grad:
        continue
    if name.startswith('clip.'):
        clip_params.append(p)
    else:
        main_params.append(p)
param_groups = []
if main_params:
    param_groups.append({"params": main_params, "lr": cfg.lr})
if clip_params:
    scaled_lr = cfg.lr * getattr(cfg, 'clip_lr_scale', 0.05)
    param_groups.append({"params": clip_params, "lr": scaled_lr})
    print(f"[INFO] CLIP fine-tune params: {len(clip_params)} with lr={scaled_lr:.2e}")

optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay)

steps_per_epoch = len(train_loader)
total_steps = steps_per_epoch * cfg.epochs

def lr_lambda(step):
    # linear warmup then cosine decay
    if cfg.warmup_steps > 0 and step < cfg.warmup_steps:
        return float(step) / float(max(1, cfg.warmup_steps))
    progress = (step - cfg.warmup_steps) / float(max(1, total_steps - cfg.warmup_steps))
    progress = min(max(progress, 0.0), 1.0)
    return 0.5 * (1 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)
scaler = torch.amp.GradScaler('cuda', enabled=amp_dtype is not None)
best_val = float('inf')
CKPT_DIR = 'checkpoints'; os.makedirs(CKPT_DIR, exist_ok=True)

# Early stopping state
early_patience = getattr(cfg, 'early_stop_patience', None)
min_delta = getattr(cfg, 'early_stop_min_delta', 0.0)
_epochs_no_improve = 0
stopped_early = False

monitor_history = []  # (epoch, metric)

global_step = 0
for epoch in range(cfg.epochs):
    model_mm.train()
    train_sum = 0.0
    t0 = time.time()
    for step, batch in enumerate(train_loader, start=1):
        imgs, srcs, tgts = batch
        imgs = imgs.to(device, non_blocking=True)
        optimizer.zero_grad(set_to_none=True)

        if amp_dtype is not None:
            with torch.amp.autocast('cuda', dtype=amp_dtype):
                out = model_mm(imgs, srcs, tgts)
                loss = out.loss
            scaler.scale(loss).backward()
            if cfg.grad_clip and cfg.grad_clip > 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model_mm.parameters(), cfg.grad_clip)
            scaler.step(optimizer)
            scaler.update()
        else:
            out = model_mm(imgs, srcs, tgts); loss = out.loss
            loss.backward()
            if cfg.grad_clip and cfg.grad_clip > 0:
                torch.nn.utils.clip_grad_norm_(model_mm.parameters(), cfg.grad_clip)
            optimizer.step()

        scheduler.step()
        train_sum += loss.item()
        global_step += 1
    train_epoch_loss = train_sum / max(1, len(train_loader))

    val_epoch_loss = None
    if val_loader:
        model_mm.eval(); v = 0.0
        with torch.no_grad():
            for batch in val_loader:
                imgs, srcs, tgts = batch
                imgs = imgs.to(device, non_blocking=True)
                if amp_dtype is not None:
                    with torch.amp.autocast('cuda', dtype=amp_dtype):
                        out = model_mm(imgs, srcs, tgts)
                else:
                    out = model_mm(imgs, srcs, tgts)
                v += out.loss.item()
        val_epoch_loss = v / max(1, len(val_loader))

    dt = time.time() - t0
    current_lr = scheduler.get_last_lr()[0]
    if val_epoch_loss is not None:
        print(f"Epoch {epoch+1}/{cfg.epochs} train_loss={train_epoch_loss:.4f} val_loss={val_epoch_loss:.4f} time={dt:.1f}s lr={current_lr:.2e}", flush=True)
        if use_wandb:
            wandb.log({"train/epoch_loss": train_epoch_loss, "val/epoch_loss": val_epoch_loss, "lr": current_lr}, step=epoch)
    else:
        print(f"Epoch {epoch+1}/{cfg.epochs} train_loss={train_epoch_loss:.4f} time={dt:.1f}s lr={current_lr:.2e}", flush=True)
        if use_wandb:
            wandb.log({"train/epoch_loss": train_epoch_loss, "lr": current_lr}, step=epoch)

    metric = val_epoch_loss if val_epoch_loss is not None else train_epoch_loss
    monitor_history.append((epoch, metric))

    improved = metric < (best_val - min_delta)
    if improved:
        best_val = metric
        _epochs_no_improve = 0
        torch.save({
            'model': model_mm.state_dict(),
            'cfg': {k: getattr(cfg, k) for k in base_config.keys()},
            'epoch': epoch,
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'best_val': best_val,
        }, os.path.join(CKPT_DIR, 'best.pt'))
        print(f"  -> Saved checkpoint (metric {best_val:.4f})")
    else:
        _epochs_no_improve += 1

    if early_patience is not None and _epochs_no_improve >= early_patience:
        print(f"[Early Stop] No improvement (>{min_delta} delta) for {early_patience} consecutive epochs. Stopping at epoch {epoch+1}.")
        stopped_early = True
        break

print("Training finished. Best metric:", best_val, "(early stop)" if stopped_early else "")

[AMP] Using bfloat16 mixed precision
[INFO] CLIP fine-tune params: 302 with lr=5.00e-06


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch 1/25 train_loss=17.2407 val_loss=1.9684 time=423.6s lr=1.00e-04
  -> Saved checkpoint (metric 1.9684)
Epoch 2/25 train_loss=2.0191 val_loss=1.6931 time=385.3s lr=9.93e-05
  -> Saved checkpoint (metric 1.6931)
Epoch 3/25 train_loss=1.8182 val_loss=1.6248 time=393.7s lr=9.77e-05
  -> Saved checkpoint (metric 1.6248)
Epoch 4/25 train_loss=1.7422 val_loss=1.5870 time=384.6s lr=9.54e-05
  -> Saved checkpoint (metric 1.5870)
Epoch 5/25 train_loss=1.6813 val_loss=1.5650 time=384.4s lr=9.24e-05
  -> Saved checkpoint (metric 1.5650)
Epoch 6/25 train_loss=1.6351 val_loss=1.5442 time=382.9s lr=8.86e-05
  -> Saved checkpoint (metric 1.5442)
Epoch 7/25 train_loss=1.5986 val_loss=1.5256 time=381.9s lr=8.42e-05
  -> Saved checkpoint (metric 1.5256)
Epoch 8/25 train_loss=1.5640 val_loss=1.5167 time=380.4s lr=7.92e-05
Epoch 9/25 train_loss=1.5397 val_loss=1.5039 time=383.6s lr=7.37e-05
  -> Saved checkpoint (metric 1.5039)
Epoch 10/25 train_loss=1.5208 val_loss=1.4943 time=384.0s lr=6.79e-05
Epoc

In [5]:
# === Model Export / Import Utilities ===
import os, json, torch, math
from datetime import datetime
from typing import Optional, Dict, Any, Tuple

# ------------------------------
# Export
# ------------------------------

def export_model(
    save_dir: str,
    model: torch.nn.Module,
    cfg_obj,
    optimizer: Optional[torch.optim.Optimizer] = None,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
    epoch: Optional[int] = None,
    global_step: Optional[int] = None,
    best_val: Optional[float] = None,
    tag: str = "latest",
    use_safetensors: bool = False,
) -> str:
    """Export model checkpoint + (optional) optimizer/scheduler.

    Saves:
      - config.json (raw cfg values)
      - tokenizer/ (HF tokenizer)
      - clip_mt5_prefix_<tag>.pt (bundle) OR .safetensors (+ meta files)

    Bundle (.pt) contains:
      model_state, cfg, epoch, global_step, tag, export_time,
      optimizer_state?, scheduler_state?, best_val?, hyperparams.

    hyperparams dict added so retraining can auto‑rebuild optimizer/scheduler:
      { lr, clip_lr_scale, weight_decay, warmup_steps, grad_clip, use_amp, use_bf16 }
    """
    os.makedirs(save_dir, exist_ok=True)

    # Collect base config keys if available
    if 'base_config' in globals():
        cfg_dict = {k: getattr(cfg_obj, k) for k in base_config.keys() if hasattr(cfg_obj, k)}
    else:
        cfg_dict = {k: v for k, v in vars(cfg_obj).items() if not k.startswith('_')}

    # Persist config separately (human readable)
    with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
        json.dump(cfg_dict, f, ensure_ascii=False, indent=2)

    # Save tokenizer (best effort)
    try:
        model.tokenizer.save_pretrained(os.path.join(save_dir, 'tokenizer'))
    except Exception as e:
        print(f"[WARN] Tokenizer save failed: {e}")

    checkpoint_name = f"clip_mt5_prefix_{tag}"

    hyperparams = {
        'lr': cfg_dict.get('lr'),
        'clip_lr_scale': cfg_dict.get('clip_lr_scale'),
        'weight_decay': cfg_dict.get('weight_decay'),
        'warmup_steps': cfg_dict.get('warmup_steps'),
        'grad_clip': cfg_dict.get('grad_clip'),
        'use_amp': cfg_dict.get('use_amp'),
        'use_bf16': cfg_dict.get('use_bf16'),
        'batch_size': cfg_dict.get('batch_size'),
    }

    if use_safetensors:
        try:
            from safetensors.torch import save_file
            weights = model.state_dict()
            save_file(weights, os.path.join(save_dir, checkpoint_name + '.safetensors'))
            meta = {
                'cfg': cfg_dict,
                'epoch': epoch,
                'global_step': global_step,
                'export_time': datetime.utcnow().isoformat() + 'Z',
                'tag': tag,
                'best_val': best_val,
                'has_optimizer': optimizer is not None,
                'has_scheduler': scheduler is not None,
                'hyperparams': hyperparams,
            }
            if optimizer:
                torch.save(optimizer.state_dict(), os.path.join(save_dir, checkpoint_name + '.optimizer.pt'))
            if scheduler:
                torch.save(scheduler.state_dict(), os.path.join(save_dir, checkpoint_name + '.scheduler.pt'))
            with open(os.path.join(save_dir, 'meta.json'), 'w') as f:
                json.dump(meta, f, indent=2)
            print(f"[EXPORT] Weights -> {checkpoint_name}.safetensors (meta.json written)")
            return os.path.join(save_dir, checkpoint_name + '.safetensors')
        except ImportError:
            print('[WARN] safetensors not installed; falling back to .pt')

    # Standard .pt route
    bundle = {
        'model_state': model.state_dict(),
        'cfg': cfg_dict,
        'epoch': epoch,
        'global_step': global_step,
        'export_time': datetime.utcnow().isoformat() + 'Z',
        'tag': tag,
        'best_val': best_val,
        'hyperparams': hyperparams,
    }
    if optimizer:
        bundle['optimizer_state'] = optimizer.state_dict()
    if scheduler:
        bundle['scheduler_state'] = scheduler.state_dict()
    out_path = os.path.join(save_dir, checkpoint_name + '.pt')
    torch.save(bundle, out_path)
    print(f"[EXPORT] Saved checkpoint: {out_path}")
    return out_path

# ------------------------------
# Helper: build optimizer (main + clip groups)
# ------------------------------

def build_optimizer_from_hparams(model: torch.nn.Module, h: Dict[str, Any]):
    lr = h.get('lr', 1e-4)
    clip_lr_scale = h.get('clip_lr_scale', 0.05) or 0.05
    weight_decay = h.get('weight_decay', 0.0)
    main_params, clip_params = [], []
    for n, p in model.named_parameters():
        if not p.requires_grad:
            continue
        (clip_params if n.startswith('clip.') else main_params).append(p)
    param_groups = []
    if main_params:
        param_groups.append({'params': main_params, 'lr': lr})
    if clip_params:
        param_groups.append({'params': clip_params, 'lr': lr * clip_lr_scale})
    if clip_params:
        print(f"[OPT] CLIP params: {len(clip_params)} lr={lr * clip_lr_scale:.2e}")
    return torch.optim.AdamW(param_groups, weight_decay=weight_decay)

# ------------------------------
# Helper: build cosine scheduler with warmup (same as training)
# ------------------------------

def build_scheduler_from_hparams(optimizer, h: Dict[str, Any], steps_per_epoch: int, total_epochs: int):
    warmup_steps = h.get('warmup_steps', 0) or 0
    total_steps = steps_per_epoch * total_epochs
    def lr_lambda(step):
        if warmup_steps > 0 and step < warmup_steps:
            return float(step) / float(max(1, warmup_steps))
        progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps)) if total_steps > warmup_steps else 1.0
        progress = min(max(progress, 0.0), 1.0)
        return 0.5 * (1 + math.cos(math.pi * progress))
    from torch.optim.lr_scheduler import LambdaLR
    return LambdaLR(optimizer, lr_lambda=lr_lambda)

# ------------------------------
# Import for finetune (unchanged except now returns hyperparams)
# ------------------------------

def load_model_for_finetune(
    load_dir: str,
    device: torch.device,
    checkpoint_tag: str = 'latest',
    resume_optimizer: bool = True,
    build_optimizer_fn=None,
    build_scheduler_fn=None,
    override_cfg: Optional[Dict[str, Any]] = None,
    prefer_safetensors: bool = True
):
    # Load config
    with open(os.path.join(load_dir, 'config.json'), 'r', encoding='utf-8') as f:
        cfg_json = json.load(f)
    if override_cfg:
        cfg_json.update(override_cfg)

    class _Cfg: ...
    cfg_obj = _Cfg()
    for k, v in cfg_json.items():
        setattr(cfg_obj, k, v)

    model = CLIPmT5Pipeline(cfg_obj).to(device)

    ckpt_base = f"clip_mt5_prefix_{checkpoint_tag}"
    safepath = os.path.join(load_dir, ckpt_base + '.safetensors')
    ptpath = os.path.join(load_dir, ckpt_base + '.pt')

    optimizer_state = scheduler_state = None
    epoch = global_step = None
    hyperparams = None
    used_safetensors = False

    if prefer_safetensors and os.path.isfile(safepath):
        try:
            from safetensors.torch import load_file
            weights = load_file(safepath, device=device)
            model.load_state_dict(weights, strict=True)
            used_safetensors = True
            meta_path = os.path.join(load_dir, 'meta.json')
            meta = {}
            if os.path.isfile(meta_path):
                with open(meta_path, 'r') as f:
                    meta = json.load(f)
            epoch = meta.get('epoch'); global_step = meta.get('global_step')
            hyperparams = meta.get('hyperparams')
            if resume_optimizer and meta.get('has_optimizer'):
                opt_file = os.path.join(load_dir, ckpt_base + '.optimizer.pt')
                if os.path.isfile(opt_file):
                    optimizer_state = torch.load(opt_file, map_location='cpu')
            if resume_optimizer and meta.get('has_scheduler'):
                sch_file = os.path.join(load_dir, ckpt_base + '.scheduler.pt')
                if os.path.isfile(sch_file):
                    scheduler_state = torch.load(sch_file, map_location='cpu')
        except Exception as e:
            print(f"[WARN] safetensors load failed ({e}); falling back to .pt")
            used_safetensors = False

    if not used_safetensors:
        if not os.path.isfile(ptpath):
            raise FileNotFoundError(f"No checkpoint found at {ptpath}")
        bundle = torch.load(ptpath, map_location=device)
        model.load_state_dict(bundle['model_state'], strict=True)
        epoch = bundle.get('epoch'); global_step = bundle.get('global_step')
        hyperparams = bundle.get('hyperparams')
        if resume_optimizer:
            optimizer_state = bundle.get('optimizer_state')
            scheduler_state = bundle.get('scheduler_state')

    print(f"[IMPORT] Model loaded (epoch={epoch}, global_step={global_step})")
    return model, cfg_obj, hyperparams, optimizer_state, scheduler_state, epoch, global_step

# ------------------------------
# Simple one-shot retrain loader
# ------------------------------

def load_model_for_retrain(
    checkpoint_path: str,
    device: torch.device,
    new_lr: Optional[float] = None,
    new_clip_lr_scale: Optional[float] = None,
    reset_optimizer: bool = True,
    freeze_clip: Optional[bool] = None,
    override_cfg: Optional[Dict[str, Any]] = None,
):
    """Load a .pt bundle and prepare model + (fresh) optimizer hyperparams for retraining.

    Returns (model, cfg_obj, optimizer_hparams, epoch_loaded, global_step_loaded)
    Caller should then: optimizer = build_optimizer_from_hparams(model, optimizer_hparams)
    and build a new scheduler with build_scheduler_from_hparams once steps_per_epoch known.
    """
    assert os.path.isfile(checkpoint_path), f"Checkpoint not found: {checkpoint_path}"
    bundle = torch.load(checkpoint_path, map_location=device)
    if 'model_state' not in bundle:
        raise ValueError('Not an export_model bundle (.pt)')
    cfg_json = dict(bundle['cfg'])
    if override_cfg:
        cfg_json.update(override_cfg)

    class _Cfg: ...
    cfg_obj = _Cfg()
    for k, v in cfg_json.items():
        setattr(cfg_obj, k, v)
    if freeze_clip is not None:
        cfg_obj.freeze_clip = freeze_clip

    model = CLIPmT5Pipeline(cfg_obj).to(device)
    model.load_state_dict(bundle['model_state'], strict=True)

    base_h = bundle.get('hyperparams', {})
    # Override learning rates if requested
    if new_lr is not None:
        base_h['lr'] = new_lr
    if new_clip_lr_scale is not None:
        base_h['clip_lr_scale'] = new_clip_lr_scale

    # If freezing clip newly, we still keep clip_lr_scale but it won't matter
    epoch_loaded = bundle.get('epoch')
    global_step_loaded = bundle.get('global_step')
    print(f"[RETRAIN] Loaded weights (epoch={epoch_loaded}). Preparing fresh optimizer hyperparams.")
    return model, cfg_obj, base_h, epoch_loaded, global_step_loaded

print('[READY] Enhanced export/import + retrain helpers available.')

[READY] Enhanced export/import + retrain helpers available.


In [7]:
# Export the model
export_model("/kaggle/working/exports/run1", model_mm, cfg, optimizer, scheduler, epoch=epoch, global_step=global_step, tag="epoch_last")

[EXPORT] Saved checkpoint: /kaggle/working/exports/run1/clip_mt5_prefix_epoch_last.pt


'/kaggle/working/exports/run1/clip_mt5_prefix_epoch_last.pt'

In [1]:
# --- Validation Subset Metrics (pycocoevalcap) Helper ---
# Use to compute BLEU1-4, METEOR, ROUGE-L, CIDEr, SPICE on a fixed subset (default 200) of validation images each epoch.
# Call inside training/resume loops after computing (train/val) losses.

def compute_val_subset_metrics(
    model,
    val_dataset,
    images_root: str,
    device,
    amp_dtype=None,
    sample_size: int = 200,
    seed: int = 42,
    batch_size: int = 8,
    verbose: bool = False,
):
    import random, time, os, torch
    from collections import defaultdict
    if val_dataset is None or len(val_dataset) == 0:
        return {}

    # Build refs map (filename -> list[captions]) based on dataset.samples (filename, caption)
    refs_map = defaultdict(list)
    for (fname, cap) in val_dataset.samples:
        c = cap.strip()
        if c:
            refs_map[fname].append(c)
    unique_files = list(refs_map.keys())
    if not unique_files:
        return {}
    random.Random(seed).shuffle(unique_files)
    subset_files = unique_files[: min(sample_size, len(unique_files))]

    image_paths = [os.path.join(images_root, f) for f in subset_files]

    hyps, refs = [], []
    model.eval()
    t0 = time.time()

    def _gen(paths):
        if amp_dtype is not None:
            with torch.amp.autocast('cuda', dtype=amp_dtype):
                return model.generate(image_paths=paths, ban_sentinels=True)
        return model.generate(image_paths=paths, ban_sentinels=True)

    for i in range(0, len(image_paths), batch_size):
        batch = image_paths[i:i+batch_size]
        outs = _gen(batch)
        for p, pred in zip(batch, outs):
            fname = os.path.basename(p)
            hyps.append(pred if pred else "<EMPTY>")
            refs.append(refs_map.get(fname, [" "]))

    gts = {i: refs[i] for i in range(len(refs))}
    res = {i: [hyps[i]] for i in range(len(hyps))}

    from pycocoevalcap.bleu.bleu import Bleu
    from pycocoevalcap.meteor.meteor import Meteor
    from pycocoevalcap.rouge.rouge import Rouge
    from pycocoevalcap.cider.cider import Cider
    try:
        from pycocoevalcap.spice.spice import Spice
    except Exception:
        Spice = None

    metrics = {}
    try:
        bleu_scores, _ = Bleu(4).compute_score(gts, res)
        metrics['bleu1'] = float(bleu_scores[0]); metrics['bleu2'] = float(bleu_scores[1])
        metrics['bleu3'] = float(bleu_scores[2]); metrics['bleu4'] = float(bleu_scores[3])
        metrics['bleu'] = float(bleu_scores[3])
    except Exception as e:
        if verbose: print('[VAL_SUBSET] BLEU failed:', e)
    try:
        meteor_score, _ = Meteor().compute_score(gts, res)
        metrics['meteor'] = float(meteor_score)
    except Exception as e:
        if verbose: print('[VAL_SUBSET] METEOR failed:', e)
    try:
        rouge_score, _ = Rouge().compute_score(gts, res)
        metrics['rougeL'] = float(rouge_score)
    except Exception as e:
        if verbose: print('[VAL_SUBSET] ROUGE-L failed:', e)
    try:
        cider_score, _ = Cider().compute_score(gts, res)
        metrics['cider'] = float(cider_score)
    except Exception as e:
        if verbose: print('[VAL_SUBSET] CIDEr failed:', e)
    if Spice:
        try:
            spice_score, _ = Spice().compute_score(gts, res)
            metrics['spice'] = float(spice_score)
        except Exception as e:
            if verbose: print('[VAL_SUBSET] SPICE failed:', e)

    metrics['subset_size'] = len(subset_files)
    metrics['eval_time_s'] = round(time.time() - t0, 2)
    if verbose:
        print('[VAL_SUBSET]', ', '.join(f"{k}={v:.3f}" for k,v in metrics.items() if isinstance(v,(int,float))))
    return metrics

In [5]:
# === Direct .pt Checkpoint Loader (file-based) ===
import torch, os, json
from typing import Optional, Dict, Any, Tuple

def load_model_from_checkpoint_file(
    checkpoint_path: str,
    device: torch.device,
    resume_optimizer: bool = False,
    build_optimizer_fn=None,
    build_scheduler_fn=None,
    override_cfg: Optional[Dict[str, Any]] = None,
):
    """Load model (and optional optimizer/scheduler) directly from a single .pt file.

    Expects the .pt bundle produced by export_model (contains 'model_state' and 'cfg').

    Returns: (model, cfg_obj, optimizer, scheduler, epoch, global_step)
    """
    assert os.path.isfile(checkpoint_path), f"Checkpoint not found: {checkpoint_path}"
    bundle = torch.load(checkpoint_path, map_location=device)
    if 'model_state' not in bundle or 'cfg' not in bundle:
        raise ValueError("Checkpoint missing required keys 'model_state' or 'cfg'")

    cfg_json = dict(bundle['cfg'])
    if override_cfg:
        cfg_json.update(override_cfg)

    class _Cfg: ...
    cfg_obj = _Cfg()
    for k, v in cfg_json.items():
        setattr(cfg_obj, k, v)

    model = CLIPmT5Pipeline(cfg_obj).to(device)
    model.load_state_dict(bundle['model_state'], strict=True)

    epoch = bundle.get('epoch')
    global_step = bundle.get('global_step')

    optimizer = None
    scheduler = None
    if resume_optimizer and 'optimizer_state' in bundle:
        if build_optimizer_fn is None:
            print('[WARN] Optimizer state present but build_optimizer_fn not provided; skipping optimizer restore.')
        else:
            optimizer = build_optimizer_fn(cfg_obj, model)
            try:
                optimizer.load_state_dict(bundle['optimizer_state'])
                print('[IMPORT] Optimizer state restored.')
            except Exception as e:
                print(f'[WARN] Failed to load optimizer state: {e}')
    if resume_optimizer and optimizer and 'scheduler_state' in bundle and build_scheduler_fn:
        scheduler = build_scheduler_fn(cfg_obj, optimizer)
        try:
            scheduler.load_state_dict(bundle['scheduler_state'])
            print('[IMPORT] Scheduler state restored.')
        except Exception as e:
            print(f'[WARN] Failed to load scheduler state: {e}')

    print(f"[IMPORT] Loaded model from file: {checkpoint_path}")
    return model, cfg_obj, optimizer, scheduler, epoch, global_step

# One-line usage example (inference only):
# model_mm, cfg_loaded, _, _, epoch_loaded, step_loaded = load_model_from_checkpoint_file('/kaggle/input/15_epochs_tasviret/pytorch/default/1/clip_mt5_prefix_epoch_last.pt', device=torch.device('cuda'), resume_optimizer=False)


In [6]:
# Load the model
model_mm, cfg_loaded, _, _, epoch_loaded, step_loaded = load_model_from_checkpoint_file('/kaggle/input/25_epochs_tasviret/pytorch/default/1/clip_mt5_prefix_epoch_last.pt', device=torch.device('cuda'), resume_optimizer=False)

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'T5Tokenizer'. 
The class this function is called from is 'MT5Tokenizer'.


[IMPORT] Loaded model from file: /kaggle/input/25_epochs_tasviret/pytorch/default/1/clip_mt5_prefix_epoch_last.pt


In [7]:
# Testing / Inference cell (evaluate test set + pycocoevalcap metrics only)
import torch, time, os, json, math
from typing import List
from PIL import Image

assert torch.cuda.is_available(), "CUDA GPU required for inference."
model_mm.eval()
device = torch.device('cuda')
model_mm.to(device)

# Determine AMP dtype consistent with training
amp_dtype = None
if getattr(cfg, 'use_amp', True):
    if getattr(cfg, 'use_bf16', False) and torch.cuda.is_bf16_supported():
        amp_dtype = torch.bfloat16
    else:
        amp_dtype = torch.float16

# -------------------- Test Loss --------------------
test_loss = None
if test_loader:
    t0 = time.time(); total=0.0
    with torch.no_grad():
        for batch in test_loader:
            imgs, srcs, tgts = batch
            imgs = imgs.to(device, non_blocking=True)
            if amp_dtype is not None:
                with torch.amp.autocast('cuda', dtype=amp_dtype):
                    out = model_mm(imgs, srcs, tgts)
            else:
                out = model_mm(imgs, srcs, tgts)
            total += out.loss.item()
    test_loss = total / max(1,len(test_loader))
    print(f"Test loss={test_loss:.4f} time={(time.time()-t0):.1f}s")
else:
    print("No test split available.")

############################################
# Caption Quality Metrics (BLEU1-4, METEOR, CIDEr, ROUGE-L, SPICE) + W&B logging
# All via pycocoevalcap (no evaluate dependency)
############################################
metrics = {}
try:
    import nltk
    for pkg in ['punkt', 'wordnet', 'omw-1.4']:
        try:
            nltk.data.find(f'tokenizers/{pkg}')
        except Exception:
            try:
                nltk.download(pkg, quiet=True)
            except Exception:
                pass

    # Build reference sets (all captions per test image) from original JSON
    with open(json_path) as f:
        data_json = json.load(f)
    img_entries = data_json['images'] if isinstance(data_json, dict) else data_json

    # Gather unique test image filenames present in test_dataset
    test_image_names = sorted({s[0] for s in test_dataset.samples})
    refs_map = {}
    for entry in img_entries:
        if not isinstance(entry, dict):
            continue
        fname = entry.get('filename')
        if fname in test_image_names:
            all_caps = []
            for s in entry.get('sentences', []):
                if isinstance(s, dict) and 'raw' in s:
                    cap = s['raw'].strip()
                    if cap:
                        all_caps.append(cap)
            if all_caps:
                refs_map[fname] = all_caps

    # Generate hypotheses (batched for speed)
    hyps = []
    refs = []
    BATCH_GEN = 12  # moderate batch size; adjust as GPU allows
    image_paths = [os.path.join(images_root, fname) for fname in test_image_names]

    def _generate_batch(paths):
        if amp_dtype is not None:
            with torch.amp.autocast('cuda', dtype=amp_dtype):
                return model_mm.generate(image_paths=paths, ban_sentinels=True)
        return model_mm.generate(image_paths=paths, ban_sentinels=True)

    gen_start = time.time()
    for i in range(0, len(image_paths), BATCH_GEN):
        batch_paths = image_paths[i:i+BATCH_GEN]
        caps = _generate_batch(batch_paths)
        for p, pred in zip(batch_paths, caps):
            fname = os.path.basename(p)
            pred = pred if pred else "<EMPTY>"
            hyps.append(pred)
            # list of refs (avoid empties)
            cur_refs = refs_map.get(fname, [" "])
            refs.append([r for r in cur_refs if r.strip()] or [" "])
    print(f"[GEN] Generated {len(hyps)} captions in {time.time()-gen_start:.1f}s")

    # Build COCO-style dicts for pycocoevalcap
    gts = {i: refs[i] for i in range(len(refs))}
    res = {i: [hyps[i]] for i in range(len(hyps))}

    # ---- pycocoevalcap metrics ----
    from pycocoevalcap.bleu.bleu import Bleu
    from pycocoevalcap.meteor.meteor import Meteor
    from pycocoevalcap.rouge.rouge import Rouge
    from pycocoevalcap.cider.cider import Cider
    from pycocoevalcap.spice.spice import Spice

    metrics = {}
    # BLEU (returns list of 4 scores)
    try:
        bleu_scorer = Bleu(4)
        bleu_scores, _ = bleu_scorer.compute_score(gts, res)
        metrics['bleu1'] = float(bleu_scores[0])
        metrics['bleu2'] = float(bleu_scores[1])
        metrics['bleu3'] = float(bleu_scores[2])
        metrics['bleu4'] = float(bleu_scores[3])
        metrics['bleu'] = float(bleu_scores[3])  # treat BLEU-4 as aggregate
    except Exception as e:
        print('[WARN] BLEU failed:', e)

    try:
        meteor_scorer = Meteor()
        meteor_score, _ = meteor_scorer.compute_score(gts, res)
        metrics['meteor'] = float(meteor_score)
    except Exception as e:
        print('[WARN] METEOR failed:', e)

    try:
        rouge_scorer = Rouge()
        rouge_score, _ = rouge_scorer.compute_score(gts, res)
        metrics['rougeL'] = float(rouge_score)
    except Exception as e:
        print('[WARN] ROUGE-L failed:', e)

    try:
        cider_scorer = Cider()
        cider_score, _ = cider_scorer.compute_score(gts, res)
        metrics['cider'] = float(cider_score)
    except Exception as e:
        print('[WARN] CIDEr failed:', e)

    try:
        spice_scorer = Spice()
        spice_score, spice_scores = spice_scorer.compute_score(gts, res)
        metrics['spice'] = float(spice_score)
    except Exception as e:
        print('[WARN] SPICE failed (Java required?):', e)

    if test_loss is not None:
        metrics['test_loss'] = test_loss

    print("=== Caption Metrics (pycocoevalcap) ===")
    for k,v in metrics.items():
        print(f"{k}: {v:.4f}" if isinstance(v,(int,float)) else f"{k}: {v}")

    if use_wandb:
        wandb.log({f"eval/{k}": v for k,v in metrics.items()}, commit=True)
        print("[wandb] Logged caption metrics.")

except Exception as e:
    print("[ERROR] Metric computation failed:", e)

# Show a few sample generated captions from first N test images with diversity controls
default_gen_kwargs = dict(
    repetition_penalty=1.2,
    no_repeat_ngram_size=3,
)
SAMPLE_PRINTS = 5
if test_loader and SAMPLE_PRINTS > 0:
    shown = 0
    printed_imgs = set()
    for img_path, _ in [(os.path.join(images_root, s[0]), s[1]) for s in test_dataset.samples]:
        if img_path in printed_imgs:
            continue
        if amp_dtype is not None:
            with torch.amp.autocast('cuda', dtype=amp_dtype):
                caps = model_mm.generate(image_paths=[img_path], **default_gen_kwargs)
        else:
            caps = model_mm.generate(image_paths=[img_path], **default_gen_kwargs)
        gt_caption = next(s[1] for s in test_dataset.samples if os.path.join(images_root, s[0]) == img_path)
        print('GT:', gt_caption)
        print('CAP:', caps[0])
        print('---------------')
        printed_imgs.add(img_path)
        shown += 1
        if shown >= SAMPLE_PRINTS:
            break

# Optional: stochastic sampling helper for more diverse outputs
def generate_captions(image_paths: List[str], mode='beam', **kwargs):
    if mode == 'sample':
        sample_defaults = dict(temperature=0.8, top_p=0.9, do_sample=True, repetition_penalty=1.15, no_repeat_ngram_size=3)
        for k, v in sample_defaults.items():
            kwargs.setdefault(k, v)
    else:  # beam
        beam_defaults = dict(repetition_penalty=1.2, no_repeat_ngram_size=3)
        for k, v in beam_defaults.items():
            kwargs.setdefault(k, v)
    if amp_dtype is not None:
        with torch.amp.autocast('cuda', dtype=amp_dtype):
            return model_mm.generate(image_paths=image_paths, **kwargs)
    return model_mm.generate(image_paths=image_paths, **kwargs)

print("[Ready] generate_captions(['path/to/img.jpg'], mode='beam')")

Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Test loss=1.4448 time=32.9s
[GEN] Generated 1000 captions in 382.1s
{'testlen': 7001, 'reflen': 7497, 'guess': [7001, 6001, 5001, 4001], 'correct': [1203, 154, 20, 2]}
ratio: 0.9338402027476411
Downloading stanford-corenlp-3.6.0 for SPICE ...
Progress: 384.5M / 384.5M (100.0%)
Extracting stanford-corenlp-3.6.0 ...
Done.


Parsing reference captions
Initiating Stanford parsing pipeline
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator tokenize
[main] INFO edu.stanford.nlp.pipeline.TokenizerAnnotator - TokenizerAnnotator: No tokenizer type provided. Defaulting to PTBTokenizer.
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator ssplit
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator parse
[main] INFO edu.stanford.nlp.parser.common.ParserGrammar - Loading parser from serialized file edu/stanford/nlp/models/lexparser/englishPCFG.ser.gz ... 
done [0.7 sec].
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator lemma
[main] INFO edu.stanford.nlp.pipeline.StanfordCoreNLP - Adding annotator ner
Loading classifier from edu/stanford/nlp/models/ner/english.all.3class.distsim.crf.ser.gz ... done [1.3 sec].
Loading classifier from edu/stanford/nlp/models/ner/english.muc.7class.distsim.crf.ser.gz ... done [0.7 sec].
Loading classif

SPICE evaluation took: 52.13 s
=== Caption Metrics (pycocoevalcap) ===
bleu1: 0.1601
bleu2: 0.0619
bleu3: 0.0242
bleu4: 0.0090
bleu: 0.0090
meteor: 0.0883
rougeL: 0.1482
cider: 0.0643
spice: 0.0267
test_loss: 1.4448
[wandb] Logged caption metrics.
GT: İki kahverengi köpek kar üstünde kavga ediyor.
CAP: Kırmızı tişörtlü bir adam yeşil bir topu havada yakalamış.
---------------
GT: Havuzda sahibine doğru yüzen küçük bir köpek.
CAP: İki köpek çimlerin üzerinde koşuyor.
---------------
GT: Bir sokak festivalinde bir kadın ve bir erkek sambacı dans ediyor, kadın tüy şapka giymiş, arkada bir adam fotoğraf çekiyor.
CAP: Kırmızı tişörtlü bir çocuk yeşillikler içerisinde koşuyor.
---------------
GT: İki kadınla bir adam tahta masa sandalyede bir şemsiyenin altında oturmuş içeceklerini içiyor.
CAP: Kırmızı tişörtlü bir çocuk yeşillikler içerisinde koşuyor.
---------------
GT: Maç sırasında odaklanmış belli bir noktaya bakmakta olan bir Amerikan futbol oyuncusu.
CAP: Kırmızı tişörtlü bir çocuk ye

In [None]:
# Single Image Inference with Optional pycocoevalcap Metrics (BLEU1-4, METEOR, ROUGE-L, CIDEr, SPICE)
from typing import Optional, Dict, Any, List
import matplotlib.pyplot as plt
import os, json, torch

# Cache scorers globally to avoid repeated Java/METEOR/Spice init costs
_PYCOCO_SCORERS = {}


def _lazy_load_scorers():
    global _PYCOCO_SCORERS
    from pycocoevalcap.bleu.bleu import Bleu
    from pycocoevalcap.meteor.meteor import Meteor
    from pycocoevalcap.rouge.rouge import Rouge
    from pycocoevalcap.cider.cider import Cider
    try:
        from pycocoevalcap.spice.spice import Spice
    except Exception:
        Spice = None
    if 'bleu' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['bleu'] = Bleu(4)
    if 'meteor' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['meteor'] = Meteor()
    if 'rouge' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['rouge'] = Rouge()
    if 'cider' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['cider'] = Cider()
    if Spice and 'spice' not in _PYCOCO_SCORERS:
        _PYCOCO_SCORERS['spice'] = Spice()
    return _PYCOCO_SCORERS


def _compute_single_caption_metrics(pred: str, refs: List[str]) -> Dict[str, float]:
    refs_clean = [r.strip() for r in refs if r and r.strip()]
    if not refs_clean:
        return {}
    scorers = _lazy_load_scorers()
    gts = {0: refs_clean}
    res = {0: [pred]}
    out = {}
    # BLEU
    try:
        bleu_scores, _ = scorers['bleu'].compute_score(gts, res)
        out['bleu1'] = float(bleu_scores[0])
        out['bleu2'] = float(bleu_scores[1])
        out['bleu3'] = float(bleu_scores[2])
        out['bleu4'] = float(bleu_scores[3])
        out['bleu'] = float(bleu_scores[3])
    except Exception as e:
        print('[WARN] BLEU failed:', e)
    # METEOR
    try:
        meteor_score, _ = scorers['meteor'].compute_score(gts, res)
        out['meteor'] = float(meteor_score)
    except Exception as e:
        print('[WARN] METEOR failed:', e)
    # ROUGE-L
    try:
        rouge_score, _ = scorers['rouge'].compute_score(gts, res)
        out['rougeL'] = float(rouge_score)
    except Exception as e:
        print('[WARN] ROUGE-L failed:', e)
    # CIDEr
    try:
        cider_score, _ = scorers['cider'].compute_score(gts, res)
        out['cider'] = float(cider_score)
    except Exception as e:
        print('[WARN] CIDEr failed:', e)
    # SPICE (may require Java)
    if 'spice' in scorers:
        try:
            spice_score, spice_scores = scorers['spice'].compute_score(gts, res)
            out['spice'] = float(spice_score)
        except Exception as e:
            print('[WARN] SPICE failed:', e)
    return out


def predict(
    image_path: str,
    prompt: str = "Bu fotoğrafı açıkla: ",
    mode: str = "beam",              # 'beam' or 'sample'
    max_refs: Optional[int] = 5,
    show_image: bool = True,
    show_refs: bool = True,
    gen_kwargs: Optional[Dict[str, Any]] = None,
    json_file: Optional[str] = None,
    ban_sentinels: bool = True,
    compute_metrics: bool = True,
    print_metrics: bool = True,
) -> Dict[str, Any]:
    """Generate a caption and optionally compute pycocoevalcap metrics.

    Returns dict with: image_path, prediction, references (list), metrics (dict), mode.
    """
    assert os.path.isfile(image_path), f"Image not found: {image_path}"
    jf = json_file or json_path  # use global json_path defined earlier

    # Collect reference captions
    refs: List[str] = []
    try:
        with open(jf) as f:
            data = json.load(f)
        entries = data['images'] if isinstance(data, dict) and 'images' in data else data
        target_name = os.path.basename(image_path)
        for e in entries:
            if not isinstance(e, dict):
                continue
            if e.get('filename') == target_name:
                for s in e.get('sentences', []):
                    if isinstance(s, dict) and 'raw' in s:
                        cap = s['raw'].strip()
                        if cap:
                            refs.append(cap)
                break
    except Exception as e:
        print(f"[WARN] Could not parse references ({e})")

    if max_refs is not None:
        refs = refs[:max_refs]

    # Defaults for decoding
    gen_kwargs = (gen_kwargs or {}).copy()
    if mode == "sample":
        defaults = dict(temperature=0.8, top_p=0.9, do_sample=True,
                        repetition_penalty=1.15, no_repeat_ngram_size=3)
    else:
        defaults = dict(repetition_penalty=1.2, no_repeat_ngram_size=3,
                        num_beams=getattr(cfg, "num_beams_infer", 4))
    for k, v in defaults.items():
        gen_kwargs.setdefault(k, v)

    local_amp_dtype = globals().get("amp_dtype", None)
    model_mm.eval()
    if local_amp_dtype is not None:
        with torch.amp.autocast('cuda', dtype=local_amp_dtype):
            pred = model_mm.generate(
                image_paths=[image_path],
                prompt=prompt,
                ban_sentinels=ban_sentinels,
                **gen_kwargs
            )[0]
    else:
        pred = model_mm.generate(
            image_paths=[image_path],
            prompt=prompt,
            ban_sentinels=ban_sentinels,
            **gen_kwargs
        )[0]

    metrics: Dict[str, float] = {}
    if compute_metrics and refs:
        try:
            metrics = _compute_single_caption_metrics(pred if pred else "<EMPTY>", refs)
        except Exception as e:
            print(f"[WARN] Metric computation failed: {e}")

    # Visualization
    if show_image:
        try:
            img = Image.open(image_path).convert("RGB")
            plt.figure(figsize=(5,5))
            plt.imshow(img); plt.axis("off")
            plt.title("Image")
            plt.show()
        except Exception as e:
            print(f"[WARN] Could not display image ({e})")

    if show_refs:
        if refs:
            print("References:")
            for i, r in enumerate(refs, 1):
                print(f"  {i}. {r}")
        else:
            print("(No references found)")

    print("Prediction:")
    print(" ", pred)

    if print_metrics and metrics:
        print("\nMetrics (pycocoevalcap):")
        for k, v in metrics.items():
            print(f"  {k}: {v:.4f}" if isinstance(v,(int,float)) else f"  {k}: {v}")

    return {
        "image_path": image_path,
        "prediction": pred,
        "references": refs,
        "metrics": metrics,
        "mode": mode,
    }

# Example:
# result = predict("/kaggle/input/tasviret/flickr8k/Images/1032460886_4a598ed535.jpg", mode="beam")
# print(result)

In [None]:
predict("/kaggle/input/tasviret/flickr8k/Images/1032460886_4a598ed535.jpg", mode='beam')

In [None]:
# === Example: Retrain from exported checkpoint on same or new dataset ===
# checkpoint_path = '/kaggle/input/your_export/clip_mt5_prefix_latest.pt'
# from torch.utils.data import DataLoader
# device = torch.device('cuda')
# model_re, cfg_re, hparams, epoch_loaded, step_loaded = load_model_for_retrain(
#     checkpoint_path,
#     device=device,
#     new_lr=5e-5,              # optionally override LR
#     new_clip_lr_scale=0.02,   # optionally override CLIP LR scale
#     freeze_clip=None,         # or True to freeze CLIP now
# )
# optimizer_re = build_optimizer_from_hparams(model_re, hparams)
# steps_per_epoch = len(train_loader)  # after you rebuild train_loader for (same or new) dataset
# scheduler_re = build_scheduler_from_hparams(optimizer_re, hparams, steps_per_epoch, total_epochs=5)
# # Proceed with standard training loop using model_re / optimizer_re / scheduler_re

In [None]:
# ==== Resume / Continue Training From Exported or Training Checkpoint (.pt) ====
# Improvements:
#  - Ensures 'checkpoints' directory exists before saving.
#  - Handles tokenizer class mismatch notice (mt5 vs t5) by reloading MT5 tokenizer explicitly.
#  - Initializes best_val robustly (evaluates validation if inf).
#  - Safe scheduler positioning.
#  - Adds optional immediate validation before continuing to set baseline.

RESUME_PATH = '/kaggle/input/15_epochs_tasviret/pytorch/default/1/clip_mt5_prefix_epoch_last.pt'  # change as needed
EXTRA_EPOCHS = 5          # number of extra epochs to train
RUN_INITIAL_VAL = True    # run a validation pass right after load to set best_val if missing
SAVE_NAME = 'best_resumed.pt'

import os, math, time, torch
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR

os.makedirs('checkpoints', exist_ok=True)

if not os.path.isfile(RESUME_PATH):
    print(f"[RESUME] Path not found: {RESUME_PATH} -> Skip.")
else:
    print(f"[RESUME] Loading: {RESUME_PATH}")
    bundle = torch.load(RESUME_PATH, map_location='cpu')

    has_export_format = 'model_state' in bundle
    has_train_ckpt_format = 'model' in bundle
    if not (has_export_format or has_train_ckpt_format):
        raise ValueError('Unrecognized checkpoint format.')

    cfg_dict = bundle.get('cfg') or {}
    class _Cfg: ...
    cfg_resume = _Cfg()
    for k, v in base_config.items():
        setattr(cfg_resume, k, cfg_dict.get(k, v))

    # Force consistent tokenizer type (MT5Tokenizer) regardless of original class
    model_resume = CLIPmT5Pipeline(cfg_resume).to(device)

    if has_export_format:
        model_resume.load_state_dict(bundle['model_state'], strict=True)
        loaded_epoch = bundle.get('epoch', -1)
        global_step_loaded = bundle.get('global_step', None)
        best_val_loaded = bundle.get('best_val', float('inf'))
    else:
        model_resume.load_state_dict(bundle['model'], strict=True)
        loaded_epoch = bundle.get('epoch', -1)
        global_step_loaded = bundle.get('global_step', None)
        best_val_loaded = bundle.get('best_val', float('inf'))

    print(f"[RESUME] Weights loaded (epoch={loaded_epoch}, best_val={best_val_loaded})")

    main_params, clip_params = [], []
    for n, p in model_resume.named_parameters():
        if not p.requires_grad: continue
        (clip_params if n.startswith('clip.') else main_params).append(p)
    param_groups = []
    if main_params: param_groups.append({'params': main_params, 'lr': cfg_resume.lr})
    if clip_params: param_groups.append({'params': clip_params, 'lr': cfg_resume.lr * getattr(cfg_resume, 'clip_lr_scale', 0.05)})
    optimizer_resume = AdamW(param_groups, weight_decay=cfg_resume.weight_decay)

    steps_per_epoch = len(train_loader)
    total_planned_epochs = loaded_epoch + 1 + EXTRA_EPOCHS
    total_steps = steps_per_epoch * total_planned_epochs

    def lr_lambda(step):
        if cfg_resume.warmup_steps > 0 and step < cfg_resume.warmup_steps:
            return float(step) / float(max(1, cfg_resume.warmup_steps))
        progress = (step - cfg_resume.warmup_steps) / float(max(1, total_steps - cfg_resume.warmup_steps))
        progress = min(max(progress, 0.0), 1.0)
        return 0.5 * (1 + math.cos(math.pi * progress))

    scheduler_resume = LambdaLR(optimizer_resume, lr_lambda=lr_lambda)

    opt_key = 'optimizer_state' if has_export_format else 'optimizer'
    sch_key = 'scheduler_state' if has_export_format else 'scheduler'
    if opt_key in bundle:
        try:
            optimizer_resume.load_state_dict(bundle[opt_key])
            print('[RESUME] Optimizer state restored.')
        except Exception as e:
            print('[WARN] Optimizer state load failed:', e)
    if sch_key in bundle:
        try:
            scheduler_resume.load_state_dict(bundle[sch_key])
            print('[RESUME] Scheduler state restored.')
        except Exception as e:
            print('[WARN] Scheduler state load failed:', e)

    if scheduler_resume.last_epoch < 0 and loaded_epoch >= 0:
        completed_steps = (loaded_epoch + 1) * steps_per_epoch
        scheduler_resume.last_epoch = completed_steps
        print(f"[RESUME] Scheduler last_epoch set to {scheduler_resume.last_epoch}")

    amp_dtype_resume = None
    if getattr(cfg_resume, 'use_amp', True):
        if getattr(cfg_resume, 'use_bf16', False) and torch.cuda.is_bf16_supported():
            amp_dtype_resume = torch.bfloat16
        else:
            amp_dtype_resume = torch.float16
    scaler_resume = torch.amp.GradScaler('cuda', enabled=amp_dtype_resume is not None)

    start_epoch = loaded_epoch + 1
    end_epoch = loaded_epoch + EXTRA_EPOCHS

    best_val = best_val_loaded

    # Optional immediate validation to set best_val if it's inf or user wants baseline
    if RUN_INITIAL_VAL and (best_val == float('inf') or best_val != best_val):  # inf or NaN
        if val_loader:
            model_resume.eval(); v=0.0
            with torch.no_grad():
                for batch in val_loader:
                    imgs, srcs, tgts = batch
                    imgs = imgs.to(device, non_blocking=True)
                    if amp_dtype_resume is not None:
                        with torch.amp.autocast('cuda', dtype=amp_dtype_resume):
                            out = model_resume(imgs, srcs, tgts)
                    else:
                        out = model_resume(imgs, srcs, tgts)
                    v += out.loss.item()
            best_val = v / max(1, len(val_loader))
            print(f"[RESUME] Initial validation baseline best_val set to {best_val:.4f}")
        else:
            best_val = float('inf')

    early_patience = getattr(cfg_resume, 'early_stop_patience', None)
    min_delta = getattr(cfg_resume, 'early_stop_min_delta', 0.0)
    _epochs_no_improve = 0

    global_step_resume = global_step_loaded if global_step_loaded is not None else (start_epoch * steps_per_epoch)

    print(f"[RESUME] Continue training for {EXTRA_EPOCHS} more epochs: {start_epoch} -> {end_epoch}")

    for epoch in range(start_epoch, end_epoch + 1):
        model_resume.train()
        sum_loss = 0.0
        t0 = time.time()
        for step, batch in enumerate(train_loader, start=1):
            imgs, srcs, tgts = batch
            imgs = imgs.to(device, non_blocking=True)
            optimizer_resume.zero_grad(set_to_none=True)
            if amp_dtype_resume is not None:
                with torch.amp.autocast('cuda', dtype=amp_dtype_resume):
                    out = model_resume(imgs, srcs, tgts)
                    loss = out.loss
                scaler_resume.scale(loss).backward()
                if cfg_resume.grad_clip and cfg_resume.grad_clip > 0:
                    scaler_resume.unscale_(optimizer_resume)
                    torch.nn.utils.clip_grad_norm_(model_resume.parameters(), cfg_resume.grad_clip)
                scaler_resume.step(optimizer_resume)
                scaler_resume.update()
            else:
                out = model_resume(imgs, srcs, tgts); loss = out.loss
                loss.backward()
                if cfg_resume.grad_clip and cfg_resume.grad_clip > 0:
                    torch.nn.utils.clip_grad_norm_(model_resume.parameters(), cfg_resume.grad_clip)
                optimizer_resume.step()
            scheduler_resume.step()
            sum_loss += loss.item()
            global_step_resume += 1
        train_loss = sum_loss / max(1, len(train_loader))

        val_loss = None
        if val_loader:
            model_resume.eval(); v = 0.0
            with torch.no_grad():
                for batch in val_loader:
                    imgs, srcs, tgts = batch
                    imgs = imgs.to(device, non_blocking=True)
                    if amp_dtype_resume is not None:
                        with torch.amp.autocast('cuda', dtype=amp_dtype_resume):
                            out = model_resume(imgs, srcs, tgts)
                        
                    else:
                        out = model_resume(imgs, srcs, tgts)
                    v += out.loss.item()
            val_loss = v / max(1, len(val_loader))

        dt = time.time() - t0
        lr_cur = scheduler_resume.get_last_lr()[0]
        if val_loss is not None:
            print(f"[RESUME] Epoch {epoch} train={train_loss:.4f} val={val_loss:.4f} time={dt:.1f}s lr={lr_cur:.2e}")
        else:
            print(f"[RESUME] Epoch {epoch} train={train_loss:.4f} time={dt:.1f}s lr={lr_cur:.2e}")

        metric = val_loss if val_loss is not None else train_loss
        improved = metric < (best_val - min_delta)
        if improved:
            best_val = metric
            _epochs_no_improve = 0
            save_obj = {
                'model': model_resume.state_dict(),
                'cfg': {k: getattr(cfg_resume, k) for k in base_config.keys()},
                'epoch': epoch,
                'optimizer': optimizer_resume.state_dict(),
                'scheduler': scheduler_resume.state_dict(),
                'best_val': best_val,
                'global_step': global_step_resume,
            }
            torch.save(save_obj, os.path.join('checkpoints', SAVE_NAME))
            print(f"   -> [RESUME] Saved {SAVE_NAME} (metric={best_val:.4f})")
        else:
            _epochs_no_improve += 1
            if early_patience is not None and _epochs_no_improve >= early_patience:
                print(f"[Early Stop - Resume] No improvement for {early_patience} epochs.")
                break

    print('[RESUME] Training extension finished. Final best_val=', best_val)
    model_mm = model_resume