<a href="https://colab.research.google.com/github/Saivamsee24/Multimodal-Machine-Translation-Leveraging-Images-for-Enhanced-Language-Understanding/blob/main/Multimodal_Machine_Translation_Leveraging_Images_for_Enhanced_Language_Understanding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import subprocess

gpu_info = subprocess.check_output("nvidia-smi -L", shell=True).decode()
print("GPU Info ‚Üí", gpu_info)

if "A100" not in gpu_info and "H100" not in gpu_info:
    print("‚ö†Ô∏è WARNING: You did NOT receive an A100/H100.")
    print("Training will be 10‚Äì20√ó slower.")
    print("Please restart runtime until you get an A100/H100.")
else:
    print("‚úÖ Great! You received a premium GPU.")


GPU Info ‚Üí GPU 0: NVIDIA L4 (UUID: GPU-118ded44-0181-16e2-816b-c2bdd3b7d899)

Training will be 10‚Äì20√ó slower.
Please restart runtime until you get an A100/H100.


In [None]:
!pip install evaluate
!pip install transformers
!pip install sacrebleu
!pip install sentencepiece
!pip install -U transformers accelerate datasets evaluate sentencepiece
!pip install wandb
!pip install rouge_score
!pip install bert_score


Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m84.1/84.1 kB[0m [31m9.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.6
Collecting sacrebleu
  Downloading sacrebleu-2.5.1-py3-none-any.whl.metadata (51 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m51.8/51.8 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting portalocker (from sacrebleu)
  Downloading portalocker-3.2.0-py3-none-any.whl.metadata (8.7 kB)
Collecting colorama (from sacrebleu)
  Downloading colorama-0.4.6-py2.py3-none-any.whl.metadata (17 kB)
Downloading sacrebleu-2.5.1-py3-none-any.whl (104 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [None]:
#Individual Model Training

In [None]:
# ==============================================================
# üåç MULTIMODAL TRANSLATION + SIGLIP + LORA + FUSION vs TEXT-ONLY
#  - Multi30K (data/task1/raw + image_splits)
#  - SigLIP vision encoder (google/siglip-base-patch16-224)
#  - mBART-50 text model with LoRA on attention (q_proj, v_proj)
#  - Better fusion: Transformer-based fusion over [IMG + TEXT]
#  - Also trains text-only mBART+LoRA baseline for comparison
# ==============================================================

import os
import json
from pathlib import Path
from typing import List, Tuple
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFile
from tqdm import tqdm
import evaluate
import warnings

warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------ HF + PEFT imports ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    os.system("pip install -q transformers peft accelerate")
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# ------------------ DEVICE ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

# ============================================================== #
# CONFIG
# ============================================================== #

@dataclass
class Config:
    # Paths
    data_root: str = "/content/multi30k-dataset"
    image_dir: str = "flickr30k-images"
    save_dir: str = "/content/multimodal_translation_models_siglip_lora_fusion"

    # Training
    max_length: int = 64
    batch_size: int = 2          # small for VRAM safety
    learning_rate: float = 3e-5
    num_epochs: int = 6
    patience: int = 3
    min_delta: float = 0.5       # BLEU improvement threshold to reset patience
    use_amp: bool = True

    # Data limits
    max_train_samples: int = 15000
    max_val_samples: int = 1000

    # Optim
    warmup_steps: int = 100
    max_grad_norm: float = 1.0

    # Vision model (SigLIP)
    vision_model_name: str = "google/siglip-base-patch16-224"

    # LoRA
    use_lora: bool = True
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1
    lora_targets: List[str] = None

    # Language directions
    directions: List[Tuple[str, str]] = None

    def __post_init__(self):
        if self.lora_targets is None:
            # Attn projections in mBART encoder+decoder
            self.lora_targets = ["q_proj", "v_proj"]
        if self.directions is None:
            self.directions = [
                ("en", "de"),
                ("en", "fr"),
                ("de", "en"),
                ("de", "fr"),
                ("fr", "en"),
                ("fr", "de"),
            ]

config = Config()

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# Global metric (avoid re-loading each epoch)
sacrebleu_metric = evaluate.load("sacrebleu")

# ============================================================== #
# OPTIONAL: DRIVE MOUNT + SYMLINK (if you want)
# ============================================================== #

def mount_and_link_dataset():
    """
    Mounts Google Drive and links /content/multi30k-dataset to your folder.
    Safe: no directory scanning, just checks existence.
    """
    try:
        from google.colab import drive
    except Exception:
        print("‚ÑπÔ∏è Not running in Colab / no google.colab, skipping mount.")
        return config.data_root

    print("üîó Mounting Google Drive...")
    drive.mount("/content/drive")

    candidate_paths = [
        "/content/drive/MyDrive/multi30k-dataset",
        "/content/drive/MyDrive/dataset/multi30k-dataset",
        "/content/drive/MyDrive/Colab Notebooks/multi30k-dataset",
    ]

    dataset_path = None
    for p in candidate_paths:
        if os.path.exists(p):
            dataset_path = p
            print(f"‚úÖ Found dataset at: {p}")
            break

    if dataset_path is None:
        print("‚ùå Multi30K dataset not found in default locations. Using existing:", config.data_root)
        return config.data_root

    if os.path.islink("/content/multi30k-dataset") or os.path.exists("/content/multi30k-dataset"):
        os.system("rm -rf /content/multi30k-dataset")

    os.symlink(dataset_path, "/content/multi30k-dataset")
    print("üîó Symlink created ‚Üí /content/multi30k-dataset")
    return "/content/multi30k-dataset"

# ============================================================== #
# IMAGE LOADER (NO DIR LISTING)
# ============================================================== #

def safe_load_image(image_id: str, root: Path) -> Image.Image:
    """
    Loads one image by ID without listing directories.
    Multi30K image IDs in image_splits are usually like "1234567890.jpg" or "1234567890".
    We try: id, id.jpg, id.jpeg, id.png.
    """
    base = image_id.strip()
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[: -len(ext)]
            break

    candidates = [
        f"{base}.jpg",
        f"{base}.jpeg",
        f"{base}.png",
        base,
    ]

    for name in candidates:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except Exception:
                pass

    # Fallback: dummy gray image
    return Image.new("RGB", (224, 224), (128, 128, 128))

# ============================================================== #
# LORA HELPER
# ============================================================== #

def apply_lora_to_mbart(mbart: MBartForConditionalGeneration) -> MBartForConditionalGeneration:
    """
    Wraps mBART with LoRA on attention projections.
    """
    if not config.use_lora:
        print("‚ÑπÔ∏è LoRA disabled; training full mBART (heavier).")
        return mbart

    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    peft_model = get_peft_model(mbart, lora_cfg)
    print("‚úÖ LoRA applied to mBART (targets:", config.lora_targets, ")")
    peft_model.print_trainable_parameters()
    return peft_model

# ============================================================== #
# FUSION BLOCK (BETTER THAN PLAIN CONCAT)
# ============================================================== #

class FusionBlock(nn.Module):
    """
    Simple Transformer-based fusion over [IMG_TOKEN + TEXT_TOKENS].
    Lets the image token attend to text and vice versa.
    """
    def __init__(self, d_model: int, nhead: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=1)

    def forward(self, img_embed: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
        """
        img_embed: [B,1,d_model]
        text_embed: [B,L,d_model]
        returns fused: [B,1+L,d_model]
        """
        x = torch.cat([img_embed, text_embed], dim=1)  # [B,1+L,d]
        x = self.encoder(x)                            # fuse via self-attention
        return x

# ============================================================== #
# MULTIMODAL MODEL (SIGLIP + MBART + LORA + FUSION)
# ============================================================== #

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()

        # SigLIP vision encoder (vision-only)
        print(f"üîÑ Loading SigLIP vision model: {config.vision_model_name}")
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)

        # Freeze SigLIP to save memory & compute
        for p in self.vision.parameters():
            p.requires_grad = False

        # SigLIP vision hidden size
        vision_dim = self.vision.config.hidden_size
        print("üìê SigLIP vision hidden size:", vision_dim)

        # mBART-50 text model
        print("üîÑ Loading mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )

        # Apply LoRA on mBART
        self.mbart = apply_lora_to_mbart(base_mbart)

        # Shared text embeddings (LoRA-safe)
        self.text_emb = self.mbart.get_input_embeddings()

        # Project SigLIP CLS ‚Üí mBART hidden size
        self.proj = nn.Linear(vision_dim, self.mbart.config.d_model)

        # Fusion block
        self.fusion = FusionBlock(d_model=self.mbart.config.d_model, nhead=8, dim_ff=2048, dropout=0.1)

    def forward(self, input_ids, attention_mask, pixel_values, labels=None):
        batch_size = input_ids.size(0)

        # 1) SigLIP image features (CLS token)
        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]   # [B, hidden_dim]

        img_embed = self.proj(img_feat).unsqueeze(1)               # [B,1,d_model]

        # 2) Text embeddings from mBART shared embedding matrix
        text_embed = self.text_emb(input_ids)                      # [B,L,d_model]

        # 3) Transformer-based fusion
        fused = self.fusion(img_embed, text_embed)                 # [B,1+L,d_model]

        # 4) Attention mask (add image token)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=device), attention_mask],
            dim=1,
        )

        # 5) mBART forward using inputs_embeds
        outputs = self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels,
            return_dict=True,
        )
        return outputs

    def generate(self, input_ids, attention_mask, pixel_values, tokenizer):
        batch_size = input_ids.size(0)

        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]

        img_embed = self.proj(img_feat).unsqueeze(1)
        text_embed = self.text_emb(input_ids)

        fused = self.fusion(img_embed, text_embed)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=device), attention_mask],
            dim=1,
        )

        gen_ids = self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=config.max_length,
            num_beams=3,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ============================================================== #
# TEXT-ONLY MODEL (MBART + LORA)
# ============================================================== #

class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        print("üîÑ Loading text-only mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base_mbart)

    def forward(self, input_ids, attention_mask, labels=None):
        return self.mbart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
        )

    def generate(self, input_ids, attention_mask, tokenizer):
        gen_ids = self.mbart.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=config.max_length,
            num_beams=3,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ============================================================== #
# DATASETS
# ============================================================== #

class MultiModalDataset(Dataset):
    def __init__(self, image_ids, src, tgt, tokenizer, image_processor, img_root):
        self.ids = image_ids
        self.src = src
        self.tgt = tgt
        self.tok = tokenizer
        self.img_proc = image_processor
        self.img_root = img_root

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        src = self.src[idx]
        tgt = self.tgt[idx]

        enc = self.tok(
            src,
            max_length=config.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt,
                max_length=config.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

        labels = dec["input_ids"].squeeze()
        labels[labels == self.tok.pad_token_id] = -100

        img = safe_load_image(img_id, self.img_root)
        pv = self.img_proc(images=img, return_tensors="pt")["pixel_values"].squeeze()

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": labels,
            "pixel_values": pv,
            "target_text": tgt,
        }

class TextOnlyDataset(Dataset):
    def __init__(self, src, tgt, tokenizer):
        self.src = src
        self.tgt = tgt
        self.tok = tokenizer

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src = self.src[idx]
        tgt = self.tgt[idx]

        enc = self.tok(
            src,
            max_length=config.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt,
                max_length=config.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

        labels = dec["input_ids"].squeeze()
        labels[labels == self.tok.pad_token_id] = -100

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": labels,
            "target_text": tgt,
        }

# ============================================================== #
# LOAD SPLIT
# ============================================================== #

def load_split(root, split, src_lang, tgt_lang, limit):
    """
    Expects:
      root/data/task1/raw/{split}/{split}.{lang}
      root/data/task1/image_splits/{split}.txt
    """
    root = Path(root)
    raw = root / "data" / "task1" / "raw" / split
    id_file = root / "data" / "task1" / "image_splits" / f"{split}.txt"

    src_file = raw / f"{split}.{src_lang}"
    tgt_file = raw / f"{split}.{tgt_lang}"

    print(f"üîé Checking files for {split} {src_lang}‚Üí{tgt_lang}")
    print("   ", src_file)
    print("   ", tgt_file)
    print("   ", id_file)

    if not src_file.exists() or not tgt_file.exists() or not id_file.exists():
        print(f"‚ùå Missing one or more files for {split} ({src_lang}‚Üí{tgt_lang})")
        return [], [], []

    ids = [l.strip() for l in open(id_file, encoding="utf-8") if l.strip()]
    src = [l.strip() for l in open(src_file, encoding="utf-8") if l.strip()]
    tgt = [l.strip() for l in open(tgt_file, encoding="utf-8") if l.strip()]

    n = min(len(ids), len(src), len(tgt), limit)
    print(f"‚úÖ Loaded {n} samples ({split}: {src_lang}‚Üí{tgt_lang})")
    return ids[:n], src[:n], tgt[:n]

# ============================================================== #
# TRAINING + EVAL HELPERS
# ============================================================== #

def compute_bleu_multimodal(model, loader, tokenizer):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            pv = batch["pixel_values"].to(device)
            tgt = batch["target_text"]

            gen_ids = model.generate(ids, mask, pv, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            preds.extend(decoded)
            refs.extend([[t] for t in tgt])

    bleu = sacrebleu_metric.compute(predictions=preds, references=refs)["score"]
    print(f"   üîµ Multimodal BLEU: {bleu:.2f}")
    return bleu

def compute_bleu_text(model, loader, tokenizer):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            tgt = batch["target_text"]

            gen_ids = model.generate(ids, mask, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            preds.extend(decoded)
            refs.extend([[t] for t in tgt])

    bleu = sacrebleu_metric.compute(predictions=preds, references=refs)["score"]
    print(f"   üîµ Text-only BLEU: {bleu:.2f}")
    return bleu

# ------------------ TRAIN MULTIMODAL ------------------ #

def train_multimodal_model(src_lang, tgt_lang, tokenizer, train_ds, val_ds):
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size)

    model = MultiModalModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max(total_steps, 1),
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None
    best_bleu = 0.0
    no_improve = 0

    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [MULTIMODAL] Epoch {epoch}/{config.num_epochs} ‚Äî {src_lang}‚Üí{tgt_lang}")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc=f"[MM Train {src_lang}->{tgt_lang}]")
        for batch in loop:
            opt.zero_grad()

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            lbl = batch["labels"].to(device)
            pv = batch["pixel_values"].to(device)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, pv, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, pv, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (multimodal)")
            break

        avg_loss = total_loss / max(len(train_loader), 1)
        print(f"   üîª Multimodal avg train loss: {avg_loss:.4f}")

        print("   üîç Evaluating multimodal on validation...")
        bleu = compute_bleu_multimodal(model, val_loader, tokenizer)

        improved = bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = bleu
            no_improve = 0
            save_dir = Path(config.save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)
            save_path = save_dir / f"siglip_fusion_lora_{src_lang}_{tgt_lang}_mm_best.pt"
            torch.save(model.state_dict(), save_path)
            print(f"   üíæ Saved best MULTIMODAL model ‚Üí {save_path}")
        else:
            no_improve += 1

        if no_improve >= config.patience:
            print(f"üõë Early stopping MULTIMODAL {src_lang}‚Üí{tgt_lang} at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished MULTIMODAL training {src_lang}‚Üí{tgt_lang} | Best BLEU: {best_bleu:.2f}")
    return best_bleu

# ------------------ TRAIN TEXT-ONLY ------------------ #

def train_text_model(src_lang, tgt_lang, tokenizer, train_ds, val_ds):
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size)

    model = TextOnlyModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max(total_steps, 1),
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None
    best_bleu = 0.0
    no_improve = 0

    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [TEXT-ONLY] Epoch {epoch}/{config.num_epochs} ‚Äî {src_lang}‚Üí{tgt_lang}")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc=f"[TXT Train {src_lang}->{tgt_lang}]")
        for batch in loop:
            opt.zero_grad()

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            lbl = batch["labels"].to(device)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (text-only)")
            break

        avg_loss = total_loss / max(len(train_loader), 1)
        print(f"   üîª Text-only avg train loss: {avg_loss:.4f}")

        print("   üîç Evaluating TEXT-ONLY on validation...")
        bleu = compute_bleu_text(model, val_loader, tokenizer)

        improved = bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = bleu
            no_improve = 0
            save_dir = Path(config.save_dir)
            save_dir.mkdir(parents=True, exist_ok=True)
            save_path = save_dir / f"mbart_lora_{src_lang}_{tgt_lang}_text_best.pt"
            torch.save(model.state_dict(), save_path)
            print(f"   üíæ Saved best TEXT-ONLY model ‚Üí {save_path}")
        else:
            no_improve += 1

        if no_improve >= config.patience:
            print(f"üõë Early stopping TEXT-ONLY {src_lang}‚Üí{tgt_lang} at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished TEXT-ONLY training {src_lang}‚Üí{tgt_lang} | Best BLEU: {best_bleu:.2f}")
    return best_bleu

# ============================================================== #
# MAIN
# ============================================================== #

def main():
    # Optional: remap data_root via Drive symlink
    if os.path.exists("/content/drive"):
        config.data_root = mount_and_link_dataset()

    os.makedirs(config.save_dir, exist_ok=True)

    print("üîÑ Loading MBart tokenizer & SigLIP processor...")
    tokenizer = MBart50TokenizerFast.from_pretrained(
        "facebook/mbart-large-50-many-to-many-mmt"
    )
    image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

    # Save config
    cfg_path = Path(config.save_dir) / "config_siglip_fusion_lora.json"
    with open(cfg_path, "w") as f:
        json.dump(asdict(config), f, indent=2)
    print(f"üíæ Config saved at: {cfg_path}")

    results_multimodal = {}
    results_textonly = {}

    for src, tgt in config.directions:
        print("\n======================================================================")
        print(f"üèÅ LANGUAGE PAIR: {src.upper()} ‚Üí {tgt.upper()}")
        print("======================================================================")

        tokenizer.src_lang = LANG_CODES[src]
        tokenizer.tgt_lang = LANG_CODES[tgt]

        train_ids, train_src, train_tgt = load_split(
            config.data_root, "train", src, tgt, config.max_train_samples
        )
        val_ids, val_src, val_tgt = load_split(
            config.data_root, "val", src, tgt, config.max_val_samples
        )

        if len(train_ids) == 0:
            print(f"‚ö†Ô∏è Skipping {src}‚Üí{tgt} (no data)")
            continue

        img_root = Path(config.data_root) / config.image_dir

        # Datasets
        train_mm = MultiModalDataset(
            train_ids, train_src, train_tgt,
            tokenizer, image_processor, img_root
        )
        val_mm = MultiModalDataset(
            val_ids, val_src, val_tgt,
            tokenizer, image_processor, img_root
        )

        train_txt = TextOnlyDataset(train_src, train_tgt, tokenizer)
        val_txt = TextOnlyDataset(val_src, val_tgt, tokenizer)

        # ----- Train MULTIMODAL -----
        mm_bleu = train_multimodal_model(src, tgt, tokenizer, train_mm, val_mm)
        results_multimodal[f"{src}_{tgt}"] = mm_bleu

        # ----- Train TEXT-ONLY -----
        txt_bleu = train_text_model(src, tgt, tokenizer, train_txt, val_txt)
        results_textonly[f"{src}_{tgt}"] = txt_bleu

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\nüìä FINAL BLEU SCORES (MULTIMODAL):")
    for k, v in results_multimodal.items():
        print(f"  {k}: {v:.2f}")

    print("\nüìä FINAL BLEU SCORES (TEXT-ONLY):")
    for k, v in results_textonly.items():
        print(f"  {k}: {v:.2f}")

if __name__ == "__main__":
    main()


Using device: cuda
üîó Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Found dataset at: /content/drive/MyDrive/dataset/multi30k-dataset
üîó Symlink created ‚Üí /content/multi30k-dataset
üîÑ Loading MBart tokenizer & SigLIP processor...
üíæ Config saved at: /content/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora.json

üèÅ LANGUAGE PAIR: EN ‚Üí DE
üîé Checking files for train en‚Üíde
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: en‚Üíde)
üîé Checking files for val en‚Üíde
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/raw/val/val.de
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: en‚Ü

[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [3:22:38<00:00,  1.62s/it, loss=0.7]


   üîª Multimodal avg train loss: 0.9370
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.50
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:52<00:00,  6.29it/s, loss=0.762]


   üîª Multimodal avg train loss: 0.8312
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 41.21
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:52<00:00,  6.29it/s, loss=1]


   üîª Multimodal avg train loss: 0.7771
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.42
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:53<00:00,  6.28it/s, loss=1.4]


   üîª Multimodal avg train loss: 0.7317
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.24

üìç [MULTIMODAL] Epoch 5/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:55<00:00,  6.27it/s, loss=0.207]


   üîª Multimodal avg train loss: 0.6974
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.21

üìç [MULTIMODAL] Epoch 6/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:01<00:00,  6.24it/s, loss=0.551]


   üîª Multimodal avg train loss: 0.6717
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.29
üõë Early stopping MULTIMODAL en‚Üíde at epoch 6
‚úÖ Finished MULTIMODAL training en‚Üíde | Best BLEU: 42.42
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:50<00:00,  8.42it/s, loss=0.62]


   üîª Text-only avg train loss: 0.9554
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 39.27
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:51<00:00,  8.41it/s, loss=1.41]


   üîª Text-only avg train loss: 0.8746
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 39.82
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:49<00:00,  8.43it/s, loss=1.18]


   üîª Text-only avg train loss: 0.8433
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.31

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:51<00:00,  8.42it/s, loss=0.784]


   üîª Text-only avg train loss: 0.8187
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.86
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:47<00:00,  8.45it/s, loss=1.2]


   üîª Text-only avg train loss: 0.8049
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.74

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:47<00:00,  8.45it/s, loss=0.534]


   üîª Text-only avg train loss: 0.7961
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.99
‚úÖ Finished TEXT-ONLY training en‚Üíde | Best BLEU: 40.86

üèÅ LANGUAGE PAIR: EN ‚Üí FR
üîé Checking files for train en‚Üífr
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: en‚Üífr)
üîé Checking files for val en‚Üífr
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: en‚Üífr)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:59<00:00,  6.25it/s, loss=0.675]


   üîª Multimodal avg train loss: 0.8600
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 51.84
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:09<00:00,  6.20it/s, loss=1.16]


   üîª Multimodal avg train loss: 0.7128
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 53.28
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:20<00:00,  6.14it/s, loss=0.328]


   üîª Multimodal avg train loss: 0.6471
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 55.12
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:18<00:00,  6.15it/s, loss=0.393]


   üîª Multimodal avg train loss: 0.6008
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 55.60

üìç [MULTIMODAL] Epoch 5/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:20<00:00,  6.14it/s, loss=0.409]


   üîª Multimodal avg train loss: 0.5661
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 56.18
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 6/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:20<00:00,  6.15it/s, loss=0.254]


   üîª Multimodal avg train loss: 0.5416
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 56.59
‚úÖ Finished MULTIMODAL training en‚Üífr | Best BLEU: 56.18
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:01<00:00,  8.32it/s, loss=1.15]


   üîª Text-only avg train loss: 0.9008
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 48.44
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:02<00:00,  8.31it/s, loss=0.896]


   üîª Text-only avg train loss: 0.7781
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 50.03
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:04<00:00,  8.30it/s, loss=0.97]


   üîª Text-only avg train loss: 0.7376
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 51.69
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:00<00:00,  8.33it/s, loss=0.66]


   üîª Text-only avg train loss: 0.7110
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 52.38
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:02<00:00,  8.31it/s, loss=0.679]


   üîª Text-only avg train loss: 0.6925
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 52.84

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:59<00:00,  8.34it/s, loss=0.194]


   üîª Text-only avg train loss: 0.6820
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 52.86
‚úÖ Finished TEXT-ONLY training en‚Üífr | Best BLEU: 52.38

üèÅ LANGUAGE PAIR: DE ‚Üí EN
üîé Checking files for train de‚Üíen
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: de‚Üíen)
üîé Checking files for val de‚Üíen
    /content/multi30k-dataset/data/task1/raw/val/val.de
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: de‚Üíen)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:12<00:00,  6.18it/s, loss=0.609]


   üîª Multimodal avg train loss: 0.8808
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 44.66
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:11<00:00,  6.19it/s, loss=1.06]


   üîª Multimodal avg train loss: 0.7947
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 46.41
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:07<00:00,  6.21it/s, loss=0.716]


   üîª Multimodal avg train loss: 0.7447
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 46.24

üìç [MULTIMODAL] Epoch 4/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:59<00:00,  6.25it/s, loss=0.99]


   üîª Multimodal avg train loss: 0.7030
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 46.45

üìç [MULTIMODAL] Epoch 5/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:54<00:00,  6.28it/s, loss=0.674]


   üîª Multimodal avg train loss: 0.6704
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 46.90
üõë Early stopping MULTIMODAL de‚Üíen at epoch 5
‚úÖ Finished MULTIMODAL training de‚Üíen | Best BLEU: 46.41
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:02<00:00,  8.31it/s, loss=0.363]


   üîª Text-only avg train loss: 0.8824
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 44.43
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:45<00:00,  8.47it/s, loss=1.8]


   üîª Text-only avg train loss: 0.8264
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 45.25
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [15:00<00:00,  8.33it/s, loss=0.391]


   üîª Text-only avg train loss: 0.8019
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 45.64

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:50<00:00,  8.42it/s, loss=0.466]


   üîª Text-only avg train loss: 0.7843
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 45.78
   üíæ Saved best TEXT-ONLY model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:44<00:00,  8.48it/s, loss=1.07]


   üîª Text-only avg train loss: 0.7741
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 45.76

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:56<00:00,  8.37it/s, loss=0.269]


   üîª Text-only avg train loss: 0.7661
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 46.09
‚úÖ Finished TEXT-ONLY training de‚Üíen | Best BLEU: 45.78

üèÅ LANGUAGE PAIR: DE ‚Üí FR
üîé Checking files for train de‚Üífr
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: de‚Üífr)
üîé Checking files for val de‚Üífr
    /content/multi30k-dataset/data/task1/raw/val/val.de
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: de‚Üífr)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:00<00:00,  6.25it/s, loss=1.04]


   üîª Multimodal avg train loss: 1.2171
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 35.63
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:57<00:00,  6.26it/s, loss=0.875]


   üîª Multimodal avg train loss: 1.0412
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 39.27
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:59<00:00,  6.25it/s, loss=1.04]


   üîª Multimodal avg train loss: 0.9676
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.63
   üíæ Saved best MULTIMODAL model ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:02<00:00,  6.24it/s, loss=0.404]


   üîª Multimodal avg train loss: 0.9151
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 41.02

üìç [MULTIMODAL] Epoch 5/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:01<00:00,  6.24it/s, loss=0.484]


   üîª Multimodal avg train loss: 0.8727
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 41.01

üìç [MULTIMODAL] Epoch 6/6 ‚Äî de‚Üífr


[MM Train de->fr]:  68%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 5083/7500 [13:32<06:34,  6.13it/s, loss=0.662]

In [None]:
# ==============================================================
# üåç MULTIMODAL TRANSLATION + SIGLIP + LORA + FUSION vs TEXT-ONLY
#  - Multi30K (data/task1/raw + image_splits)
#  - SigLIP vision encoder (google/siglip-base-patch16-224)
#  - mBART-50 text model with LoRA on attention (q_proj, v_proj)
#  - Better fusion: Transformer-based fusion over [IMG + TEXT]
#  - Also trains text-only mBART+LoRA baseline for comparison
#  - Saves models in BOTH Colab and Google Drive
# ==============================================================

import os
import json
from pathlib import Path
from typing import List, Tuple
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFile
from tqdm import tqdm
import evaluate
import warnings

warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------ HF + PEFT imports ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    os.system("pip install -q transformers peft accelerate")
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# ------------------ DEVICE ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

# ============================================================== #
# CONFIG
# ============================================================== #

@dataclass
class Config:
    # Paths
    data_root: str = "/content/multi30k-dataset"
    image_dir: str = "flickr30k-images"

    # Local Colab save dir
    save_dir: str = "/content/multimodal_translation_models_siglip_lora_fusion"
    # Drive save dir (we'll ensure it exists if Drive is mounted)
    drive_save_dir: str = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"

    # Training
    max_length: int = 64
    batch_size: int = 2          # small for VRAM safety
    learning_rate: float = 3e-5
    num_epochs: int = 6
    patience: int = 3
    min_delta: float = 0.5       # BLEU improvement threshold to reset patience
    use_amp: bool = True

    # Data limits
    max_train_samples: int = 15000
    max_val_samples: int = 1000

    # Optim
    warmup_steps: int = 100
    max_grad_norm: float = 1.0

    # Vision model (SigLIP)
    vision_model_name: str = "google/siglip-base-patch16-224"

    # LoRA
    use_lora: bool = True
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1
    lora_targets: List[str] = None

    # Language directions
    directions: List[Tuple[str, str]] = None

    def __post_init__(self):
        if self.lora_targets is None:
            # Attn projections in mBART encoder+decoder
            self.lora_targets = ["q_proj", "v_proj"]
        if self.directions is None:
            self.directions = [
                ("en", "de"),
                ("en", "fr"),
                ("de", "en"),
                ("de", "fr"),
                ("fr", "en"),
                ("fr", "de"),
            ]

config = Config()

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# Global metric (avoid re-loading each epoch)
sacrebleu_metric = evaluate.load("sacrebleu")

# ============================================================== #
# OPTIONAL: DRIVE MOUNT + SYMLINK (if you want)
# ============================================================== #

def mount_and_link_dataset():
    """
    Mounts Google Drive and links /content/multi30k-dataset to your folder.
    Safe: no directory scanning, just checks existence.
    """
    try:
        from google.colab import drive
    except Exception:
        print("‚ÑπÔ∏è Not running in Colab / no google.colab, skipping mount.")
        return config.data_root

    print("üîó Mounting Google Drive...")
    drive.mount("/content/drive")

    candidate_paths = [
        "/content/drive/MyDrive/multi30k-dataset",
        "/content/drive/MyDrive/dataset/multi30k-dataset",
        "/content/drive/MyDrive/Colab Notebooks/multi30k-dataset",
    ]

    dataset_path = None
    for p in candidate_paths:
        if os.path.exists(p):
            dataset_path = p
            print(f"‚úÖ Found dataset at: {p}")
            break

    if dataset_path is None:
        print("‚ùå Multi30K dataset not found in default locations. Using existing:", config.data_root)
        return config.data_root

    if os.path.islink("/content/multi30k-dataset") or os.path.exists("/content/multi30k-dataset"):
        os.system("rm -rf /content/multi30k-dataset")

    os.symlink(dataset_path, "/content/multi30k-dataset")
    print("üîó Symlink created ‚Üí /content/multi30k-dataset")
    return "/content/multi30k-dataset"

# ============================================================== #
# IMAGE LOADER (NO DIR LISTING)
# ============================================================== #

def safe_load_image(image_id: str, root: Path) -> Image.Image:
    """
    Loads one image by ID without listing directories.
    Multi30K image IDs in image_splits are usually like "1234567890.jpg" or "1234567890".
    We try: id, id.jpg, id.jpeg, id.png.
    """
    base = image_id.strip()
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[: -len(ext)]
            break

    candidates = [
        f"{base}.jpg",
        f"{base}.jpeg",
        f"{base}.png",
        base,
    ]

    for name in candidates:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except Exception:
                pass

    # Fallback: dummy gray image
    return Image.new("RGB", (224, 224), (128, 128, 128))

# ============================================================== #
# LORA HELPER
# ============================================================== #

def apply_lora_to_mbart(mbart: MBartForConditionalGeneration) -> MBartForConditionalGeneration:
    """
    Wraps mBART with LoRA on attention projections.
    """
    if not config.use_lora:
        print("‚ÑπÔ∏è LoRA disabled; training full mBART (heavier).")
        return mbart

    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    peft_model = get_peft_model(mbart, lora_cfg)
    print("‚úÖ LoRA applied to mBART (targets:", config.lora_targets, ")")
    peft_model.print_trainable_parameters()
    return peft_model

# ============================================================== #
# FUSION BLOCK (BETTER THAN PLAIN CONCAT)
# ============================================================== #

class FusionBlock(nn.Module):
    """
    Simple Transformer-based fusion over [IMG_TOKEN + TEXT_TOKENS].
    Lets the image token attend to text and vice versa.
    """
    def __init__(self, d_model: int, nhead: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=1)

    def forward(self, img_embed: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
        """
        img_embed: [B,1,d_model]
        text_embed: [B,L,d_model]
        returns fused: [B,1+L,d_model]
        """
        x = torch.cat([img_embed, text_embed], dim=1)  # [B,1+L,d]
        x = self.encoder(x)                            # fuse via self-attention
        return x

# ============================================================== #
# MULTIMODAL MODEL (SIGLIP + MBART + LORA + FUSION)
# ============================================================== #

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()

        # SigLIP vision encoder (vision-only)
        print(f"üîÑ Loading SigLIP vision model: {config.vision_model_name}")
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)

        # Freeze SigLIP to save memory & compute
        for p in self.vision.parameters():
            p.requires_grad = False

        # SigLIP vision hidden size
        vision_dim = self.vision.config.hidden_size
        print("üìê SigLIP vision hidden size:", vision_dim)

        # mBART-50 text model
        print("üîÑ Loading mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )

        # Apply LoRA on mBART
        self.mbart = apply_lora_to_mbart(base_mbart)

        # Shared text embeddings (LoRA-safe)
        self.text_emb = self.mbart.get_input_embeddings()

        # Project SigLIP CLS ‚Üí mBART hidden size
        self.proj = nn.Linear(vision_dim, self.mbart.config.d_model)

        # Fusion block
        self.fusion = FusionBlock(d_model=self.mbart.config.d_model, nhead=8, dim_ff=2048, dropout=0.1)

    def forward(self, input_ids, attention_mask, pixel_values, labels=None):
        batch_size = input_ids.size(0)

        # 1) SigLIP image features (CLS token)
        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]   # [B, hidden_dim]

        img_embed = self.proj(img_feat).unsqueeze(1)               # [B,1,d_model]

        # 2) Text embeddings from mBART shared embedding matrix
        text_embed = self.text_emb(input_ids)                      # [B,L,d_model]

        # 3) Transformer-based fusion
        fused = self.fusion(img_embed, text_embed)                 # [B,1+L,d_model]

        # 4) Attention mask (add image token)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=device), attention_mask],
            dim=1,
        )

        # 5) mBART forward using inputs_embeds
        outputs = self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels,
            return_dict=True,
        )
        return outputs

    def generate(self, input_ids, attention_mask, pixel_values, tokenizer):
        batch_size = input_ids.size(0)

        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]

        img_embed = self.proj(img_feat).unsqueeze(1)
        text_embed = self.text_emb(input_ids)

        fused = self.fusion(img_embed, text_embed)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=device), attention_mask],
            dim=1,
        )

        gen_ids = self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=config.max_length,
            num_beams=3,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ============================================================== #
# TEXT-ONLY MODEL (MBART + LORA)
# ============================================================== #

class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        print("üîÑ Loading text-only mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base_mbart)

    def forward(self, input_ids, attention_mask, labels=None):
        return self.mbart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
        )

    def generate(self, input_ids, attention_mask, tokenizer):
        gen_ids = self.mbart.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=config.max_length,
            num_beams=3,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ============================================================== #
# DATASETS
# ============================================================== #

class MultiModalDataset(Dataset):
    def __init__(self, image_ids, src, tgt, tokenizer, image_processor, img_root):
        self.ids = image_ids
        self.src = src
        self.tgt = tgt
        self.tok = tokenizer
        self.img_proc = image_processor
        self.img_root = img_root

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        src = self.src[idx]
        tgt = self.tgt[idx]

        enc = self.tok(
            src,
            max_length=config.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt,
                max_length=config.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

        labels = dec["input_ids"].squeeze()
        labels[labels == self.tok.pad_token_id] = -100

        img = safe_load_image(img_id, self.img_root)
        pv = self.img_proc(images=img, return_tensors="pt")["pixel_values"].squeeze()

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": labels,
            "pixel_values": pv,
            "target_text": tgt,
        }

class TextOnlyDataset(Dataset):
    def __init__(self, src, tgt, tokenizer):
        self.src = src
        self.tgt = tgt
        self.tok = tokenizer

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src = self.src[idx]
        tgt = self.tgt[idx]

        enc = self.tok(
            src,
            max_length=config.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt,
                max_length=config.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

        labels = dec["input_ids"].squeeze()
        labels[labels == self.tok.pad_token_id] = -100

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": labels,
            "target_text": tgt,
        }

# ============================================================== #
# LOAD SPLIT
# ============================================================== #

def load_split(root, split, src_lang, tgt_lang, limit):
    """
    Expects:
      root/data/task1/raw/{split}/{split}.{lang}
      root/data/task1/image_splits/{split}.txt
    """
    root = Path(root)
    raw = root / "data" / "task1" / "raw" / split
    id_file = root / "data" / "task1" / "image_splits" / f"{split}.txt"

    src_file = raw / f"{split}.{src_lang}"
    tgt_file = raw / f"{split}.{tgt_lang}"

    print(f"üîé Checking files for {split} {src_lang}‚Üí{tgt_lang}")
    print("   ", src_file)
    print("   ", tgt_file)
    print("   ", id_file)

    if not src_file.exists() or not tgt_file.exists() or not id_file.exists():
        print(f"‚ùå Missing one or more files for {split} ({src_lang}‚Üí{tgt_lang})")
        return [], [], []

    ids = [l.strip() for l in open(id_file, encoding="utf-8") if l.strip()]
    src = [l.strip() for l in open(src_file, encoding="utf-8") if l.strip()]
    tgt = [l.strip() for l in open(tgt_file, encoding="utf-8") if l.strip()]

    n = min(len(ids), len(src), len(tgt), limit)
    print(f"‚úÖ Loaded {n} samples ({split}: {src_lang}‚Üí{tgt_lang})")
    return ids[:n], src[:n], tgt[:n]

# ============================================================== #
# TRAINING + EVAL HELPERS
# ============================================================== #

def compute_bleu_multimodal(model, loader, tokenizer):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            pv = batch["pixel_values"].to(device)
            tgt = batch["target_text"]

            gen_ids = model.generate(ids, mask, pv, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            preds.extend(decoded)
            refs.extend([[t] for t in tgt])

    bleu = sacrebleu_metric.compute(predictions=preds, references=refs)["score"]
    print(f"   üîµ Multimodal BLEU: {bleu:.2f}")
    return bleu

def compute_bleu_text(model, loader, tokenizer):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            tgt = batch["target_text"]

            gen_ids = model.generate(ids, mask, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            preds.extend(decoded)
            refs.extend([[t] for t in tgt])

    bleu = sacrebleu_metric.compute(predictions=preds, references=refs)["score"]
    print(f"   üîµ Text-only BLEU: {bleu:.2f}")
    return bleu

# ------------------ TRAIN MULTIMODAL ------------------ #

def train_multimodal_model(src_lang, tgt_lang, tokenizer, train_ds, val_ds,
                           local_save_dir: Path, drive_save_dir: Path | None):
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size)

    model = MultiModalModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max(total_steps, 1),
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None
    best_bleu = 0.0
    no_improve = 0

    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [MULTIMODAL] Epoch {epoch}/{config.num_epochs} ‚Äî {src_lang}‚Üí{tgt_lang}")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc=f"[MM Train {src_lang}->{tgt_lang}]")
        for batch in loop:
            opt.zero_grad()

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            lbl = batch["labels"].to(device)
            pv = batch["pixel_values"].to(device)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, pv, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, pv, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (multimodal)")
            break

        avg_loss = total_loss / max(len(train_loader), 1)
        print(f"   üîª Multimodal avg train loss: {avg_loss:.4f}")

        print("   üîç Evaluating multimodal on validation...")
        bleu = compute_bleu_multimodal(model, val_loader, tokenizer)

        improved = bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = bleu
            no_improve = 0

            filename = f"siglip_fusion_lora_{src_lang}_{tgt_lang}_mm_best.pt"
            local_path = local_save_dir / filename
            torch.save(model.state_dict(), local_path)
            print(f"   üíæ Saved best MULTIMODAL model (local) ‚Üí {local_path}")

            if drive_save_dir is not None:
                drive_path = drive_save_dir / filename
                torch.save(model.state_dict(), drive_path)
                print(f"   üíæ Saved best MULTIMODAL model (drive) ‚Üí {drive_path}")

        else:
            no_improve += 1

        if no_improve >= config.patience:
            print(f"üõë Early stopping MULTIMODAL {src_lang}‚Üí{tgt_lang} at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished MULTIMODAL training {src_lang}‚Üí{tgt_lang} | Best BLEU: {best_bleu:.2f}")
    return best_bleu

# ------------------ TRAIN TEXT-ONLY ------------------ #

def train_text_model(src_lang, tgt_lang, tokenizer, train_ds, val_ds,
                     local_save_dir: Path, drive_save_dir: Path | None):
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size)

    model = TextOnlyModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max(total_steps, 1),
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None
    best_bleu = 0.0
    no_improve = 0

    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [TEXT-ONLY] Epoch {epoch}/{config.num_epochs} ‚Äî {src_lang}‚Üí{tgt_lang}")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc=f"[TXT Train {src_lang}->{tgt_lang}]")
        for batch in loop:
            opt.zero_grad()

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            lbl = batch["labels"].to(device)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (text-only)")
            break

        avg_loss = total_loss / max(len(train_loader), 1)
        print(f"   üîª Text-only avg train loss: {avg_loss:.4f}")

        print("   üîç Evaluating TEXT-ONLY on validation...")
        bleu = compute_bleu_text(model, val_loader, tokenizer)

        improved = bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = bleu
            no_improve = 0

            filename = f"mbart_lora_{src_lang}_{tgt_lang}_text_best.pt"
            local_path = local_save_dir / filename
            torch.save(model.state_dict(), local_path)
            print(f"   üíæ Saved best TEXT-ONLY model (local) ‚Üí {local_path}")

            if drive_save_dir is not None:
                drive_path = drive_save_dir / filename
                torch.save(model.state_dict(), drive_path)
                print(f"   üíæ Saved best TEXT-ONLY model (drive) ‚Üí {drive_path}")

        else:
            no_improve += 1

        if no_improve >= config.patience:
            print(f"üõë Early stopping TEXT-ONLY {src_lang}‚Üí{tgt_lang} at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished TEXT-ONLY training {src_lang}‚Üí{tgt_lang} | Best BLEU: {best_bleu:.2f}")
    return best_bleu

# ============================================================== #
# MAIN
# ============================================================== #

def main():
    # Optional: remap data_root via Drive symlink
    if os.path.exists("/content/drive"):
        config.data_root = mount_and_link_dataset()

    # Local save dir (Colab)
    local_save_dir = Path(config.save_dir)
    local_save_dir.mkdir(parents=True, exist_ok=True)

    # Drive save dir (if Drive is mounted)
    drive_save_dir = None
    drive_root = Path("/content/drive/MyDrive")
    if drive_root.exists():
        drive_save_dir = Path(config.drive_save_dir)
        drive_save_dir.mkdir(parents=True, exist_ok=True)
        print(f"üíæ Drive save dir: {drive_save_dir}")
    else:
        print("‚ö†Ô∏è Drive not mounted or /content/drive/MyDrive missing; will only save locally.")

    print("üîÑ Loading MBart tokenizer & SigLIP processor...")
    tokenizer = MBart50TokenizerFast.from_pretrained(
        "facebook/mbart-large-50-many-to-many-mmt"
    )
    image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

    # Save config (local + drive)
    cfg = asdict(config)
    cfg_path_local = local_save_dir / "config_siglip_fusion_lora.json"
    with open(cfg_path_local, "w") as f:
        json.dump(cfg, f, indent=2)
    print(f"üíæ Config saved (local) at: {cfg_path_local}")

    if drive_save_dir is not None:
        cfg_path_drive = drive_save_dir / "config_siglip_fusion_lora.json"
        with open(cfg_path_drive, "w") as f:
            json.dump(cfg, f, indent=2)
        print(f"üíæ Config saved (drive) at: {cfg_path_drive}")

    results_multimodal = {}
    results_textonly = {}

    for src, tgt in config.directions:
        print("\n======================================================================")
        print(f"üèÅ LANGUAGE PAIR: {src.upper()} ‚Üí {tgt.upper()}")
        print("======================================================================")

        tokenizer.src_lang = LANG_CODES[src]
        tokenizer.tgt_lang = LANG_CODES[tgt]

        train_ids, train_src, train_tgt = load_split(
            config.data_root, "train", src, tgt, config.max_train_samples
        )
        val_ids, val_src, val_tgt = load_split(
            config.data_root, "val", src, tgt, config.max_val_samples
        )

        if len(train_ids) == 0:
            print(f"‚ö†Ô∏è Skipping {src}‚Üí{tgt} (no data)")
            continue

        img_root = Path(config.data_root) / config.image_dir

        # Datasets
        train_mm = MultiModalDataset(
            train_ids, train_src, train_tgt,
            tokenizer, image_processor, img_root
        )
        val_mm = MultiModalDataset(
            val_ids, val_src, val_tgt,
            tokenizer, image_processor, img_root
        )

        train_txt = TextOnlyDataset(train_src, train_tgt, tokenizer)
        val_txt = TextOnlyDataset(val_src, val_tgt, tokenizer)

        # ----- Train MULTIMODAL -----
        mm_bleu = train_multimodal_model(
            src, tgt, tokenizer, train_mm, val_mm,
            local_save_dir=local_save_dir,
            drive_save_dir=drive_save_dir,
        )
        results_multimodal[f"{src}_{tgt}"] = mm_bleu

        # ----- Train TEXT-ONLY -----
        txt_bleu = train_text_model(
            src, tgt, tokenizer, train_txt, val_txt,
            local_save_dir=local_save_dir,
            drive_save_dir=drive_save_dir,
        )
        results_textonly[f"{src}_{tgt}"] = txt_bleu

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\nüìä FINAL BLEU SCORES (MULTIMODAL):")
    for k, v in results_multimodal.items():
        print(f"  {k}: {v:.2f}")

    print("\nüìä FINAL BLEU SCORES (TEXT-ONLY):")
    for k, v in results_textonly.items():
        print(f"  {k}: {v:.2f}")

if __name__ == "__main__":
    main()


Using device: cuda
üîó Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Found dataset at: /content/drive/MyDrive/dataset/multi30k-dataset
üîó Symlink created ‚Üí /content/multi30k-dataset
üíæ Drive save dir: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion
üîÑ Loading MBart tokenizer & SigLIP processor...
üíæ Config saved (local) at: /content/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora.json
üíæ Config saved (drive) at: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora.json

üèÅ LANGUAGE PAIR: EN ‚Üí DE
üîé Checking files for train en‚Üíde
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: en‚Üíde)
üîé Check

[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [3:33:36<00:00,  1.71s/it, loss=1.92]


   üîª Multimodal avg train loss: 0.9342
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.08
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:50<00:00,  6.30it/s, loss=0.53]


   üîª Multimodal avg train loss: 0.8313
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.59
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:12<00:00,  6.18it/s, loss=0.963]


   üîª Multimodal avg train loss: 0.7708
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 41.38
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:13<00:00,  6.18it/s, loss=0.562]


   üîª Multimodal avg train loss: 0.7269
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.00
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 5/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:08<00:00,  6.21it/s, loss=0.424]


   üîª Multimodal avg train loss: 0.6927
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.84
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt

üìç [MULTIMODAL] Epoch 6/6 ‚Äî en‚Üíde


[MM Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:14<00:00,  6.18it/s, loss=0.531]


   üîª Multimodal avg train loss: 0.6671
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 42.92
‚úÖ Finished MULTIMODAL training en‚Üíde | Best BLEU: 42.84
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:43<00:00,  8.49it/s, loss=0.674]


   üîª Text-only avg train loss: 0.9582
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 39.66
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:48<00:00,  8.44it/s, loss=0.429]


   üîª Text-only avg train loss: 0.8744
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.11

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:49<00:00,  8.43it/s, loss=0.868]


   üîª Text-only avg train loss: 0.8423
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.24
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:47<00:00,  8.45it/s, loss=0.823]


   üîª Text-only avg train loss: 0.8189
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.61

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:47<00:00,  8.45it/s, loss=0.434]


   üîª Text-only avg train loss: 0.8048
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.90
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî en‚Üíde


[TXT Train en->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:50<00:00,  8.42it/s, loss=0.583]


   üîª Text-only avg train loss: 0.7954
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 40.75
‚úÖ Finished TEXT-ONLY training en‚Üíde | Best BLEU: 40.90

üèÅ LANGUAGE PAIR: EN ‚Üí FR
üîé Checking files for train en‚Üífr
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: en‚Üífr)
üîé Checking files for val en‚Üífr
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: en‚Üífr)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:06<00:00,  6.22it/s, loss=0.273]


   üîª Multimodal avg train loss: 0.8575
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 51.56
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:07<00:00,  6.21it/s, loss=0.352]


   üîª Multimodal avg train loss: 0.7151
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 53.22
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:28<00:00,  6.11it/s, loss=0.23]


   üîª Multimodal avg train loss: 0.6495
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 55.73
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:27<00:00,  6.11it/s, loss=0.46]


   üîª Multimodal avg train loss: 0.6041
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 55.87

üìç [MULTIMODAL] Epoch 5/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:06<00:00,  6.22it/s, loss=0.128]


   üîª Multimodal avg train loss: 0.5695
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 56.03

üìç [MULTIMODAL] Epoch 6/6 ‚Äî en‚Üífr


[MM Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:07<00:00,  6.21it/s, loss=1.19]


   üîª Multimodal avg train loss: 0.5434
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 56.68
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt
‚úÖ Finished MULTIMODAL training en‚Üífr | Best BLEU: 56.68
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:53<00:00,  8.39it/s, loss=0.861]


   üîª Text-only avg train loss: 0.9013
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 48.05
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:56<00:00,  8.36it/s, loss=0.442]


   üîª Text-only avg train loss: 0.7792
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 50.77
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:55<00:00,  8.37it/s, loss=0.598]


   üîª Text-only avg train loss: 0.7350
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 51.70
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:55<00:00,  8.37it/s, loss=0.728]


   üîª Text-only avg train loss: 0.7100
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 51.99

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:54<00:00,  8.38it/s, loss=0.81]


   üîª Text-only avg train loss: 0.6938
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 52.85
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî en‚Üífr


[TXT Train en->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:55<00:00,  8.37it/s, loss=0.141]


   üîª Text-only avg train loss: 0.6816
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 52.80
‚úÖ Finished TEXT-ONLY training en‚Üífr | Best BLEU: 52.85

üèÅ LANGUAGE PAIR: DE ‚Üí EN
üîé Checking files for train de‚Üíen
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: de‚Üíen)
üîé Checking files for val de‚Üíen
    /content/multi30k-dataset/data/task1/raw/val/val.de
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: de‚Üíen)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:09<00:00,  6.20it/s, loss=0.692]


   üîª Multimodal avg train loss: 0.8817
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 44.57
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:09<00:00,  6.20it/s, loss=0.437]


   üîª Multimodal avg train loss: 0.7957
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 45.61
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:33<00:00,  6.08it/s, loss=1.56]


   üîª Multimodal avg train loss: 0.7478
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 45.92

üìç [MULTIMODAL] Epoch 4/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:05<00:00,  6.22it/s, loss=0.832]


   üîª Multimodal avg train loss: 0.7068
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 46.79
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

üìç [MULTIMODAL] Epoch 5/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:28<00:00,  6.10it/s, loss=0.599]


   üîª Multimodal avg train loss: 0.6727
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 47.40
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

üìç [MULTIMODAL] Epoch 6/6 ‚Äî de‚Üíen


[MM Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:27<00:00,  6.11it/s, loss=0.927]


   üîª Multimodal avg train loss: 0.6466
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 47.24
‚úÖ Finished MULTIMODAL training de‚Üíen | Best BLEU: 47.40
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:52<00:00,  8.40it/s, loss=0.751]


   üîª Text-only avg train loss: 0.8841
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 44.24
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:53<00:00,  8.40it/s, loss=0.925]


   üîª Text-only avg train loss: 0.8269
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 45.08
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:52<00:00,  8.40it/s, loss=0.518]


   üîª Text-only avg train loss: 0.8013
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 45.40

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:50<00:00,  8.42it/s, loss=0.541]


   üîª Text-only avg train loss: 0.7846
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 46.02
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:53<00:00,  8.39it/s, loss=0.582]


   üîª Text-only avg train loss: 0.7737
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 46.17

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî de‚Üíen


[TXT Train de->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:48<00:00,  8.44it/s, loss=0.679]


   üîª Text-only avg train loss: 0.7655
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 46.47
‚úÖ Finished TEXT-ONLY training de‚Üíen | Best BLEU: 46.02

üèÅ LANGUAGE PAIR: DE ‚Üí FR
üîé Checking files for train de‚Üífr
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: de‚Üífr)
üîé Checking files for val de‚Üífr
    /content/multi30k-dataset/data/task1/raw/val/val.de
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: de‚Üífr)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:57<00:00,  6.26it/s, loss=0.487]


   üîª Multimodal avg train loss: 1.2159
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 36.00
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:03<00:00,  6.23it/s, loss=1.06]


   üîª Multimodal avg train loss: 1.0409
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 38.56
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:18<00:00,  6.16it/s, loss=0.581]


   üîª Multimodal avg train loss: 0.9680
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 39.85
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:59<00:00,  6.25it/s, loss=1.05]


   üîª Multimodal avg train loss: 0.9124
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.49
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt

üìç [MULTIMODAL] Epoch 5/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:49<00:00,  6.30it/s, loss=0.789]


   üîª Multimodal avg train loss: 0.8721
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.96

üìç [MULTIMODAL] Epoch 6/6 ‚Äî de‚Üífr


[MM Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:28<00:00,  6.42it/s, loss=0.631]


   üîª Multimodal avg train loss: 0.8444
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 40.98
‚úÖ Finished MULTIMODAL training de‚Üífr | Best BLEU: 40.49
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî de‚Üífr


[TXT Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:32<00:00,  8.60it/s, loss=0.967]


   üîª Text-only avg train loss: 1.2708
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 33.55
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî de‚Üífr


[TXT Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:55<00:00,  8.37it/s, loss=0.394]


   üîª Text-only avg train loss: 1.1103
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 36.01
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî de‚Üífr


[TXT Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:53<00:00,  8.39it/s, loss=0.833]


   üîª Text-only avg train loss: 1.0596
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 37.34
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî de‚Üífr


[TXT Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:54<00:00,  8.38it/s, loss=1.34]


   üîª Text-only avg train loss: 1.0287
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 38.19
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî de‚Üífr


[TXT Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:57<00:00,  8.36it/s, loss=0.754]


   üîª Text-only avg train loss: 1.0093
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 38.51

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî de‚Üífr


[TXT Train de->fr]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:57<00:00,  8.36it/s, loss=1.19]


   üîª Text-only avg train loss: 0.9989
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 38.37
‚úÖ Finished TEXT-ONLY training de‚Üífr | Best BLEU: 38.19

üèÅ LANGUAGE PAIR: FR ‚Üí EN
üîé Checking files for train fr‚Üíen
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: fr‚Üíen)
üîé Checking files for val fr‚Üíen
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: fr‚Üíen)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:03<00:00,  6.23it/s, loss=0.803]


   üîª Multimodal avg train loss: 0.8398
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 50.35
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:06<00:00,  6.22it/s, loss=0.186]


   üîª Multimodal avg train loss: 0.7411
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 51.65
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:28<00:00,  6.11it/s, loss=0.442]


   üîª Multimodal avg train loss: 0.6874
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 52.87
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:29<00:00,  6.10it/s, loss=0.711]


   üîª Multimodal avg train loss: 0.6475
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 53.90
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 5/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:25<00:00,  6.12it/s, loss=0.889]


   üîª Multimodal avg train loss: 0.6167
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 53.98

üìç [MULTIMODAL] Epoch 6/6 ‚Äî fr‚Üíen


[MM Train fr->en]:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 3498/7500 [09:20<10:56,  6.10it/s, loss=0.289]

In [None]:
# ==============================================================
# üåç MULTIMODAL TRANSLATION + SIGLIP + LORA + FUSION vs TEXT-ONLY
#  - Multi30K (data/task1/raw + image_splits)
#  - SigLIP vision encoder (google/siglip-base-patch16-224)
#  - mBART-50 text model with LoRA on attention (q_proj, v_proj)
#  - Better fusion: Transformer-based fusion over [IMG + TEXT]
#  - Also trains text-only mBART+LoRA baseline for comparison
#  - Saves models in BOTH Colab and Google Drive
# ==============================================================

import os
import json
from pathlib import Path
from typing import List, Tuple
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFile
from tqdm import tqdm
import evaluate
import warnings

warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------ HF + PEFT imports ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except ImportError as e:
    print(f"‚ùå Import error: {e}")
    os.system("pip install -q transformers peft accelerate")
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# ------------------ DEVICE ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

# ============================================================== #
# CONFIG
# ============================================================== #

@dataclass
class Config:
    # Paths
    data_root: str = "/content/multi30k-dataset"
    image_dir: str = "flickr30k-images"

    # Local Colab save dir
    save_dir: str = "/content/multimodal_translation_models_siglip_lora_fusion"
    # Drive save dir (we'll ensure it exists if Drive is mounted)
    drive_save_dir: str = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"

    # Training
    max_length: int = 64
    batch_size: int = 2          # small for VRAM safety
    learning_rate: float = 3e-5
    num_epochs: int = 6
    patience: int = 3
    min_delta: float = 0.5       # BLEU improvement threshold to reset patience
    use_amp: bool = True

    # Data limits
    max_train_samples: int = 15000
    max_val_samples: int = 1000

    # Optim
    warmup_steps: int = 100
    max_grad_norm: float = 1.0

    # Vision model (SigLIP)
    vision_model_name: str = "google/siglip-base-patch16-224"

    # LoRA
    use_lora: bool = True
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1
    lora_targets: List[str] = None

    # Language directions
    directions: List[Tuple[str, str]] = None

    def __post_init__(self):
        if self.lora_targets is None:
            # Attn projections in mBART encoder+decoder
            self.lora_targets = ["q_proj", "v_proj"]
        if self.directions is None:
            self.directions = [
                ("fr", "en"),
                ("fr", "de"),
            ]

config = Config()

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# Global metric (avoid re-loading each epoch)
sacrebleu_metric = evaluate.load("sacrebleu")

# ============================================================== #
# OPTIONAL: DRIVE MOUNT + SYMLINK (if you want)
# ============================================================== #

def mount_and_link_dataset():
    """
    Mounts Google Drive and links /content/multi30k-dataset to your folder.
    Safe: no directory scanning, just checks existence.
    """
    try:
        from google.colab import drive
    except Exception:
        print("‚ÑπÔ∏è Not running in Colab / no google.colab, skipping mount.")
        return config.data_root

    print("üîó Mounting Google Drive...")
    drive.mount("/content/drive")

    candidate_paths = [
        "/content/drive/MyDrive/multi30k-dataset",
        "/content/drive/MyDrive/dataset/multi30k-dataset",
        "/content/drive/MyDrive/Colab Notebooks/multi30k-dataset",
    ]

    dataset_path = None
    for p in candidate_paths:
        if os.path.exists(p):
            dataset_path = p
            print(f"‚úÖ Found dataset at: {p}")
            break

    if dataset_path is None:
        print("‚ùå Multi30K dataset not found in default locations. Using existing:", config.data_root)
        return config.data_root

    if os.path.islink("/content/multi30k-dataset") or os.path.exists("/content/multi30k-dataset"):
        os.system("rm -rf /content/multi30k-dataset")

    os.symlink(dataset_path, "/content/multi30k-dataset")
    print("üîó Symlink created ‚Üí /content/multi30k-dataset")
    return "/content/multi30k-dataset"

# ============================================================== #
# IMAGE LOADER (NO DIR LISTING)
# ============================================================== #

def safe_load_image(image_id: str, root: Path) -> Image.Image:
    """
    Loads one image by ID without listing directories.
    Multi30K image IDs in image_splits are usually like "1234567890.jpg" or "1234567890".
    We try: id, id.jpg, id.jpeg, id.png.
    """
    base = image_id.strip()
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[: -len(ext)]
            break

    candidates = [
        f"{base}.jpg",
        f"{base}.jpeg",
        f"{base}.png",
        base,
    ]

    for name in candidates:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except Exception:
                pass

    # Fallback: dummy gray image
    return Image.new("RGB", (224, 224), (128, 128, 128))

# ============================================================== #
# LORA HELPER
# ============================================================== #

def apply_lora_to_mbart(mbart: MBartForConditionalGeneration) -> MBartForConditionalGeneration:
    """
    Wraps mBART with LoRA on attention projections.
    """
    if not config.use_lora:
        print("‚ÑπÔ∏è LoRA disabled; training full mBART (heavier).")
        return mbart

    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    peft_model = get_peft_model(mbart, lora_cfg)
    print("‚úÖ LoRA applied to mBART (targets:", config.lora_targets, ")")
    peft_model.print_trainable_parameters()
    return peft_model

# ============================================================== #
# FUSION BLOCK (BETTER THAN PLAIN CONCAT)
# ============================================================== #

class FusionBlock(nn.Module):
    """
    Simple Transformer-based fusion over [IMG_TOKEN + TEXT_TOKENS].
    Lets the image token attend to text and vice versa.
    """
    def __init__(self, d_model: int, nhead: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=1)

    def forward(self, img_embed: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
        """
        img_embed: [B,1,d_model]
        text_embed: [B,L,d_model]
        returns fused: [B,1+L,d_model]
        """
        x = torch.cat([img_embed, text_embed], dim=1)  # [B,1+L,d]
        x = self.encoder(x)                            # fuse via self-attention
        return x

# ============================================================== #
# MULTIMODAL MODEL (SIGLIP + MBART + LORA + FUSION)
# ============================================================== #

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()

        # SigLIP vision encoder (vision-only)
        print(f"üîÑ Loading SigLIP vision model: {config.vision_model_name}")
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)

        # Freeze SigLIP to save memory & compute
        for p in self.vision.parameters():
            p.requires_grad = False

        # SigLIP vision hidden size
        vision_dim = self.vision.config.hidden_size
        print("üìê SigLIP vision hidden size:", vision_dim)

        # mBART-50 text model
        print("üîÑ Loading mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )

        # Apply LoRA on mBART
        self.mbart = apply_lora_to_mbart(base_mbart)

        # Shared text embeddings (LoRA-safe)
        self.text_emb = self.mbart.get_input_embeddings()

        # Project SigLIP CLS ‚Üí mBART hidden size
        self.proj = nn.Linear(vision_dim, self.mbart.config.d_model)

        # Fusion block
        self.fusion = FusionBlock(d_model=self.mbart.config.d_model, nhead=8, dim_ff=2048, dropout=0.1)

    def forward(self, input_ids, attention_mask, pixel_values, labels=None):
        batch_size = input_ids.size(0)

        # 1) SigLIP image features (CLS token)
        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]   # [B, hidden_dim]

        img_embed = self.proj(img_feat).unsqueeze(1)               # [B,1,d_model]

        # 2) Text embeddings from mBART shared embedding matrix
        text_embed = self.text_emb(input_ids)                      # [B,L,d_model]

        # 3) Transformer-based fusion
        fused = self.fusion(img_embed, text_embed)                 # [B,1+L,d_model]

        # 4) Attention mask (add image token)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=device), attention_mask],
            dim=1,
        )

        # 5) mBART forward using inputs_embeds
        outputs = self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels,
            return_dict=True,
        )
        return outputs

    def generate(self, input_ids, attention_mask, pixel_values, tokenizer):
        batch_size = input_ids.size(0)

        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]

        img_embed = self.proj(img_feat).unsqueeze(1)
        text_embed = self.text_emb(input_ids)

        fused = self.fusion(img_embed, text_embed)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=device), attention_mask],
            dim=1,
        )

        gen_ids = self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=config.max_length,
            num_beams=3,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ============================================================== #
# TEXT-ONLY MODEL (MBART + LORA)
# ============================================================== #

class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        print("üîÑ Loading text-only mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base_mbart)

    def forward(self, input_ids, attention_mask, labels=None):
        return self.mbart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
        )

    def generate(self, input_ids, attention_mask, tokenizer):
        gen_ids = self.mbart.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=config.max_length,
            num_beams=3,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ============================================================== #
# DATASETS
# ============================================================== #

class MultiModalDataset(Dataset):
    def __init__(self, image_ids, src, tgt, tokenizer, image_processor, img_root):
        self.ids = image_ids
        self.src = src
        self.tgt = tgt
        self.tok = tokenizer
        self.img_proc = image_processor
        self.img_root = img_root

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        src = self.src[idx]
        tgt = self.tgt[idx]

        enc = self.tok(
            src,
            max_length=config.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt,
                max_length=config.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

        labels = dec["input_ids"].squeeze()
        labels[labels == self.tok.pad_token_id] = -100

        img = safe_load_image(img_id, self.img_root)
        pv = self.img_proc(images=img, return_tensors="pt")["pixel_values"].squeeze()

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": labels,
            "pixel_values": pv,
            "target_text": tgt,
        }

class TextOnlyDataset(Dataset):
    def __init__(self, src, tgt, tokenizer):
        self.src = src
        self.tgt = tgt
        self.tok = tokenizer

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        src = self.src[idx]
        tgt = self.tgt[idx]

        enc = self.tok(
            src,
            max_length=config.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt,
                max_length=config.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
            )

        labels = dec["input_ids"].squeeze()
        labels[labels == self.tok.pad_token_id] = -100

        return {
            "input_ids": enc["input_ids"].squeeze(),
            "attention_mask": enc["attention_mask"].squeeze(),
            "labels": labels,
            "target_text": tgt,
        }

# ============================================================== #
# LOAD SPLIT
# ============================================================== #

def load_split(root, split, src_lang, tgt_lang, limit):
    """
    Expects:
      root/data/task1/raw/{split}/{split}.{lang}
      root/data/task1/image_splits/{split}.txt
    """
    root = Path(root)
    raw = root / "data" / "task1" / "raw" / split
    id_file = root / "data" / "task1" / "image_splits" / f"{split}.txt"

    src_file = raw / f"{split}.{src_lang}"
    tgt_file = raw / f"{split}.{tgt_lang}"

    print(f"üîé Checking files for {split} {src_lang}‚Üí{tgt_lang}")
    print("   ", src_file)
    print("   ", tgt_file)
    print("   ", id_file)

    if not src_file.exists() or not tgt_file.exists() or not id_file.exists():
        print(f"‚ùå Missing one or more files for {split} ({src_lang}‚Üí{tgt_lang})")
        return [], [], []

    ids = [l.strip() for l in open(id_file, encoding="utf-8") if l.strip()]
    src = [l.strip() for l in open(src_file, encoding="utf-8") if l.strip()]
    tgt = [l.strip() for l in open(tgt_file, encoding="utf-8") if l.strip()]

    n = min(len(ids), len(src), len(tgt), limit)
    print(f"‚úÖ Loaded {n} samples ({split}: {src_lang}‚Üí{tgt_lang})")
    return ids[:n], src[:n], tgt[:n]

# ============================================================== #
# TRAINING + EVAL HELPERS
# ============================================================== #

def compute_bleu_multimodal(model, loader, tokenizer):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            pv = batch["pixel_values"].to(device)
            tgt = batch["target_text"]

            gen_ids = model.generate(ids, mask, pv, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            preds.extend(decoded)
            refs.extend([[t] for t in tgt])

    bleu = sacrebleu_metric.compute(predictions=preds, references=refs)["score"]
    print(f"   üîµ Multimodal BLEU: {bleu:.2f}")
    return bleu

def compute_bleu_text(model, loader, tokenizer):
    model.eval()
    preds, refs = [], []
    with torch.no_grad():
        for batch in loader:
            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            tgt = batch["target_text"]

            gen_ids = model.generate(ids, mask, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True) # write the translations in file and

            preds.extend(decoded)
            refs.extend([[t] for t in tgt])

    bleu = sacrebleu_metric.compute(predictions=preds, references=refs)["score"]
    print(f"   üîµ Text-only BLEU: {bleu:.2f}")
    return bleu

# ------------------ TRAIN MULTIMODAL ------------------ #

def train_multimodal_model(src_lang, tgt_lang, tokenizer, train_ds, val_ds,
                           local_save_dir: Path, drive_save_dir: Path | None):
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size)

    model = MultiModalModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max(total_steps, 1),
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None
    best_bleu = 0.0
    no_improve = 0

    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [MULTIMODAL] Epoch {epoch}/{config.num_epochs} ‚Äî {src_lang}‚Üí{tgt_lang}")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc=f"[MM Train {src_lang}->{tgt_lang}]")
        for batch in loop:
            opt.zero_grad()

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            lbl = batch["labels"].to(device)
            pv = batch["pixel_values"].to(device)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, pv, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, pv, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (multimodal)")
            break

        avg_loss = total_loss / max(len(train_loader), 1)
        print(f"   üîª Multimodal avg train loss: {avg_loss:.4f}")

        print("   üîç Evaluating multimodal on validation...")
        bleu = compute_bleu_multimodal(model, val_loader, tokenizer)

        improved = bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = bleu
            no_improve = 0

            filename = f"siglip_fusion_lora_{src_lang}_{tgt_lang}_mm_best.pt"
            local_path = local_save_dir / filename
            torch.save(model.state_dict(), local_path)
            print(f"   üíæ Saved best MULTIMODAL model (local) ‚Üí {local_path}")

            if drive_save_dir is not None:
                drive_path = drive_save_dir / filename
                torch.save(model.state_dict(), drive_path)
                print(f"   üíæ Saved best MULTIMODAL model (drive) ‚Üí {drive_path}")

        else:
            no_improve += 1

        if no_improve >= config.patience:
            print(f"üõë Early stopping MULTIMODAL {src_lang}‚Üí{tgt_lang} at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished MULTIMODAL training {src_lang}‚Üí{tgt_lang} | Best BLEU: {best_bleu:.2f}")
    return best_bleu

# ------------------ TRAIN TEXT-ONLY ------------------ #

def train_text_model(src_lang, tgt_lang, tokenizer, train_ds, val_ds,
                     local_save_dir: Path, drive_save_dir: Path | None):
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size)

    model = TextOnlyModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = len(train_loader) * config.num_epochs
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=max(total_steps, 1),
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None
    best_bleu = 0.0
    no_improve = 0

    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [TEXT-ONLY] Epoch {epoch}/{config.num_epochs} ‚Äî {src_lang}‚Üí{tgt_lang}")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc=f"[TXT Train {src_lang}->{tgt_lang}]")
        for batch in loop:
            opt.zero_grad()

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            lbl = batch["labels"].to(device)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (text-only)")
            break

        avg_loss = total_loss / max(len(train_loader), 1)
        print(f"   üîª Text-only avg train loss: {avg_loss:.4f}")

        print("   üîç Evaluating TEXT-ONLY on validation...")
        bleu = compute_bleu_text(model, val_loader, tokenizer)

        improved = bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = bleu
            no_improve = 0

            filename = f"mbart_lora_{src_lang}_{tgt_lang}_text_best.pt"
            local_path = local_save_dir / filename
            torch.save(model.state_dict(), local_path)
            print(f"   üíæ Saved best TEXT-ONLY model (local) ‚Üí {local_path}")

            if drive_save_dir is not None:
                drive_path = drive_save_dir / filename
                torch.save(model.state_dict(), drive_path)
                print(f"   üíæ Saved best TEXT-ONLY model (drive) ‚Üí {drive_path}")

        else:
            no_improve += 1

        if no_improve >= config.patience:
            print(f"üõë Early stopping TEXT-ONLY {src_lang}‚Üí{tgt_lang} at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished TEXT-ONLY training {src_lang}‚Üí{tgt_lang} | Best BLEU: {best_bleu:.2f}")
    return best_bleu

# ============================================================== #
# MAIN
# ============================================================== #

def main():
    # Optional: remap data_root via Drive symlink
    if os.path.exists("/content/drive"):
        config.data_root = mount_and_link_dataset()

    # Local save dir (Colab)
    local_save_dir = Path(config.save_dir)
    local_save_dir.mkdir(parents=True, exist_ok=True)

    # Drive save dir (if Drive is mounted)
    drive_save_dir = None
    drive_root = Path("/content/drive/MyDrive")
    if drive_root.exists():
        drive_save_dir = Path(config.drive_save_dir)
        drive_save_dir.mkdir(parents=True, exist_ok=True)
        print(f"üíæ Drive save dir: {drive_save_dir}")
    else:
        print("‚ö†Ô∏è Drive not mounted or /content/drive/MyDrive missing; will only save locally.")

    print("üîÑ Loading MBart tokenizer & SigLIP processor...")
    tokenizer = MBart50TokenizerFast.from_pretrained(
        "facebook/mbart-large-50-many-to-many-mmt"
    )
    image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

    # Save config (local + drive)
    cfg = asdict(config)
    cfg_path_local = local_save_dir / "config_siglip_fusion_lora.json"
    with open(cfg_path_local, "w") as f:
        json.dump(cfg, f, indent=2)
    print(f"üíæ Config saved (local) at: {cfg_path_local}")

    if drive_save_dir is not None:
        cfg_path_drive = drive_save_dir / "config_siglip_fusion_lora.json"
        with open(cfg_path_drive, "w") as f:
            json.dump(cfg, f, indent=2)
        print(f"üíæ Config saved (drive) at: {cfg_path_drive}")

    results_multimodal = {}
    results_textonly = {}

    for src, tgt in config.directions:
        print("\n======================================================================")
        print(f"üèÅ LANGUAGE PAIR: {src.upper()} ‚Üí {tgt.upper()}")
        print("======================================================================")

        tokenizer.src_lang = LANG_CODES[src]
        tokenizer.tgt_lang = LANG_CODES[tgt]

        train_ids, train_src, train_tgt = load_split(
            config.data_root, "train", src, tgt, config.max_train_samples
        )
        val_ids, val_src, val_tgt = load_split(
            config.data_root, "val", src, tgt, config.max_val_samples
        )

        if len(train_ids) == 0:
            print(f"‚ö†Ô∏è Skipping {src}‚Üí{tgt} (no data)")
            continue

        img_root = Path(config.data_root) / config.image_dir

        # Datasets
        train_mm = MultiModalDataset(
            train_ids, train_src, train_tgt,
            tokenizer, image_processor, img_root
        )
        val_mm = MultiModalDataset(
            val_ids, val_src, val_tgt,
            tokenizer, image_processor, img_root
        )

        train_txt = TextOnlyDataset(train_src, train_tgt, tokenizer)
        val_txt = TextOnlyDataset(val_src, val_tgt, tokenizer)

        # ----- Train MULTIMODAL -----
        mm_bleu = train_multimodal_model(
            src, tgt, tokenizer, train_mm, val_mm,
            local_save_dir=local_save_dir,
            drive_save_dir=drive_save_dir,
        )
        results_multimodal[f"{src}_{tgt}"] = mm_bleu

        # ----- Train TEXT-ONLY -----
        txt_bleu = train_text_model(
            src, tgt, tokenizer, train_txt, val_txt,
            local_save_dir=local_save_dir,
            drive_save_dir=drive_save_dir,
        )
        results_textonly[f"{src}_{tgt}"] = txt_bleu

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print("\nüìä FINAL BLEU SCORES (MULTIMODAL):")
    for k, v in results_multimodal.items():
        print(f"  {k}: {v:.2f}")

    print("\nüìä FINAL BLEU SCORES (TEXT-ONLY):")
    for k, v in results_textonly.items():
        print(f"  {k}: {v:.2f}")

if __name__ == "__main__":
    main()


Using device: cuda
üîó Mounting Google Drive...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úÖ Found dataset at: /content/drive/MyDrive/dataset/multi30k-dataset
üîó Symlink created ‚Üí /content/multi30k-dataset
üíæ Drive save dir: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion
üîÑ Loading MBart tokenizer & SigLIP processor...


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


üíæ Config saved (local) at: /content/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora.json
üíæ Config saved (drive) at: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora.json

üèÅ LANGUAGE PAIR: FR ‚Üí EN
üîé Checking files for train fr‚Üíen
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: fr‚Üíen)
üîé Checking files for val fr‚Üíen
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/raw/val/val.en
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: fr‚Üíen)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
traina

[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [2:00:04<00:00,  1.04it/s, loss=1.16]


   üîª Multimodal avg train loss: 0.8408
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 50.29
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:14<00:00,  6.18it/s, loss=0.808]


   üîª Multimodal avg train loss: 0.7377
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 51.69
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:48<00:00,  6.31it/s, loss=0.859]


   üîª Multimodal avg train loss: 0.6863
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 52.63
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:33<00:00,  6.39it/s, loss=0.391]


   üîª Multimodal avg train loss: 0.6450
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 52.88

üìç [MULTIMODAL] Epoch 5/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:11<00:00,  6.19it/s, loss=0.35]


   üîª Multimodal avg train loss: 0.6112
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 52.98

üìç [MULTIMODAL] Epoch 6/6 ‚Äî fr‚Üíen


[MM Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:00<00:00,  6.57it/s, loss=0.817]


   üîª Multimodal avg train loss: 0.5871
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 53.24
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
‚úÖ Finished MULTIMODAL training fr‚Üíen | Best BLEU: 53.24
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî fr‚Üíen


[TXT Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:23<00:00,  8.69it/s, loss=0.542]


   üîª Text-only avg train loss: 0.8561
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 48.67
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî fr‚Üíen


[TXT Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:25<00:00,  8.67it/s, loss=0.891]


   üîª Text-only avg train loss: 0.7879
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 49.92
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî fr‚Üíen


[TXT Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:18<00:00,  8.73it/s, loss=0.469]


   üîª Text-only avg train loss: 0.7583
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 50.62
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî fr‚Üíen


[TXT Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:11<00:00,  8.81it/s, loss=0.733]


   üîª Text-only avg train loss: 0.7401
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 51.07

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî fr‚Üíen


[TXT Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:31<00:00,  8.61it/s, loss=1.12]


   üîª Text-only avg train loss: 0.7253
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 51.75
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî fr‚Üíen


[TXT Train fr->en]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:27<00:00,  8.65it/s, loss=0.335]


   üîª Text-only avg train loss: 0.7179
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 51.73
‚úÖ Finished TEXT-ONLY training fr‚Üíen | Best BLEU: 51.75

üèÅ LANGUAGE PAIR: FR ‚Üí DE
üîé Checking files for train fr‚Üíde
    /content/multi30k-dataset/data/task1/raw/train/train.fr
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: fr‚Üíde)
üîé Checking files for val fr‚Üíde
    /content/multi30k-dataset/data/task1/raw/val/val.fr
    /content/multi30k-dataset/data/task1/raw/val/val.de
    /content/multi30k-dataset/data/task1/image_splits/val.txt
‚úÖ Loaded 1000 samples (val: fr‚Üíde)
üîÑ Loading SigLIP vision model: google/siglip-base-patch16-224
üìê SigLIP vision hidden size: 768
üîÑ Loading mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç

[MM Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:48<00:00,  6.31it/s, loss=1.44]


   üîª Multimodal avg train loss: 1.2667
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 31.35
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî fr‚Üíde


[MM Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:24<00:00,  6.44it/s, loss=1.21]


   üîª Multimodal avg train loss: 1.1166
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 34.03
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî fr‚Üíde


[MM Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [20:01<00:00,  6.24it/s, loss=1.09]


   üîª Multimodal avg train loss: 1.0440
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 36.08
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî fr‚Üíde


[MM Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:52<00:00,  6.29it/s, loss=0.416]


   üîª Multimodal avg train loss: 0.9897
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 35.82

üìç [MULTIMODAL] Epoch 5/6 ‚Äî fr‚Üíde


[MM Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:23<00:00,  6.45it/s, loss=0.848]


   üîª Multimodal avg train loss: 0.9461
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 36.46

üìç [MULTIMODAL] Epoch 6/6 ‚Äî fr‚Üíde


[MM Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [19:26<00:00,  6.43it/s, loss=0.52]


   üîª Multimodal avg train loss: 0.9192
   üîç Evaluating multimodal on validation...
   üîµ Multimodal BLEU: 36.23
üõë Early stopping MULTIMODAL fr‚Üíde at epoch 6
‚úÖ Finished MULTIMODAL training fr‚Üíde | Best BLEU: 36.08
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî fr‚Üíde


[TXT Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:11<00:00,  8.81it/s, loss=1.89]


   üîª Text-only avg train loss: 1.3255
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 29.45
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî fr‚Üíde


[TXT Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:05<00:00,  8.87it/s, loss=1.31]


   üîª Text-only avg train loss: 1.1915
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 31.27
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî fr‚Üíde


[TXT Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:04<00:00,  8.88it/s, loss=2.24]


   üîª Text-only avg train loss: 1.1483
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 32.29
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî fr‚Üíde


[TXT Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:16<00:00,  8.75it/s, loss=0.809]


   üîª Text-only avg train loss: 1.1225
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 33.13
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî fr‚Üíde


[TXT Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:10<00:00,  8.82it/s, loss=1.77]


   üîª Text-only avg train loss: 1.1009
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 33.12

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî fr‚Üíde


[TXT Train fr->de]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [14:10<00:00,  8.82it/s, loss=0.967]


   üîª Text-only avg train loss: 1.0904
   üîç Evaluating TEXT-ONLY on validation...
   üîµ Text-only BLEU: 33.21
‚úÖ Finished TEXT-ONLY training fr‚Üíde | Best BLEU: 33.13

üìä FINAL BLEU SCORES (MULTIMODAL):
  fr_en: 53.24
  fr_de: 36.08

üìä FINAL BLEU SCORES (TEXT-ONLY):
  fr_en: 51.75
  fr_de: 33.13


In [None]:
# ================================================================
# üåç EVAL: SIGLIP + MBART + LORA (MULTIMODAL vs TEXT-ONLY)
#  - Loads models from Drive
#  - Uses test_2017_flickr split from Multi30K
#  - Saves merged JSON per pair (Option B)
#  - Saves readable TXT per pair
#  - Saves BLEU
# ================================================================

import os
import json
from pathlib import Path
from typing import List, Tuple

import torch
import torch.nn as nn
from PIL import Image, ImageFile
import evaluate
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------ HF + PEFT imports ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except:
    !pip install -q transformers peft accelerate
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# ------------------ Mount Drive ------------------
from google.colab import drive
drive.mount("/content/drive")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------ Paths ------------------
MODEL_DIR = Path("/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion")
TEST_DIR  = Path("/content/drive/MyDrive/dataset/multi30k-dataset/data/task1/raw/test_2016_flickr")
SPLIT_FILE = Path("/content/drive/MyDrive/dataset/multi30k-dataset/data/task1/image_splits/test_2016_flickr.txt")
IMG_ROOT = Path("/content/drive/MyDrive/dataset/multi30k-dataset/flickr30k-images")
OUT_DIR  = MODEL_DIR / "test2016_predictions"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("MODEL_DIR:", MODEL_DIR)
print("TEST_DIR :", TEST_DIR)
print("SPLIT_FILE:", SPLIT_FILE)
print("IMG_ROOT:", IMG_ROOT)
print("OUT_DIR:", OUT_DIR)

# ------------------ Load saved training config ------------------
cfg_path = MODEL_DIR / "config_siglip_fusion_lora.json"
import types
config_dict = json.load(open(cfg_path))
config = types.SimpleNamespace(**config_dict)

config.directions = [
    ("en", "de"),
    ("en", "fr"),
    ("de", "en"),
    ("de", "fr"),
    ("fr", "en"),
    ("fr", "de"),
]

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# ------------------ Safe image loader ------------------
def safe_load_image(image_id: str, root: Path):
    base = image_id.strip().replace("\n", "")
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[:-len(ext)]
            break
    for name in [base+".jpg", base+".jpeg", base+".png"]:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                pass
    return Image.new("RGB", (224,224), (128,128,128))

# ------------------ LoRA helper ------------------
def apply_lora_to_mbart(mbart):
    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    return get_peft_model(mbart, lora_cfg)

# ------------------ Fusion block ------------------
class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8, dim_feedforward=2048,
            dropout=0.1, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_embed, text_embed):
        x = torch.cat([img_embed, text_embed], dim=1)
        return self.encoder(x)

# ------------------ MultiModal Model ------------------
class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        print("Loading SigLIP:", config.vision_model_name)
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)
        for p in self.vision.parameters():
            p.requires_grad = False

        print("Loading MBART...")
        base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        self.mbart = apply_lora_to_mbart(base)
        self.text_emb = self.mbart.get_input_embeddings()

        vision_dim = self.vision.config.hidden_size
        self.proj = nn.Linear(vision_dim, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def generate(self, input_ids, mask, pixel_values, tokenizer):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:,0,:]
        img = self.proj(vis).unsqueeze(1)
        txt = self.text_emb(input_ids)
        fused = self.fusion(img, txt)
        fused_mask = torch.cat([torch.ones((input_ids.size(0),1),device=device), mask], dim=1)

        return self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=config.max_length,
            num_beams=3,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
        )

# ------------------ Text-only Model ------------------
class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        self.mbart = apply_lora_to_mbart(base)

    def generate(self, input_ids, mask, tokenizer):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=mask,
            max_length=config.max_length,
            num_beams=3,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
        )

# ------------------ Load test files ------------------
def load_test(src, tgt):
    ids = open(SPLIT_FILE).read().splitlines()
    src_txt = open(TEST_DIR / f"test_2016_flickr.{src}").read().splitlines()
    tgt_txt = open(TEST_DIR / f"test_2016_flickr.{tgt}").read().splitlines()
    n = min(len(ids), len(src_txt), len(tgt_txt))
    return ids[:n], src_txt[:n], tgt_txt[:n]

# ------------------ Evaluate one pair ------------------
def evaluate_pair(src, tgt):
    print(f"\n===== Evaluating {src} ‚Üí {tgt} =====")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ids, src_txt, tgt_txt = load_test(src, tgt)

    # Load models
    mm_path = MODEL_DIR / f"siglip_fusion_lora_{src}_{tgt}_mm_best.pt"
    txt_path = MODEL_DIR / f"mbart_lora_{src}_{tgt}_text_best.pt"

    print("üìå Using MM:", mm_path)
    print("üìå Using TXT:", txt_path)

    mm = MultiModalModel().to(device)
    txt = TextOnlyModel().to(device)

    mm.load_state_dict(torch.load(mm_path, map_location=device))
    txt.load_state_dict(torch.load(txt_path, map_location=device))

    mm.eval()
    txt.eval()

    results = []
    refs = [[t] for t in tgt_txt]
    preds_mm, preds_txt = [], []

    for i in tqdm(range(len(ids))):
        enc = tokenizer(src_txt[i], max_length=config.max_length, truncation=True,
                        padding="max_length", return_tensors="pt")
        input_ids = enc["input_ids"].to(device)
        mask = enc["attention_mask"].to(device)

        img = safe_load_image(ids[i], IMG_ROOT)
        pixel = image_processor(images=img, return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            mm_out = mm.generate(input_ids, mask, pixel, tokenizer)
            txt_out = txt.generate(input_ids, mask, tokenizer)

        mm_pred = tokenizer.decode(mm_out[0], skip_special_tokens=True)
        txt_pred = tokenizer.decode(txt_out[0], skip_special_tokens=True)

        preds_mm.append(mm_pred)
        preds_txt.append(txt_pred)

        results.append({
            "image_id": ids[i],
            "source": src_txt[i],
            "target": tgt_txt[i],
            "prediction_multimodal": mm_pred,
            "prediction_textonly": txt_pred
        })

    # Save merged JSON (Option B)
    json_path = OUT_DIR / f"predictions_{src}_{tgt}.json"
    json.dump(results, open(json_path, "w"), indent=2, ensure_ascii=False)
    print("üíæ JSON saved:", json_path)

    # Save clean TXT
    txt_path_out = OUT_DIR / f"predictions_{src}_{tgt}.txt"
    with open(txt_path_out, "w") as f:
        for r in results:
            f.write(f"{r['image_id']}\n")
            f.write(f"SRC : {r['source']}\n")
            f.write(f"TRG : {r['target']}\n")
            f.write(f"MM  : {r['prediction_multimodal']}\n")
            f.write(f"TXT : {r['prediction_textonly']}\n")
            f.write("-"*60 + "\n")
    print("üíæ TXT saved:", txt_path_out)

    # BLEU
    sacre = evaluate.load("sacrebleu")
    bleu_mm = sacre.compute(predictions=preds_mm, references=refs)["score"]
    bleu_txt = sacre.compute(predictions=preds_txt, references=refs)["score"]

    print(f"BLEU MM  = {bleu_mm:.2f}")
    print(f"BLEU TXT = {bleu_txt:.2f}")

    return bleu_mm, bleu_txt

# ------------------ Run all ------------------
all_scores = {}
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

for src, tgt in config.directions:
    mm_bleu, txt_bleu = evaluate_pair(src, tgt)
    all_scores[f"{src}->{tgt}"] = {"mm": mm_bleu, "txt": txt_bleu}

json.dump(all_scores, open(OUT_DIR/"bleu_scores.json","w"), indent=2)
print("Saved BLEU summary.")
print(all_scores)


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
MODEL_DIR: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion
TEST_DIR : /content/drive/MyDrive/dataset/multi30k-dataset/data/task1/raw/test_2016_flickr
SPLIT_FILE: /content/drive/MyDrive/dataset/multi30k-dataset/data/task1/image_splits/test_2016_flickr.txt
IMG_ROOT: /content/drive/MyDrive/dataset/multi30k-dataset/flickr30k-images
OUT_DIR: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions

===== Evaluating en ‚Üí de =====
üìå Using MM: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt
üìå Using TXT: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_de_text_best.pt
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [16:08<00:00,  1.03it/s]


üíæ JSON saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_en_de.json
üíæ TXT saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_en_de.txt
BLEU MM  = 40.38
BLEU TXT = 38.55

===== Evaluating en ‚Üí fr =====
üìå Using MM: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt
üìå Using TXT: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_en_fr_text_best.pt
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [17:08<00:00,  1.03s/it]


üíæ JSON saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_en_fr.json
üíæ TXT saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_en_fr.txt
BLEU MM  = 57.29
BLEU TXT = 54.02

===== Evaluating de ‚Üí en =====
üìå Using MM: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt
üìå Using TXT: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_en_text_best.pt
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [14:20<00:00,  1.16it/s]


üíæ JSON saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_de_en.json
üíæ TXT saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_de_en.txt
BLEU MM  = 46.17
BLEU TXT = 45.04

===== Evaluating de ‚Üí fr =====
üìå Using MM: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_fr_mm_best.pt
üìå Using TXT: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_de_fr_text_best.pt
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [16:55<00:00,  1.02s/it]


üíæ JSON saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_de_fr.json
üíæ TXT saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_de_fr.txt
BLEU MM  = 39.51
BLEU TXT = 37.01

===== Evaluating fr ‚Üí en =====
üìå Using MM: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt
üìå Using TXT: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_en_text_best.pt
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [14:37<00:00,  1.14it/s]


üíæ JSON saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_fr_en.json
üíæ TXT saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_fr_en.txt
BLEU MM  = 54.34
BLEU TXT = 51.68

===== Evaluating fr ‚Üí de =====
üìå Using MM: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_de_mm_best.pt
üìå Using TXT: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_fr_de_text_best.pt
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART...


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [15:58<00:00,  1.04it/s]


üíæ JSON saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_fr_de.json
üíæ TXT saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_predictions/predictions_fr_de.txt
BLEU MM  = 33.87
BLEU TXT = 32.17
Saved BLEU summary.
{'en->de': {'mm': 40.38295475881235, 'txt': 38.55443359709484}, 'en->fr': {'mm': 57.29245656795444, 'txt': 54.021818840298465}, 'de->en': {'mm': 46.170166716423005, 'txt': 45.040496619688085}, 'de->fr': {'mm': 39.51088858195834, 'txt': 37.00733032947563}, 'fr->en': {'mm': 54.3362125592121, 'txt': 51.67768653929201}, 'fr->de': {'mm': 33.86687639457443, 'txt': 32.17351310984113}}


In [None]:
## Traning the models with all language sets (Unified Model Trining)

In [None]:
# Training again with same setup

In [None]:
# ==============================================================
# üåç MULTIMODAL TRANSLATION (SIGLIP + MBART + LORA FUSION)
# Optimized for A100
# ==============================================================

import os
import json
import shutil
from pathlib import Path
from typing import List, Tuple, Dict, Any
from dataclasses import dataclass, asdict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageFile
import warnings
from tqdm import tqdm
import evaluate

warnings.filterwarnings("ignore")
ImageFile.LOAD_TRUNCATED_IMAGES = True

# --------------------------------------------------------------
# üöÄ A100 performance boost
# --------------------------------------------------------------
torch.set_float32_matmul_precision("high")

# --------------------------------------------------------------
# DEVICE
# --------------------------------------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

# --------------------------------------------------------------
# METRICS
# --------------------------------------------------------------
sacrebleu_metric = evaluate.load("sacrebleu")

# --------------------------------------------------------------
# HF + PEFT imports
# --------------------------------------------------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except:
    os.system("pip install -q transformers peft accelerate evaluate")
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
        get_linear_schedule_with_warmup,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# --------------------------------------------------------------
# CONFIG (Optimized for A100)
# --------------------------------------------------------------
@dataclass
class Config:
    # Paths
    data_root: str = "/content/multi30k-dataset-local"
    image_dir: str = "flickr30k-images"

    save_dir: str = "/content/multimodal_translation_models_siglip_lora_fusion"
    drive_save_dir: str = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"

    # Training
    max_length: int = 64
    batch_size: int = 32
    learning_rate: float = 3e-5
    num_epochs: int = 6
    patience: int = 3
    min_delta: float = 0.5
    use_amp: bool = True

    # Data size per direction
    max_train_samples: int = 15000
    max_val_samples: int = 200

    # Optimization
    warmup_steps: int = 100
    max_grad_norm: float = 1.0

    # Dataloader (A100 optimized)
    num_workers: int = 4
    pin_memory: bool = True

    # Vision + LoRA
    vision_model_name: str = "google/siglip-base-patch16-224"
    use_lora: bool = True
    lora_r: int = 8
    lora_alpha: int = 16
    lora_dropout: float = 0.1
    lora_targets: List[str] = None

    directions: List[Tuple[str, str]] = None

    def __post_init__(self):
        if self.lora_targets is None:
            self.lora_targets = ["q_proj", "v_proj"]
        if self.directions is None:
            self.directions = [
                ("en", "de"),
                ("en", "fr"),
                ("de", "en"),
                ("de", "fr"),
                ("fr", "en"),
                ("fr", "de"),
            ]

config = Config()
LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# --------------------------------------------------------------
# üöÄ Copy dataset from Google Drive ‚Üí LOCAL SSD (/content)
#     HUGE SPEEDUP (20√ó faster images)
# --------------------------------------------------------------
def copy_dataset_to_local():
    drive_dataset = Path("/content/drive/MyDrive/dataset/multi30k-dataset")
    local_dataset = Path("/content/multi30k-dataset-local")

    if drive_dataset.exists() and not local_dataset.exists():
        print("üìÇ Copying dataset from Drive ‚Üí /content (one-time)...")
        shutil.copytree(drive_dataset, local_dataset)
        print("‚úÖ Copy complete.")
    else:
        print("‚ÑπÔ∏è Local dataset already exists or Drive missing.")

    config.data_root = str(local_dataset)
    print("üìå Using LOCAL dataset:", config.data_root)
    return config.data_root


Using device: cuda


In [None]:
# ==============================================================
# IMAGE LOADER (NO DIR LISTING, FAST)
# ==============================================================

from PIL import Image

def safe_load_image(image_id: str, root: Path) -> Image.Image:
    """
    Loads one image by ID without listing directories.
    Multi30K image IDs in image_splits are usually like "1234567890.jpg" or "1234567890".
    We try: id, id.jpg, id.jpeg, id.png.
    """
    base = image_id.strip()
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[: -len(ext)]
            break

    candidates = [
        f"{base}.jpg",
        f"{base}.jpeg",
        f"{base}.png",
        base,
    ]

    for name in candidates:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except Exception:
                pass

    # Fallback: dummy gray image (should almost never happen)
    return Image.new("RGB", (224, 224), (128, 128, 128))


# ==============================================================
# LORA HELPER
# ==============================================================

def apply_lora_to_mbart(mbart: MBartForConditionalGeneration) -> MBartForConditionalGeneration:
    """
    Wraps mBART with LoRA on attention projections.
    """
    if not config.use_lora:
        print("‚ÑπÔ∏è LoRA disabled; training full mBART (heavier).")
        return mbart

    lora_cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    peft_model = get_peft_model(mbart, lora_cfg)
    print("‚úÖ LoRA applied to mBART (targets:", config.lora_targets, ")")
    peft_model.print_trainable_parameters()
    return peft_model


# ==============================================================
# FUSION BLOCK
# ==============================================================

class FusionBlock(nn.Module):
    """
    Transformer-based fusion over [IMG_TOKEN + TEXT_TOKENS].
    Lets the image token attend to text and vice versa.
    """
    def __init__(self, d_model: int, nhead: int = 8, dim_ff: int = 2048, dropout: float = 0.1):
        super().__init__()
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_ff,
            dropout=dropout,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=1)

    def forward(self, img_embed: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
        """
        img_embed: [B,1,d_model]
        text_embed: [B,L,d_model]
        returns fused: [B,1+L,d_model]
        """
        x = torch.cat([img_embed, text_embed], dim=1)  # [B,1+L,d]
        x = self.encoder(x)                            # fuse via self-attention
        return x


# ==============================================================
# MULTIMODAL MODEL (SIGLIP + MBART + LORA + FUSION)
# ==============================================================

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()

        # SigLIP vision encoder (vision-only)
        print(f"üîÑ Loading SigLIP vision model: {config.vision_model_name}")
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)

        # Freeze SigLIP to save memory & compute
        for p in self.vision.parameters():
            p.requires_grad = False

        # SigLIP vision hidden size
        vision_dim = self.vision.config.hidden_size
        print("üìê SigLIP vision hidden size:", vision_dim)

        # mBART-50 text model
        print("üîÑ Loading mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )

        # Apply LoRA on mBART
        self.mbart = apply_lora_to_mbart(base_mbart)

        # Shared text embeddings (LoRA-safe)
        self.text_emb = self.mbart.get_input_embeddings()

        # Project SigLIP CLS ‚Üí mBART hidden size
        self.proj = nn.Linear(vision_dim, self.mbart.config.d_model)

        # Fusion block
        self.fusion = FusionBlock(
            d_model=self.mbart.config.d_model,
            nhead=8,
            dim_ff=2048,
            dropout=0.1,
        )

    def forward(self, input_ids, attention_mask, pixel_values, labels=None):
        batch_size = input_ids.size(0)

        # 1) SigLIP image features (CLS token)
        with torch.no_grad():  # vision backbone is frozen
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]   # [B, hidden_dim]

        img_embed = self.proj(img_feat).unsqueeze(1)               # [B,1,d_model]

        # 2) Text embeddings from mBART shared embedding matrix
        text_embed = self.text_emb(input_ids)                      # [B,L,d_model]

        # 3) Transformer-based fusion
        fused = self.fusion(img_embed, text_embed)                 # [B,1+L,d_model]

        # 4) Attention mask (add image token)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=input_ids.device), attention_mask],
            dim=1,
        )

        # 5) mBART forward using inputs_embeds
        outputs = self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels,
            return_dict=True,
        )
        return outputs

    def generate(self, input_ids, attention_mask, pixel_values, tokenizer,
                 max_length: int | None = None, num_beams: int = 5):
        """
        Generation wrapper used during BLEU evaluation.
        Assumes tokenizer.src_lang / tokenizer.tgt_lang already set.
        """
        if max_length is None:
            max_length = config.max_length

        batch_size = input_ids.size(0)

        with torch.no_grad():
            vision_outputs = self.vision(pixel_values=pixel_values)
            img_feat = vision_outputs.last_hidden_state[:, 0, :]

        img_embed = self.proj(img_feat).unsqueeze(1)
        text_embed = self.text_emb(input_ids)

        fused = self.fusion(img_embed, text_embed)
        fused_mask = torch.cat(
            [torch.ones((batch_size, 1), device=input_ids.device), attention_mask],
            dim=1,
        )

        gen_ids = self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids


# ==============================================================
# TEXT-ONLY MODEL (MBART + LORA)
# ==============================================================

class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        print("üîÑ Loading text-only mBART-50 many-to-many...")
        base_mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base_mbart)

    def forward(self, input_ids, attention_mask, labels=None):
        return self.mbart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True,
        )

    def generate(self, input_ids, attention_mask, tokenizer,
                 max_length: int | None = None, num_beams: int = 5):
        """
        Generation wrapper used during BLEU evaluation.
        Assumes tokenizer.src_lang / tokenizer.tgt_lang already set.
        """
        if max_length is None:
            max_length = config.max_length

        gen_ids = self.mbart.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids


In [None]:
# ==============================================================
# DATASETS (MIXED MULTILINGUAL)
# ==============================================================

class MixedMultiModalDataset(Dataset):
    """
    Single dataset that contains samples from ALL directions.
    Each sample has its own src_lang / tgt_lang.
    """
    def __init__(self, samples: List[Dict[str, Any]],
                 tokenizer: MBart50TokenizerFast,
                 image_processor: SiglipProcessor,
                 img_root: Path):
        self.samples = samples
        self.tok = tokenizer
        self.img_proc = image_processor
        self.img_root = img_root

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        img_id = s["img_id"]
        src = s["src"]
        tgt = s["tgt"]
        src_lang = s["src_lang"]
        tgt_lang = s["tgt_lang"]

        # Set tokenizer languages for this example
        self.tok.src_lang = LANG_CODES[src_lang]
        self.tok.tgt_lang = LANG_CODES[tgt_lang]

        # Encode source
        enc = self.tok(
            src, max_length=config.max_length,
            padding="max_length", truncation=True,
            return_tensors="pt"
        )

        # Encode target
        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt, max_length=config.max_length,
                padding="max_length", truncation=True,
                return_tensors="pt"
            )

        labels = dec["input_ids"].squeeze(0)
        labels[labels == self.tok.pad_token_id] = -100

        # FAST‚ÄîLoad image from LOCAL SSD
        img = safe_load_image(img_id, self.img_root)
        pv = self.img_proc(images=img, return_tensors="pt")["pixel_values"].squeeze(0)

        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": labels,
            "pixel_values": pv,
            "target_text": tgt,
            "direction": f"{src_lang}->{tgt_lang}",
        }


class MixedTextOnlyDataset(Dataset):
    """
    Text-only dataset containing samples from ALL directions.
    """
    def __init__(self, samples: List[Dict[str, Any]],
                 tokenizer: MBart50TokenizerFast):
        self.samples = samples
        self.tok = tokenizer

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        src = s["src"]
        tgt = s["tgt"]
        src_lang = s["src_lang"]
        tgt_lang = s["tgt_lang"]

        self.tok.src_lang = LANG_CODES[src_lang]
        self.tok.tgt_lang = LANG_CODES[tgt_lang]

        enc = self.tok(
            src, max_length=config.max_length,
            padding="max_length", truncation=True,
            return_tensors="pt"
        )
        with self.tok.as_target_tokenizer():
            dec = self.tok(
                tgt, max_length=config.max_length,
                padding="max_length", truncation=True,
                return_tensors="pt"
            )

        labels = dec["input_ids"].squeeze(0)
        labels[labels == self.tok.pad_token_id] = -100

        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": labels,
            "target_text": tgt,
            "direction": f"{src_lang}->{tgt_lang}",
        }


# ==============================================================
# LOAD SPLITTED DATA FROM local dataset
# ==============================================================

def load_split(root, split, src_lang, tgt_lang, limit):
    """
    Loads data from: root/data/task1/raw/{split}/{split}.{lang}
    And image IDs from root/data/task1/image_splits/{split}.txt
    """
    root = Path(root)
    raw = root / "data" / "task1" / "raw" / split
    id_file = root / "data" / "task1" / "image_splits" / f"{split}.txt"

    src_file = raw / f"{split}.{src_lang}"
    tgt_file = raw / f"{split}.{tgt_lang}"

    print(f"üîé Checking files for {split} {src_lang}‚Üí{tgt_lang}")
    print("   ", src_file)
    print("   ", tgt_file)
    print("   ", id_file)

    if not src_file.exists() or not tgt_file.exists() or not id_file.exists():
        print(f"‚ùå Missing one or more files for {split} ({src_lang}‚Üí{tgt_lang})")
        return [], [], []

    ids = [l.strip() for l in open(id_file, encoding="utf-8") if l.strip()]
    src = [l.strip() for l in open(src_file, encoding="utf-8") if l.strip()]
    tgt = [l.strip() for l in open(tgt_file, encoding="utf-8") if l.strip()]

    n = min(len(ids), len(src), len(tgt), limit)
    print(f"‚úÖ Loaded {n} samples ({split}: {src_lang}‚Üí{tgt_lang})")
    return ids[:n], src[:n], tgt[:n]


# ==============================================================
# FAST BLEU FOR MULTIMODAL
# ==============================================================

def compute_bleu_multimodal(model: nn.Module,
                            dataset: Dataset,
                            tokenizer: MBart50TokenizerFast):
    model.eval()

    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True,
    )

    preds_all, refs_all = [], []
    preds_by_dir = {}
    refs_by_dir = {}

    with torch.no_grad():
        for batch in tqdm(loader, desc="[MM BLEU]", leave=False):
            direction = batch["direction"][0]
            src_lang, tgt_lang = direction.split("->")

            tokenizer.src_lang = LANG_CODES[src_lang]
            tokenizer.tgt_lang = LANG_CODES[tgt_lang]

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            pv = batch["pixel_values"].to(device)
            tgt_texts = batch["target_text"]

            gen_ids = model.generate(ids, mask, pv, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            pred = decoded[0]
            ref = tgt_texts[0]

            preds_all.append(pred)
            refs_all.append([ref])

            preds_by_dir.setdefault(direction, []).append(pred)
            refs_by_dir.setdefault(direction, []).append([ref])

    overall_bleu = sacrebleu_metric.compute(predictions=preds_all,
                                            references=refs_all)["score"]

    bleu_by_dir = {}
    print("\n   üìä Multimodal BLEU by direction:")
    for direction, preds in preds_by_dir.items():
        refs = refs_by_dir[direction]
        score = sacrebleu_metric.compute(predictions=preds,
                                         references=refs)["score"]
        bleu_by_dir[direction] = score
        print(f"     ‚Ä¢ {direction}: {score:.2f}")

    print(f"   üîµ Multimodal OVERALL BLEU: {overall_bleu:.2f}")
    return overall_bleu, bleu_by_dir


# ==============================================================
# FAST BLEU FOR TEXT-ONLY
# ==============================================================

def compute_bleu_text(model: nn.Module,
                      dataset: Dataset,
                      tokenizer: MBart50TokenizerFast):
    model.eval()

    loader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    preds_all, refs_all = [], []
    preds_by_dir = {}
    refs_by_dir = {}

    with torch.no_grad():
        for batch in tqdm(loader, desc="[TXT BLEU]", leave=False):
            direction = batch["direction"][0]
            src_lang, tgt_lang = direction.split("->")

            tokenizer.src_lang = LANG_CODES[src_lang]
            tokenizer.tgt_lang = LANG_CODES[tgt_lang]

            ids = batch["input_ids"].to(device)
            mask = batch["attention_mask"].to(device)
            tgt_texts = batch["target_text"]

            gen_ids = model.generate(ids, mask, tokenizer)
            decoded = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)

            pred = decoded[0]
            ref = tgt_texts[0]

            preds_all.append(pred)
            refs_all.append([ref])

            preds_by_dir.setdefault(direction, []).append(pred)
            refs_by_dir.setdefault(direction, []).append([ref])

    overall_bleu = sacrebleu_metric.compute(predictions=preds_all,
                                            references=refs_all)["score"]

    bleu_by_dir = {}
    print("\n   üìä Text-only BLEU by direction:")
    for direction, preds in preds_by_dir.items():
        refs = refs_by_dir[direction]
        score = sacrebleu_metric.compute(predictions=preds,
                                         references=refs)["score"]
        bleu_by_dir[direction] = score
        print(f"     ‚Ä¢ {direction}: {score:.2f}")

    print(f"   üîµ Text-only OVERALL BLEU: {overall_bleu:.2f}")
    return overall_bleu, bleu_by_dir


In [None]:
# ==============================================================
# TRAINING LOOPS (A100-OPTIMIZED, BLEU EARLY STOPPING)
# ==============================================================

def make_train_loader(dataset: Dataset) -> DataLoader:
    """
    Fast DataLoader tuned for A100 on Colab Pro+.
    """
    return DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,          # A100 can handle this
        pin_memory=True,        # faster host‚ÜíGPU transfers
        persistent_workers=True # keeps workers alive between epochs
    )


def train_multimodal_model(
    tokenizer: MBart50TokenizerFast,
    train_ds: Dataset,
    val_ds: Dataset,
    local_save_dir: Path,
    drive_save_dir: Path | None,
):
    # ---------- DataLoader ----------
    train_loader = make_train_loader(train_ds)

    # ---------- Model / Optim / Sched ----------
    model = MultiModalModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = max(len(train_loader) * config.num_epochs, 1)
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps,
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None

    best_bleu = 0.0
    no_improve = 0

    # ---------- TRAIN LOOP ----------
    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [MULTIMODAL] Epoch {epoch}/{config.num_epochs} ‚Äî ALL 6 DIRECTIONS")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc="[MM Train MIXED]", mininterval=1.0)
        for batch in loop:
            opt.zero_grad(set_to_none=True)

            ids = batch["input_ids"].to(device, non_blocking=True)
            mask = batch["attention_mask"].to(device, non_blocking=True)
            lbl = batch["labels"].to(device, non_blocking=True)
            pv = batch["pixel_values"].to(device, non_blocking=True)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, pv, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, pv, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (multimodal)")
            break

        avg_train_loss = total_loss / len(train_loader)
        print(f"   üîª Multimodal avg TRAIN loss: {avg_train_loss:.4f}")

        # ---------- VALIDATION BLEU + EARLY STOP ----------
        overall_bleu, _ = compute_bleu_multimodal(model, val_ds, tokenizer)

        improved = overall_bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = overall_bleu
            no_improve = 0

            filename = "siglip_fusion_lora_all6_mm_best.pt"
            local_path = local_save_dir / filename
            torch.save(model.state_dict(), local_path)
            print(f"   üíæ Saved best MULTIMODAL model (local) ‚Üí {local_path}")

            if drive_save_dir is not None:
                drive_path = drive_save_dir / filename
                torch.save(model.state_dict(), drive_path)
                print(f"   üíæ Saved best MULTIMODAL model (drive) ‚Üí {drive_path}")
        else:
            no_improve += 1
            print(f"   ‚è∏ No BLEU improvement. patience={no_improve}/{config.patience}")

        if no_improve >= config.patience:
            print(f"üõë Early stopping MULTIMODAL training at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished MULTIMODAL training | Best OVERALL BLEU: {best_bleu:.2f}")
    return best_bleu


def train_text_model(
    tokenizer: MBart50TokenizerFast,
    train_ds: Dataset,
    val_ds: Dataset,
    local_save_dir: Path,
    drive_save_dir: Path | None,
):
    # ---------- DataLoader ----------
    train_loader = make_train_loader(train_ds)

    # ---------- Model / Optim / Sched ----------
    model = TextOnlyModel().to(device)
    params = [p for p in model.parameters() if p.requires_grad]
    opt = torch.optim.AdamW(params, lr=config.learning_rate)

    total_steps = max(len(train_loader) * config.num_epochs, 1)
    scheduler = get_linear_schedule_with_warmup(
        opt,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps,
    )

    scaler = torch.cuda.amp.GradScaler() if config.use_amp and device.type == "cuda" else None

    best_bleu = 0.0
    no_improve = 0

    # ---------- TRAIN LOOP ----------
    for epoch in range(1, config.num_epochs + 1):
        print(f"\nüìç [TEXT-ONLY] Epoch {epoch}/{config.num_epochs} ‚Äî ALL 6 DIRECTIONS")
        model.train()
        total_loss = 0.0

        loop = tqdm(train_loader, desc="[TXT Train MIXED]", mininterval=1.0)
        for batch in loop:
            opt.zero_grad(set_to_none=True)

            ids = batch["input_ids"].to(device, non_blocking=True)
            mask = batch["attention_mask"].to(device, non_blocking=True)
            lbl = batch["labels"].to(device, non_blocking=True)

            try:
                if scaler:
                    with torch.cuda.amp.autocast():
                        out = model(ids, mask, labels=lbl)
                        loss = out.loss
                    scaler.scale(loss).backward()
                    scaler.unscale_(opt)
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    scaler.step(opt)
                    scaler.update()
                else:
                    out = model(ids, mask, labels=lbl)
                    loss = out.loss
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(params, config.max_grad_norm)
                    opt.step()
            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    print("‚ö†Ô∏è CUDA OOM on this batch, skipping.")
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    continue
                else:
                    raise e

            scheduler.step()
            total_loss += float(loss)
            loop.set_postfix(loss=float(loss))

        if len(train_loader) == 0:
            print("‚ö†Ô∏è No batches in train_loader (text-only)")
            break

        avg_train_loss = total_loss / len(train_loader)
        print(f"   üîª Text-only avg TRAIN loss: {avg_train_loss:.4f}")

        # ---------- VALIDATION BLEU + EARLY STOP ----------
        overall_bleu, _ = compute_bleu_text(model, val_ds, tokenizer)

        improved = overall_bleu > best_bleu + config.min_delta
        if improved:
            best_bleu = overall_bleu
            no_improve = 0

            filename = "mbart_lora_all6_text_best.pt"
            local_path = local_save_dir / filename
            torch.save(model.state_dict(), local_path)
            print(f"   üíæ Saved best TEXT-ONLY model (local) ‚Üí {local_path}")

            if drive_save_dir is not None:
                drive_path = drive_save_dir / filename
                torch.save(model.state_dict(), drive_path)
                print(f"   üíæ Saved best TEXT-ONLY model (drive) ‚Üí {drive_path}")
        else:
            no_improve += 1
            print(f"   ‚è∏ No BLEU improvement. patience={no_improve}/{config.patience}")

        if no_improve >= config.patience:
            print(f"üõë Early stopping TEXT-ONLY training at epoch {epoch}")
            break

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Finished TEXT-ONLY training | Best OVERALL BLEU: {best_bleu:.2f}")
    return best_bleu


In [None]:
# ==============================================================
# MAIN (FINAL A100-OPTIMIZED VERSION)
# ==============================================================

def check_gpu():
    """Prints GPU details and warns if NOT A100/H100."""
    import subprocess
    try:
        gpu_info = subprocess.check_output("nvidia-smi -L", shell=True).decode()
        print("GPU Info ‚Üí", gpu_info)
        if "A100" not in gpu_info and "H100" not in gpu_info:
            print("‚ö†Ô∏è WARNING: You did NOT receive an A100/H100")
            print("Training will be 10‚Äì20√ó slower. Restart runtime.")
        else:
            print("‚úÖ Great! Premium GPU detected.")
    except Exception as e:
        print("‚ö†Ô∏è Could not check GPU:", e)


def main():
    # -------------------------------
    # GPU Check
    # -------------------------------
    check_gpu()
    print("Using device:", device)

    # -------------------------------
    # DATASET FIX ‚Üí Use local folder if already present
    # -------------------------------
    local_dataset = "/content/multi30k-dataset"
    drive_dataset = "/content/drive/MyDrive/dataset/multi30k-dataset"
    local_fast_dataset = "/content/multi30k-dataset-local"  # Used only if needed

    if os.path.exists(local_dataset):
        print("üìÅ Local dataset found ‚Üí using /content/multi30k-dataset")
        config.data_root = local_dataset

    elif os.path.exists(drive_dataset):
        # Copy ONCE to high-speed local storage
        print("üìÇ Copying dataset from Drive ‚Üí /content (fast SSD)...")
        shutil.copytree(drive_dataset, local_fast_dataset)
        print("‚úÖ Dataset copy complete.")
        config.data_root = local_fast_dataset

    else:
        print("‚ùå No dataset found! Please upload multi30k-dataset.")
        return

    print("üìå Using dataset:", config.data_root)

    # -------------------------------
    # Create save dirs
    # -------------------------------
    local_save_dir = Path(config.save_dir)
    local_save_dir.mkdir(parents=True, exist_ok=True)

    drive_save_dir = None
    drive_root = Path("/content/drive/MyDrive")
    if drive_root.exists():
        drive_save_dir = Path(config.drive_save_dir)
        drive_save_dir.mkdir(parents=True, exist_ok=True)
        print(f"üíæ Drive save dir: {drive_save_dir}")
    else:
        print("‚ö†Ô∏è Google Drive missing ‚Äî will only save locally.")

    # -------------------------------
    # Load tokenizer & SigLIP processor
    # -------------------------------
    print("üîÑ Loading MBart tokenizer & SigLIP processor...")
    tokenizer = MBart50TokenizerFast.from_pretrained(
        "facebook/mbart-large-50-many-to-many-mmt"
    )
    image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

    # -------------------------------
    # Save config (local + drive)
    # -------------------------------
    cfg = asdict(config)
    cfg_path_local = local_save_dir / "config_siglip_fusion_lora_all6.json"
    with open(cfg_path_local, "w") as f:
        json.dump(cfg, f, indent=2)
    print(f"üíæ Config saved (local) at: {cfg_path_local}")

    if drive_save_dir is not None:
        cfg_path_drive = drive_save_dir / "config_siglip_fusion_lora_all6.json"
        with open(cfg_path_drive, "w") as f:
            json.dump(cfg, f, indent=2)
        print(f"üíæ Config saved (drive) at: {cfg_path_drive}")

    # -------------------------------
    # Load ALL 6 directions
    # -------------------------------
    train_samples = []
    val_samples = []

    for src, tgt in config.directions:
        print("\n" + "="*70)
        print(f"üèÅ LOADING DATA FOR: {src.upper()} ‚Üí {tgt.upper()}")
        print("="*70)

        train_ids, train_src, train_tgt = load_split(
            config.data_root, "train", src, tgt, config.max_train_samples
        )
        val_ids, val_src, val_tgt = load_split(
            config.data_root, "val", src, tgt, config.max_val_samples
        )

        for img_id, s_txt, t_txt in zip(train_ids, train_src, train_tgt):
            train_samples.append({
                "img_id": img_id,
                "src": s_txt,
                "tgt": t_txt,
                "src_lang": src,
                "tgt_lang": tgt,
            })

        for img_id, s_txt, t_txt in zip(val_ids, val_src, val_tgt):
            val_samples.append({
                "img_id": img_id,
                "src": s_txt,
                "tgt": t_txt,
                "src_lang": src,
                "tgt_lang": tgt,
            })

    print(f"\nüì¶ TOTAL train samples (ALL directions): {len(train_samples)}")
    print(f"üì¶ TOTAL val samples   (ALL directions): {len(val_samples)}")

    if len(train_samples) == 0 or len(val_samples) == 0:
        print("‚ùå No data loaded. Check dataset path.")
        return

    img_root = Path(config.data_root) / config.image_dir

    # -------------------------------
    # Build datasets
    # -------------------------------
    print("üóÇ Building PyTorch datasets...")
    train_mm = MixedMultiModalDataset(train_samples, tokenizer, image_processor, img_root)
    val_mm   = MixedMultiModalDataset(val_samples, tokenizer, image_processor, img_root)

    train_txt = MixedTextOnlyDataset(train_samples, tokenizer)
    val_txt   = MixedTextOnlyDataset(val_samples, tokenizer)

    # -------------------------------
    # Train MULTIMODAL
    # -------------------------------
    print("\nüöÄ Starting MULTIMODAL training...")
    best_mm_bleu = train_multimodal_model(
        tokenizer,
        train_mm,
        val_mm,
        local_save_dir,
        drive_save_dir,
    )

    # -------------------------------
    # Train TEXT-ONLY
    # -------------------------------
    print("\nüöÄ Starting TEXT-ONLY training...")
    best_txt_bleu = train_text_model(
        tokenizer,
        train_txt,
        val_txt,
        local_save_dir,
        drive_save_dir,
    )

    # -------------------------------
    # Final summary
    # -------------------------------
    print("\nüìä FINAL BEST BLEU SCORES")
    print(f"   üåà Multimodal (ALL dirs): {best_mm_bleu:.2f}")
    print(f"   ‚ú® Text-only (ALL dirs):  {best_txt_bleu:.2f}")
    print("\nüéâ Training Completed Successfully!")


# ==============================================================
# ENTRY POINT
# ==============================================================

if __name__ == "__main__":
    main()


GPU Info ‚Üí GPU 0: NVIDIA A100-SXM4-80GB (UUID: GPU-dd22b5a6-5b92-1ecb-39d1-b97b837b1dcf)

‚úÖ Great! Premium GPU detected.
Using device: cuda
üìÅ Local dataset found ‚Üí using /content/multi30k-dataset
üìå Using dataset: /content/multi30k-dataset
üíæ Drive save dir: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion
üîÑ Loading MBart tokenizer & SigLIP processor...
üíæ Config saved (local) at: /content/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora_all6.json
üíæ Config saved (drive) at: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/config_siglip_fusion_lora_all6.json

üèÅ LOADING DATA FOR: EN ‚Üí DE
üîé Checking files for train en‚Üíde
    /content/multi30k-dataset/data/task1/raw/train/train.en
    /content/multi30k-dataset/data/task1/raw/train/train.de
    /content/multi30k-dataset/data/task1/image_splits/train.txt
‚úÖ Loaded 15000 samples (train: en‚Üíde)
üîé Checking files for val en‚Üíde
    /cont

[MM Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [51:11<00:00,  1.09s/it, loss=1.01]


   üîª Multimodal avg TRAIN loss: 1.1299





   üìä Multimodal BLEU by direction:
     ‚Ä¢ en->de: 39.41
     ‚Ä¢ en->fr: 46.90
     ‚Ä¢ de->en: 42.53
     ‚Ä¢ de->fr: 33.34
     ‚Ä¢ fr->en: 47.03
     ‚Ä¢ fr->de: 31.17
   üîµ Multimodal OVERALL BLEU: 40.46
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt

üìç [MULTIMODAL] Epoch 2/6 ‚Äî ALL 6 DIRECTIONS


[MM Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [07:52<00:00,  5.96it/s, loss=0.927]


   üîª Multimodal avg TRAIN loss: 1.0040





   üìä Multimodal BLEU by direction:
     ‚Ä¢ en->de: 40.27
     ‚Ä¢ en->fr: 48.75
     ‚Ä¢ de->en: 43.24
     ‚Ä¢ de->fr: 36.53
     ‚Ä¢ fr->en: 47.96
     ‚Ä¢ fr->de: 32.21
   üîµ Multimodal OVERALL BLEU: 41.83
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt

üìç [MULTIMODAL] Epoch 3/6 ‚Äî ALL 6 DIRECTIONS


[MM Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [08:00<00:00,  5.85it/s, loss=1.08]


   üîª Multimodal avg TRAIN loss: 0.9613





   üìä Multimodal BLEU by direction:
     ‚Ä¢ en->de: 40.75
     ‚Ä¢ en->fr: 51.01
     ‚Ä¢ de->en: 43.56
     ‚Ä¢ de->fr: 37.52
     ‚Ä¢ fr->en: 50.36
     ‚Ä¢ fr->de: 33.49
   üîµ Multimodal OVERALL BLEU: 43.11
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt

üìç [MULTIMODAL] Epoch 4/6 ‚Äî ALL 6 DIRECTIONS


[MM Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [07:58<00:00,  5.88it/s, loss=0.746]


   üîª Multimodal avg TRAIN loss: 0.9339





   üìä Multimodal BLEU by direction:
     ‚Ä¢ en->de: 40.93
     ‚Ä¢ en->fr: 51.79
     ‚Ä¢ de->en: 43.78
     ‚Ä¢ de->fr: 38.27
     ‚Ä¢ fr->en: 50.71
     ‚Ä¢ fr->de: 33.51
   üîµ Multimodal OVERALL BLEU: 43.65
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt

üìç [MULTIMODAL] Epoch 5/6 ‚Äî ALL 6 DIRECTIONS


[MM Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [07:58<00:00,  5.88it/s, loss=0.732]


   üîª Multimodal avg TRAIN loss: 0.9157





   üìä Multimodal BLEU by direction:
     ‚Ä¢ en->de: 41.01
     ‚Ä¢ en->fr: 52.03
     ‚Ä¢ de->en: 44.97
     ‚Ä¢ de->fr: 39.68
     ‚Ä¢ fr->en: 50.93
     ‚Ä¢ fr->de: 33.81
   üîµ Multimodal OVERALL BLEU: 44.22
   üíæ Saved best MULTIMODAL model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt
   üíæ Saved best MULTIMODAL model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_all6_mm_best.pt

üìç [MULTIMODAL] Epoch 6/6 ‚Äî ALL 6 DIRECTIONS


[MM Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [07:57<00:00,  5.89it/s, loss=0.985]


   üîª Multimodal avg TRAIN loss: 0.9046





   üìä Multimodal BLEU by direction:
     ‚Ä¢ en->de: 40.48
     ‚Ä¢ en->fr: 52.81
     ‚Ä¢ de->en: 44.72
     ‚Ä¢ de->fr: 39.91
     ‚Ä¢ fr->en: 50.28
     ‚Ä¢ fr->de: 34.42
   üîµ Multimodal OVERALL BLEU: 44.25
   ‚è∏ No BLEU improvement. patience=1/3
‚úÖ Finished MULTIMODAL training | Best OVERALL BLEU: 44.22

üöÄ Starting TEXT-ONLY training...
üîÑ Loading text-only mBART-50 many-to-many...
‚úÖ LoRA applied to mBART (targets: ['q_proj', 'v_proj'] )
trainable params: 1,179,648 || all params: 612,059,136 || trainable%: 0.1927

üìç [TEXT-ONLY] Epoch 1/6 ‚Äî ALL 6 DIRECTIONS


[TXT Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [05:52<00:00,  7.99it/s, loss=1.35]


   üîª Text-only avg TRAIN loss: 1.1721





   üìä Text-only BLEU by direction:
     ‚Ä¢ en->de: 37.92
     ‚Ä¢ en->fr: 45.61
     ‚Ä¢ de->en: 41.47
     ‚Ä¢ de->fr: 30.87
     ‚Ä¢ fr->en: 45.44
     ‚Ä¢ fr->de: 28.68
   üîµ Text-only OVERALL BLEU: 38.67
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt

üìç [TEXT-ONLY] Epoch 2/6 ‚Äî ALL 6 DIRECTIONS


[TXT Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [05:58<00:00,  7.85it/s, loss=1.15]


   üîª Text-only avg TRAIN loss: 1.0541





   üìä Text-only BLEU by direction:
     ‚Ä¢ en->de: 38.99
     ‚Ä¢ en->fr: 46.60
     ‚Ä¢ de->en: 41.84
     ‚Ä¢ de->fr: 33.11
     ‚Ä¢ fr->en: 46.13
     ‚Ä¢ fr->de: 30.38
   üîµ Text-only OVERALL BLEU: 39.91
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt

üìç [TEXT-ONLY] Epoch 3/6 ‚Äî ALL 6 DIRECTIONS


[TXT Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [05:53<00:00,  7.97it/s, loss=1.02]


   üîª Text-only avg TRAIN loss: 1.0238





   üìä Text-only BLEU by direction:
     ‚Ä¢ en->de: 39.09
     ‚Ä¢ en->fr: 48.03
     ‚Ä¢ de->en: 42.64
     ‚Ä¢ de->fr: 34.69
     ‚Ä¢ fr->en: 46.58
     ‚Ä¢ fr->de: 31.22
   üîµ Text-only OVERALL BLEU: 40.80
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt

üìç [TEXT-ONLY] Epoch 4/6 ‚Äî ALL 6 DIRECTIONS


[TXT Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [05:57<00:00,  7.88it/s, loss=1.04]


   üîª Text-only avg TRAIN loss: 1.0067





   üìä Text-only BLEU by direction:
     ‚Ä¢ en->de: 40.03
     ‚Ä¢ en->fr: 48.66
     ‚Ä¢ de->en: 42.89
     ‚Ä¢ de->fr: 34.79
     ‚Ä¢ fr->en: 46.91
     ‚Ä¢ fr->de: 30.67
   üîµ Text-only OVERALL BLEU: 41.06
   ‚è∏ No BLEU improvement. patience=1/3

üìç [TEXT-ONLY] Epoch 5/6 ‚Äî ALL 6 DIRECTIONS


[TXT Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [05:54<00:00,  7.94it/s, loss=1.3]


   üîª Text-only avg TRAIN loss: 0.9959





   üìä Text-only BLEU by direction:
     ‚Ä¢ en->de: 39.74
     ‚Ä¢ en->fr: 49.67
     ‚Ä¢ de->en: 42.80
     ‚Ä¢ de->fr: 35.24
     ‚Ä¢ fr->en: 46.75
     ‚Ä¢ fr->de: 31.83
   üîµ Text-only OVERALL BLEU: 41.45
   üíæ Saved best TEXT-ONLY model (local) ‚Üí /content/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt
   üíæ Saved best TEXT-ONLY model (drive) ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/mbart_lora_all6_text_best.pt

üìç [TEXT-ONLY] Epoch 6/6 ‚Äî ALL 6 DIRECTIONS


[TXT Train MIXED]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2813/2813 [05:59<00:00,  7.82it/s, loss=1.05]


   üîª Text-only avg TRAIN loss: 0.9905





   üìä Text-only BLEU by direction:
     ‚Ä¢ en->de: 40.10
     ‚Ä¢ en->fr: 49.97
     ‚Ä¢ de->en: 42.53
     ‚Ä¢ de->fr: 35.20
     ‚Ä¢ fr->en: 46.97
     ‚Ä¢ fr->de: 31.78
   üîµ Text-only OVERALL BLEU: 41.55
   ‚è∏ No BLEU improvement. patience=1/3
‚úÖ Finished TEXT-ONLY training | Best OVERALL BLEU: 41.45

üìä FINAL BEST BLEU SCORES
   üåà Multimodal (ALL dirs): 44.22
   ‚ú® Text-only (ALL dirs):  41.45

üéâ Training Completed Successfully!


In [None]:
#testing with test data

In [None]:
# ================================================================
# üåç FINAL EVALUATION SCRIPT (ALL 6 DIRECTIONS)
# - Multimodal vs Text-only
# - Loads all6_best models
# - Evaluates test_2016_flickr (or 2017)
# - Saves BLEU + JSON + TXT
# ================================================================

import os
import json
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from PIL import Image, ImageFile
import evaluate
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------ HF + PEFT imports ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except:
    %pip install -q transformers peft accelerate
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# ------------------ MOUNT DRIVE ------------------
from google.colab import drive
drive.mount("/content/drive")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ================================================================
# üîß PATHS
# ================================================================
ROOT = Path("/content/drive/MyDrive")

MODEL_DIR = ROOT / "multimodal_translation_models_siglip_lora_fusion"
DATASET = ROOT / "dataset/multi30k-dataset"
TEST_DIR = DATASET / "data/task1/raw/test_2016_flickr"
SPLIT_FILE = DATASET / "data/task1/image_splits/test_2016_flickr.txt"
IMG_ROOT = DATASET / "flickr30k-images"

OUT_DIR = MODEL_DIR / "test2016_all6_eval"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("MODEL_DIR:", MODEL_DIR)
print("TEST_DIR:", TEST_DIR)
print("OUT_DIR:", OUT_DIR)

# ================================================================
# LOAD CONFIG
# ================================================================
cfg_path = MODEL_DIR / "config_siglip_fusion_lora_all6.json"
config_dict = json.load(open(cfg_path))
import types
config = types.SimpleNamespace(**config_dict)

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

DIRECTIONS = [
    ("en", "de"),
    ("en", "fr"),
    ("de", "en"),
    ("de", "fr"),
    ("fr", "en"),
    ("fr", "de"),
]

# ================================================================
# SAFE IMAGE LOADER
# ================================================================
def safe_load_image(image_id: str, root: Path):
    base = image_id.strip()
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[:-len(ext)]
            break
    for name in [base+".jpg", base+".jpeg", base+".png"]:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                pass
    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# LORA APPLIER
# ================================================================
def apply_lora_to_mbart(mbart):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    return get_peft_model(mbart, cfg)

# ================================================================
# FUSION MODEL
# ================================================================
class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8, dim_feedforward=2048,
            dropout=0.1, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        x = torch.cat([img_emb, txt_emb], dim=1)
        return self.encoder(x)

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)
        for p in self.vision.parameters(): p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)
        self.text_emb = self.mbart.get_input_embeddings()

        vis_dim = self.vision.config.hidden_size
        self.proj = nn.Linear(vis_dim, self.mbart.config.d_model)

        self.fusion = FusionBlock(self.mbart.config.d_model)

    def generate(self, input_ids, mask, pixel, tok):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel).last_hidden_state[:,0,:]
        img = self.proj(vis).unsqueeze(1)
        txt = self.text_emb(input_ids)
        fused = self.fusion(img, txt)
        fused_mask = torch.cat([torch.ones((input_ids.size(0),1),device=device), mask], dim=1)

        return self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=config.max_length,
            num_beams=4,
            forced_bos_token_id=tok.lang_code_to_id[tok.tgt_lang]
        )

# ================================================================
# TEXT ONLY MODEL
# ================================================================
class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)

    def generate(self, input_ids, mask, tok):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=mask,
            max_length=config.max_length,
            num_beams=4,
            forced_bos_token_id=tok.lang_code_to_id[tok.tgt_lang]
        )

# ================================================================
# LOAD TEST SPLIT
# ================================================================
def load_test(src, tgt):
    ids = open(SPLIT_FILE).read().splitlines()
    src_txt = open(TEST_DIR / f"test_2016_flickr.{src}").read().splitlines()
    tgt_txt = open(TEST_DIR / f"test_2016_flickr.{tgt}").read().splitlines()
    n = min(len(ids), len(src_txt), len(tgt_txt))
    return ids[:n], src_txt[:n], tgt_txt[:n]

# ================================================================
# EVALUATE ONE DIRECTION
# ================================================================
def evaluate_direction(src, tgt):
    print(f"\n============================")
    print(f"   EVAL: {src} ‚Üí {tgt}")
    print(f"============================")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ids, src_txt, tgt_txt = load_test(src, tgt)
    refs = [[t] for t in tgt_txt]

    # ---- LOAD MODELS ----
    mm = MultiModalModel().to(device)
    txt = TextOnlyModel().to(device)

    mm.load_state_dict(torch.load(MODEL_DIR/"siglip_fusion_lora_all6_mm_best.pt"))
    txt.load_state_dict(torch.load(MODEL_DIR/"mbart_lora_all6_text_best.pt"))

    mm.eval()
    txt.eval()

    preds_mm, preds_txt = [], []
    results = []

    for i in tqdm(range(len(ids))):
        enc = tokenizer(src_txt[i], max_length=config.max_length,
                        padding="max_length", truncation=True, return_tensors="pt")

        ids_t = enc["input_ids"].to(device)
        mask_t = enc["attention_mask"].to(device)

        img = safe_load_image(ids[i], IMG_ROOT)
        pixel = image_processor(images=img, return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            mm_out = mm.generate(ids_t, mask_t, pixel, tokenizer)
            txt_out = txt.generate(ids_t, mask_t, tokenizer)

        mm_pred = tokenizer.decode(mm_out[0], skip_special_tokens=True)
        txt_pred = tokenizer.decode(txt_out[0], skip_special_tokens=True)

        preds_mm.append(mm_pred)
        preds_txt.append(txt_pred)

        results.append({
            "image_id": ids[i],
            "source": src_txt[i],
            "target": tgt_txt[i],
            "mm": mm_pred,
            "txt": txt_pred,
        })

    # ---- SAVE JSON + TXT ----
    json_path = OUT_DIR / f"pred_{src}_{tgt}.json"
    txt_path = OUT_DIR / f"pred_{src}_{tgt}.txt"

    json.dump(results, open(json_path,"w"), indent=2, ensure_ascii=False)

    with open(txt_path,"w") as f:
        for r in results:
            f.write(f"{r['image_id']}\nSRC: {r['source']}\nTRG: {r['target']}\nMM:  {r['mm']}\nTXT: {r['txt']}\n")
            f.write("-"*50+"\n")

    print("Saved:", json_path)
    print("Saved:", txt_path)

    sacre = evaluate.load("sacrebleu")
    bleu_mm  = sacre.compute(predictions=preds_mm,  references=refs)["score"]
    bleu_txt = sacre.compute(predictions=preds_txt, references=refs)["score"]

    print(f"BLEU (MM ): {bleu_mm:.2f}")
    print(f"BLEU (TXT): {bleu_txt:.2f}")

    return bleu_mm, bleu_txt

# ================================================================
# MAIN LOOP
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

all_scores = {}

for src, tgt in DIRECTIONS:
    mm, txt = evaluate_direction(src, tgt)
    all_scores[f"{src}->{tgt}"] = {"mm": mm, "txt": txt}

json.dump(all_scores, open(OUT_DIR/"bleu_scores_all6.json","w"), indent=2)
print("\n==============================")
print("FINAL BLEU SCORES (ALL 6)")
print("==============================")
print(json.dumps(all_scores, indent=2))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
MODEL_DIR: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion
TEST_DIR: /content/drive/MyDrive/dataset/multi30k-dataset/data/task1/raw/test_2016_flickr
OUT_DIR: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/711 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


   EVAL: en ‚Üí de


config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [29:11<00:00,  1.75s/it]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_de.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_de.txt


Downloading builder script: 0.00B [00:00, ?B/s]

BLEU (MM ): 37.74
BLEU (TXT): 37.18

   EVAL: en ‚Üí fr


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [17:41<00:00,  1.06s/it]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_fr.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_fr.txt
BLEU (MM ): 50.71
BLEU (TXT): 47.77

   EVAL: de ‚Üí en


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [14:48<00:00,  1.13it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_en.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_en.txt
BLEU (MM ): 44.38
BLEU (TXT): 43.73

   EVAL: de ‚Üí fr


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [17:24<00:00,  1.04s/it]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_fr.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_fr.txt
BLEU (MM ): 37.67
BLEU (TXT): 34.54

   EVAL: fr ‚Üí en


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [15:02<00:00,  1.11it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_en.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_en.txt
BLEU (MM ): 52.08
BLEU (TXT): 48.28

   EVAL: fr ‚Üí de


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [16:36<00:00,  1.00it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_de.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_de.txt
BLEU (MM ): 32.82
BLEU (TXT): 30.40

FINAL BLEU SCORES (ALL 6)
{
  "en->de": {
    "mm": 37.737588759500774,
    "txt": 37.184791873340004
  },
  "en->fr": {
    "mm": 50.71343796951792,
    "txt": 47.770452343421255
  },
  "de->en": {
    "mm": 44.381857532012006,
    "txt": 43.73382710635367
  },
  "de->fr": {
    "mm": 37.6747671364791,
    "txt": 34.544032013223585
  },
  "fr->en": {
    "mm": 52.083662830704604,
    "txt": 48.28086343390184
  },
  "fr->de": {
    "mm": 32.82178041558421,
    "txt": 30.40316470775001
  }
}


In [None]:
# ================================================================
# üåç FINAL EVALUATION SCRIPT (ALL 6 DIRECTIONS)
# - Multimodal vs Text-only
# - Loads all6_best models
# - Evaluates test_2016_flickr (or 2017)
# - Saves BLEU + JSON + TXT
# ================================================================

import os
import json
from pathlib import Path
from typing import List, Tuple
import torch
import torch.nn as nn
from PIL import Image, ImageFile
import evaluate
from tqdm import tqdm

ImageFile.LOAD_TRUNCATED_IMAGES = True

# ------------------ HF + PEFT imports ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType
except:
    %pip install -q transformers peft accelerate
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType

# ------------------ MOUNT DRIVE ------------------
from google.colab import drive
drive.mount("/content/drive")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ================================================================
# üîß PATHS
# ================================================================
ROOT = Path("/content/drive/MyDrive")

MODEL_DIR = ROOT / "multimodal_translation_models_siglip_lora_fusion"
DATASET = ROOT / "dataset/multi30k-dataset"
TEST_DIR = DATASET / "data/task1/raw/test_2017_flickr"
SPLIT_FILE = DATASET / "data/task1/image_splits/test_2017_flickr.txt"
IMG_ROOT = DATASET / "flickr30k-images"

OUT_DIR = MODEL_DIR / "test2016_all6_eval"
OUT_DIR.mkdir(parents=True, exist_ok=True)

print("MODEL_DIR:", MODEL_DIR)
print("TEST_DIR:", TEST_DIR)
print("OUT_DIR:", OUT_DIR)

# ================================================================
# LOAD CONFIG
# ================================================================
cfg_path = MODEL_DIR / "config_siglip_fusion_lora_all6.json"
config_dict = json.load(open(cfg_path))
import types
config = types.SimpleNamespace(**config_dict)

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

DIRECTIONS = [
    ("en", "de"),
    ("en", "fr"),
    ("de", "en"),
    ("de", "fr"),
    ("fr", "en"),
    ("fr", "de"),
]

# ================================================================
# SAFE IMAGE LOADER
# ================================================================
def safe_load_image(image_id: str, root: Path):
    base = image_id.strip()
    for ext in [".jpg", ".jpeg", ".png"]:
        if base.endswith(ext):
            base = base[:-len(ext)]
            break
    for name in [base+".jpg", base+".jpeg", base+".png"]:
        fp = root / name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                pass
    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# LORA APPLIER
# ================================================================
def apply_lora_to_mbart(mbart):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    return get_peft_model(mbart, cfg)

# ================================================================
# FUSION MODEL
# ================================================================
class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8, dim_feedforward=2048,
            dropout=0.1, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        x = torch.cat([img_emb, txt_emb], dim=1)
        return self.encoder(x)

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained(config.vision_model_name)
        for p in self.vision.parameters(): p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)
        self.text_emb = self.mbart.get_input_embeddings()

        vis_dim = self.vision.config.hidden_size
        self.proj = nn.Linear(vis_dim, self.mbart.config.d_model)

        self.fusion = FusionBlock(self.mbart.config.d_model)

    def generate(self, input_ids, mask, pixel, tok):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel).last_hidden_state[:,0,:]
        img = self.proj(vis).unsqueeze(1)
        txt = self.text_emb(input_ids)
        fused = self.fusion(img, txt)
        fused_mask = torch.cat([torch.ones((input_ids.size(0),1),device=device), mask], dim=1)

        return self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            max_length=config.max_length,
            num_beams=4,
            forced_bos_token_id=tok.lang_code_to_id[tok.tgt_lang]
        )

# ================================================================
# TEXT ONLY MODEL
# ================================================================
class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)

    def generate(self, input_ids, mask, tok):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=mask,
            max_length=config.max_length,
            num_beams=4,
            forced_bos_token_id=tok.lang_code_to_id[tok.tgt_lang]
        )

# ================================================================
# LOAD TEST SPLIT
# ================================================================
def load_test(src, tgt):
    ids = open(SPLIT_FILE).read().splitlines()
    src_txt = open(TEST_DIR / f"test_2017_flickr.{src}").read().splitlines()
    tgt_txt = open(TEST_DIR / f"test_2017_flickr.{tgt}").read().splitlines()
    n = min(len(ids), len(src_txt), len(tgt_txt))
    return ids[:n], src_txt[:n], tgt_txt[:n]

# ================================================================
# EVALUATE ONE DIRECTION
# ================================================================
def evaluate_direction(src, tgt):
    print(f"\n============================")
    print(f"   EVAL: {src} ‚Üí {tgt}")
    print(f"============================")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ids, src_txt, tgt_txt = load_test(src, tgt)
    refs = [[t] for t in tgt_txt]

    # ---- LOAD MODELS ----
    mm = MultiModalModel().to(device)
    txt = TextOnlyModel().to(device)

    mm.load_state_dict(torch.load(MODEL_DIR/"siglip_fusion_lora_all6_mm_best.pt"))
    txt.load_state_dict(torch.load(MODEL_DIR/"mbart_lora_all6_text_best.pt"))

    mm.eval()
    txt.eval()

    preds_mm, preds_txt = [], []
    results = []

    for i in tqdm(range(len(ids))):
        enc = tokenizer(src_txt[i], max_length=config.max_length,
                        padding="max_length", truncation=True, return_tensors="pt")

        ids_t = enc["input_ids"].to(device)
        mask_t = enc["attention_mask"].to(device)

        img = safe_load_image(ids[i], IMG_ROOT)
        pixel = image_processor(images=img, return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            mm_out = mm.generate(ids_t, mask_t, pixel, tokenizer)
            txt_out = txt.generate(ids_t, mask_t, tokenizer)

        mm_pred = tokenizer.decode(mm_out[0], skip_special_tokens=True)
        txt_pred = tokenizer.decode(txt_out[0], skip_special_tokens=True)

        preds_mm.append(mm_pred)
        preds_txt.append(txt_pred)

        results.append({
            "image_id": ids[i],
            "source": src_txt[i],
            "target": tgt_txt[i],
            "mm": mm_pred,
            "txt": txt_pred,
        })

    # ---- SAVE JSON + TXT ----
    json_path = OUT_DIR / f"pred_{src}_{tgt}.json"
    txt_path = OUT_DIR / f"pred_{src}_{tgt}.txt"

    json.dump(results, open(json_path,"w"), indent=2, ensure_ascii=False)

    with open(txt_path,"w") as f:
        for r in results:
            f.write(f"{r['image_id']}\nSRC: {r['source']}\nTRG: {r['target']}\nMM:  {r['mm']}\nTXT: {r['txt']}\n")
            f.write("-"*50+"\n")

    print("Saved:", json_path)
    print("Saved:", txt_path)

    sacre = evaluate.load("sacrebleu")
    bleu_mm  = sacre.compute(predictions=preds_mm,  references=refs)["score"]
    bleu_txt = sacre.compute(predictions=preds_txt, references=refs)["score"]

    print(f"BLEU (MM ): {bleu_mm:.2f}")
    print(f"BLEU (TXT): {bleu_txt:.2f}")

    return bleu_mm, bleu_txt

# ================================================================
# MAIN LOOP
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
image_processor = SiglipProcessor.from_pretrained(config.vision_model_name)

all_scores = {}

for src, tgt in DIRECTIONS:
    mm, txt = evaluate_direction(src, tgt)
    all_scores[f"{src}->{tgt}"] = {"mm": mm, "txt": txt}

json.dump(all_scores, open(OUT_DIR/"bleu_scores_all6.json","w"), indent=2)
print("\n==============================")
print("FINAL BLEU SCORES (ALL 6)")
print("==============================")
print(json.dumps(all_scores, indent=2))


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Using device: cuda
MODEL_DIR: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion
TEST_DIR: /content/drive/MyDrive/dataset/multi30k-dataset/data/task1/raw/test_2017_flickr
OUT_DIR: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval

   EVAL: en ‚Üí de


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [15:01<00:00,  1.11it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_de.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_de.txt
BLEU (MM ): 35.61
BLEU (TXT): 34.50

   EVAL: en ‚Üí fr


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [16:19<00:00,  1.02it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_fr.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_en_fr.txt
BLEU (MM ): 47.25
BLEU (TXT): 44.47

   EVAL: de ‚Üí en


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [13:42<00:00,  1.22it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_en.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_en.txt
BLEU (MM ): 44.63
BLEU (TXT): 43.97

   EVAL: de ‚Üí fr


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [15:59<00:00,  1.04it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_fr.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_de_fr.txt
BLEU (MM ): 34.07
BLEU (TXT): 31.50

   EVAL: fr ‚Üí en


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [13:50<00:00,  1.20it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_en.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_en.txt
BLEU (MM ): 49.47
BLEU (TXT): 48.20

   EVAL: fr ‚Üí de


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [15:17<00:00,  1.09it/s]


Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_de.json
Saved: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/test2016_all6_eval/pred_fr_de.txt
BLEU (MM ): 28.40
BLEU (TXT): 26.32

FINAL BLEU SCORES (ALL 6)
{
  "en->de": {
    "mm": 35.61013176172457,
    "txt": 34.500881811447705
  },
  "en->fr": {
    "mm": 47.24741638028822,
    "txt": 44.47445665424448
  },
  "de->en": {
    "mm": 44.63137357351859,
    "txt": 43.97010664008739
  },
  "de->fr": {
    "mm": 34.066932741835714,
    "txt": 31.504558856623845
  },
  "fr->en": {
    "mm": 49.46800114316033,
    "txt": 48.197099758381796
  },
  "fr->de": {
    "mm": 28.40282792786313,
    "txt": 26.319047874492725
  }
}


In [None]:
## Testing the models with e-commerce dataset (Test)

In [None]:
# ================================================================
# üîç EVALUATE PRETRAINED EN‚ÜíDE MODELS ON E-COMMERCE DATA (NO TRAINING)
# ================================================================

import os
import json
from pathlib import Path
import types
import pandas as pd
from typing import Any, Dict, List

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from tqdm import tqdm
import evaluate

# ------------------ HF + PEFT ------------------
from transformers import (
    MBart50TokenizerFast, MBartForConditionalGeneration,
    SiglipVisionModel, SiglipProcessor
)
from peft import LoraConfig, get_peft_model, TaskType

# BLEU metric
sacrebleu = evaluate.load("sacrebleu")

# ------------------ COLAB DRIVE ------------------
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ================================================================
# PATHS
# ================================================================
# PRETRAINED MODEL DIRECTORY (shared drive actual path)
MODEL_DIR = Path(
    "/content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/"
    "multimodal_translation_models_siglip_lora_fusion"
)

print("MODEL_DIR:", MODEL_DIR)

# E-commerce dataset
ECOMM_TSV = (
    "/content/drive/MyDrive/Dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/"
    "listingtitles_with_matched_images.en-de.tsv"
)
ECOMM_IMG_DIR = (
    "/content/drive/MyDrive/Dataset/ImageGuidedTranslationDataset-main/dataset/images"
)

# Output directory
EVAL_DIR = MODEL_DIR / "ecomm_eval_en_de"
EVAL_DIR.mkdir(parents=True, exist_ok=True)
print("Saving eval to:", EVAL_DIR)

# ================================================================
# LOAD TRAINING CONFIG (LoRA hyperparams)
# ================================================================
cfg_path = MODEL_DIR / "config_siglip_fusion_lora_all6.json"
config_dict = json.load(open(cfg_path))
config = types.SimpleNamespace(**config_dict)

MAX_LEN = config.max_length
vision_model_name = config.vision_model_name
LANG_CODES = {"en": "en_XX", "de": "de_DE"}

# ================================================================
# SAFE IMAGE LOADER
# ================================================================
def safe_load_image(filename: Any):
    filename = str(filename)
    for split in ["train", "val", "test"]:
        fp = Path(ECOMM_IMG_DIR) / split / filename
        if fp.exists():
            try: return Image.open(fp).convert("RGB")
            except: pass
    return Image.new("RGB", (224, 224), (128, 128, 128))

# ================================================================
# LORA / FUSION MODEL DEFINITIONS
# ================================================================
class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8, dim_feedforward=2048,
            dropout=0.1, batch_first=True
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_embed, text_embed):
        x = torch.cat([img_embed, text_embed], dim=1)
        return self.encoder(x)

def apply_lora_to_mbart(mbart):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r, lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets
    )
    return get_peft_model(mbart, cfg)

class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained(vision_model_name)
        for p in self.vision.parameters():
            p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)
        self.text_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(
            self.vision.config.hidden_size,
            self.mbart.config.d_model
        )
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def generate(self, input_ids, mask, pixel_values, tokenizer):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:,0,:]

        img_embed = self.proj(vis).unsqueeze(1)
        txt_embed = self.text_emb(input_ids)

        fused = self.fusion(img_embed, txt_embed)
        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0),1), device=device), mask], dim=1
        )

        return self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
        )

class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)

    def generate(self, input_ids, mask, tokenizer):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
        )

# ================================================================
# LOAD TOKENIZER + SIGLIP PROCESSOR
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
image_processor = SiglipProcessor.from_pretrained(vision_model_name)

tokenizer.src_lang = LANG_CODES["en"]
tokenizer.tgt_lang = LANG_CODES["de"]

# ================================================================
# LOAD ONLY VAL SPLIT OF E-COMMERCE
# ================================================================
df = pd.read_csv(ECOMM_TSV, sep="\t")
df = df[df["set_name"].str.lower().isin(["val", "validation", "valid"])]
df = df.reset_index(drop=True)

print("VAL rows:", len(df))

# ================================================================
# LOAD MODELS
# ================================================================
mm_ckpt  = MODEL_DIR / "siglip_fusion_lora_en_de_mm_best.pt"
txt_ckpt = MODEL_DIR / "mbart_lora_en_de_text_best.pt"

print("Loading MM:", mm_ckpt)
print("Loading TXT:", txt_ckpt)

mm_model = MultiModalModel().to(device)
txt_model = TextOnlyModel().to(device)

mm_model.load_state_dict(torch.load(mm_ckpt, map_location=device))
txt_model.load_state_dict(torch.load(txt_ckpt, map_location=device))

mm_model.eval()
txt_model.eval()

# ================================================================
# RUN EVALUATION
# ================================================================
preds_mm = []
preds_txt = []
refs = []
srcs = []
imgs = []

for i in tqdm(range(len(df)), desc="Evaluating EN‚ÜíDE"):
    row = df.iloc[i]
    src = str(row["source"])
    tgt = str(row["target"])
    img_file = row["image_file"]

    refs.append(tgt)
    srcs.append(src)
    imgs.append(img_file)

    enc = tokenizer(src, padding="max_length", truncation=True,
                    max_length=MAX_LEN, return_tensors="pt").to(device)

    # ---- multimodal ----
    img = safe_load_image(img_file)
    pixel = image_processor(images=[img], return_tensors="pt")["pixel_values"].to(device)
    gen_mm = mm_model.generate(enc["input_ids"], enc["attention_mask"], pixel, tokenizer)
    preds_mm.append(tokenizer.decode(gen_mm[0], skip_special_tokens=True))

    # ---- text-only ----
    gen_txt = txt_model.generate(enc["input_ids"], enc["attention_mask"], tokenizer)
    preds_txt.append(tokenizer.decode(gen_txt[0], skip_special_tokens=True))

# ================================================================
# COMPUTE BLEU
# ================================================================
bleu_mm = sacrebleu.compute(predictions=preds_mm,  references=[refs])["score"]
bleu_txt = sacrebleu.compute(predictions=preds_txt, references=[refs])["score"]

print("\n==============================")
print("‚≠ê EN ‚Üí DE FINAL BLEU SCORES")
print("==============================")
print(f"Multimodal: {bleu_mm:.2f}")
print(f"Text-only: {bleu_txt:.2f}")

# ================================================================
# SAVE RESULTS
# ================================================================
pd.DataFrame({
    "src": srcs,
    "gold": refs,
    "mm_pred": preds_mm,
    "txt_pred": preds_txt,
    "image_file": imgs
}).to_csv(EVAL_DIR / "preds_en_de.tsv", sep="\t", index=False)

with open(EVAL_DIR / "bleu_en_de.txt", "w") as f:
    f.write(f"MM:  {bleu_mm:.4f}\nTXT: {bleu_txt:.4f}\n")

print("Saved predictions and BLEU scores in:", EVAL_DIR)


Mounted at /content/drive
Using device: cuda
MODEL_DIR: /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion
Saving eval to: /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_eval_en_de


FileNotFoundError: [Errno 2] No such file or directory: '/content/drive/MyDrive/Dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de.tsv'

In [None]:
#Translation

In [None]:
# ================================================================
# Multimodal EN+DE+Image ‚Üí FR translation and TSV augmentation
# - Robust to broken/missing Google Drive image files
# ================================================================

!pip install -q openai pillow pandas tqdm

from openai import OpenAI
import base64
import pandas as pd
from tqdm import tqdm
from PIL import Image, ImageFile
from pathlib import Path
import torch

import os

os.environ["OPENAI_API_KEY"] = ""



ImageFile.LOAD_TRUNCATED_IMAGES = True

client = OpenAI()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------------------------------------
# 1. Helper: encode image as base64
# ------------------------------------------------
def encode_image(image_path: str) -> str:
    with open(image_path, "rb") as f:
        return base64.b64encode(f.read()).decode("utf-8")


# ------------------------------------------------
# 2. Multimodal translator EN+DE+Image ‚Üí FR
# ------------------------------------------------
def multimodal_translate_to_french(en_caption: str, de_caption: str, image_path: str) -> str:
    """
    Uses GPT-4o multimodal to produce a single French caption
    based on EN + DE + product image.
    """
    img_b64 = encode_image(image_path)

    system_prompt = """
You are a multimodal product translation system.

You receive:
- A product image
- An English caption
- A German caption

Your job:
- Understand the image (category, color, materials, brand, attributes)
- Merge meaning from both English + German text
- Correct mistakes using the image
- Output ONLY one final French translation
- No explanation, no analysis ‚Äî only the translated French text.
"""

    user_prompt = f"""
English caption: {en_caption}
German caption: {de_caption}
"""

    resp = client.responses.create(
        model="gpt-4o",
        input=[
            {
                "role": "system",
                "content": [{"type": "input_text", "text": system_prompt}]
            },
            {
                "role": "user",
                "content": [
                    {"type": "input_text", "text": user_prompt},
                    {
                        "type": "input_image",
                        "image_url": f"data:image/jpeg;base64,{img_b64}"
                    }
                ]
            }
        ]
    )

    return resp.output_text.strip()


# ------------------------------------------------
# 3. Robust image finder (no crashing on I/O errors)
# ------------------------------------------------
def find_or_placeholder_image(images_root: str, image_filename: str, placeholder_path: Path) -> Path:
    """
    Try to locate image_filename in train/val/test under images_root.
    If anything fails (I/O error, missing file, etc.), return placeholder.
    """
    if image_filename is None or str(image_filename).lower() == "nan":
        return placeholder_path

    image_filename = str(image_filename).strip()
    base = Path(images_root)

    for split in ["train", "val", "test"]:
        candidate = base / split / image_filename
        try:
            # Even .is_file() may raise OSError on Google Drive ‚Üí catch it
            if candidate.is_file():
                return candidate
        except OSError:
            # Skip problematic path and continue searching
            continue

    # If we reach here, no usable image was found
    return placeholder_path


# ------------------------------------------------
# 4. Main function: add French column to TSV
# ------------------------------------------------
def add_french_to_tsv(tsv_path: str, images_root: str, limit: int | None = None) -> str:
    """
    - Loads existing TSV (EN/DE/image)
    - Adds 'french' column generated by GPT-4o using EN+DE+Image
    - Saves new TSV in same folder with *_with_french.tsv suffix
    """
    tsv_path = Path(tsv_path)
    print("Loading TSV:", tsv_path)

    df = pd.read_csv(tsv_path, sep="\t")

    if limit is not None:
        df = df.head(limit)

    # Prepare a persistent placeholder image
    placeholder = Path("placeholder_grey_224.jpg")
    if not placeholder.exists():
        Image.new("RGB", (224, 224), (128, 128, 128)).save(placeholder)

    french_captions: list[str] = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Translating"):
        en_caption = str(row["source"])
        de_caption = str(row["target"])
        image_file = row.get("image_file", None)

        # Find or fallback to placeholder (no I/O crash)
        img_path = find_or_placeholder_image(images_root, image_file, placeholder)

        try:
            fr_caption = multimodal_translate_to_french(
                en_caption=en_caption,
                de_caption=de_caption,
                image_path=str(img_path),
            )
        except Exception as e:
            # In case of any model/API error, fall back to EN or DE
            print(f"\n[WARN] Row {idx} translation failed: {e}")
            fr_caption = en_caption  # or f"{en_caption} / {de_caption}"

        french_captions.append(fr_caption)

    df["french"] = french_captions

    out_path = tsv_path.with_name(tsv_path.stem + "_with_french.tsv")
    df.to_csv(out_path, sep="\t", index=False)

    print("\n‚úÖ Saved updated TSV with French translations at:")
    print(str(out_path))

    return str(out_path)


# ------------------------------------------------
# 5. Run on your paths
# ------------------------------------------------
TSV = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de.tsv"
IMAGES = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images"

# Test on a small subset first
out_test = add_french_to_tsv(TSV, IMAGES, limit=7500)
print("Test file:", out_test)

# When happy, run on full dataset (commented out for now)
# out_full = add_french_to_tsv(TSV, IMAGES, limit=None)
# print("Full file:", out_full)


Using device: cuda
Loading TSV: /content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de.tsv


Translating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [11:33:12<00:00,  5.55s/it]


‚úÖ Saved updated TSV with French translations at:
/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv
Test file: /content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv





In [None]:
#French Training

In [None]:
# ================================================================
# PART 1 ‚Äî IMPORTS, CONFIG, PATHS, UTILITIES
# ================================================================

import os
import json
from pathlib import Path
from typing import Any, Dict, List
import shutil

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from tqdm import tqdm

# ------------------ Install Dependencies if Missing ------------------
try:
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType
    import evaluate
except:
    !pip install -q transformers peft accelerate sentencepiece evaluate
    from transformers import (
        MBart50TokenizerFast,
        MBartForConditionalGeneration,
        SiglipVisionModel,
        SiglipProcessor,
    )
    from peft import LoraConfig, get_peft_model, TaskType
    import evaluate

sacrebleu = evaluate.load("sacrebleu")

# ------------------ MOUNT DRIVE ------------------
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ------------------ IMPORTANT PATHS ------------------
BASE_DIR = Path("/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion")

ECOMM_TSV = (
    "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/"
    "listingtitles_with_matched_images.en-de_with_french.tsv"
)

IMG_ROOT = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images"

OUT_DIR = BASE_DIR / "ecomm_finetuned"
OUT_DIR.mkdir(exist_ok=True, parents=True)

EVAL_DIR = OUT_DIR / "evals"
EVAL_DIR.mkdir(exist_ok=True)

# ------------------ TRAINING CONFIG ------------------
BATCH_SIZE  = 8
MAX_LEN     = 128
LR          = 2e-4
EPOCHS      = 6

MAX_TRAIN_SAMPLES = 15000
MAX_VAL_SAMPLES   = 2000

LANG_CODES = {
    "en": "en_XX",
    "de": "de_DE",
    "fr": "fr_XX"
}

VISION_MODEL_NAME = "google/siglip-base-patch16-224"   # Multi30K-compatible

# ------------------ SAFE IMAGE LOADING ------------------
def safe_load_image(image_name):
    if not isinstance(image_name, str):
        return Image.new("RGB", (224,224), (128,128,128))

    image_name = image_name.strip()

    for split in ["train", "val", "test"]:
        fp = Path(IMG_ROOT) / split / image_name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                return Image.new("RGB", (224,224), (128,128,128))

    return Image.new("RGB", (224,224), (128,128,128))

# ------------------ LORA CONFIG ------------------
def apply_lora(mbart):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"],
    )
    return get_peft_model(mbart, cfg)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Mounted at /content/drive
Using device: cuda


In [None]:
# ================================================================
# PART 2 ‚Äî MODELS, DATASET, COLLATE, TRAINING LOOP
# ================================================================

# ------------------ FUSION BLOCK ------------------
class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        x = torch.cat([img_emb, txt_emb], dim=1)
        return self.encoder(x)

# ------------------ MULTIMODAL MODEL ------------------
class MultiModalModel(nn.Module):
    def __init__(self):
        super().__init__()
        print("Loading SigLIP-Base (224)...")
        self.vision = SiglipVisionModel.from_pretrained(VISION_MODEL_NAME)
        for p in self.vision.parameters():
            p.requires_grad = False

        print("Loading MBART50 + LORA...")
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)
        self.text_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def forward(self, input_ids, attn_mask, pixel_values, labels):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:,0,:]

        img_emb = self.proj(vis).unsqueeze(1)
        txt_emb = self.text_emb(input_ids)

        fused = self.fusion(img_emb, txt_emb)

        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0),1), device=device), attn_mask],
            dim=1,
        )

        out = self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels,
        )
        return out

    def generate(self, input_ids, attn_mask, pixel_values, tokenizer):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:,0,:]

        img_emb = self.proj(vis).unsqueeze(1)
        txt_emb = self.text_emb(input_ids)
        fused = self.fusion(img_emb, txt_emb)

        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0),1), device=device), attn_mask],
            dim=1,
        )

        gen = self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen

# ------------------ TEXT-ONLY MODEL ------------------
class TextOnlyModel(nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)

    def forward(self, input_ids, attn_mask, labels):
        return self.mbart(
            input_ids=input_ids,
            attention_mask=attn_mask,
            labels=labels,
        )

    def generate(self, input_ids, attn_mask, tokenizer):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=attn_mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )

class EcommDataset(Dataset):
    def __init__(self, split, src_lang, tgt_lang, limit=None):
        df = pd.read_csv(ECOMM_TSV, sep="\t")
        df["set_name"] = df["set_name"].str.lower()

        # Filter rows that have FR data available
        df = df[df["target"].notna() & df["source"].notna()]

        # If NO VAL exists ‚Üí auto split
        has_val = any(df["set_name"].isin(["val", "valid", "validation"]))

        if not has_val:
            # Auto split 90% train, 10% val
            df = df[df["set_name"] == "train"].reset_index(drop=True)
            n = len(df)
            split_index = int(n * 0.9)

            train_df = df.iloc[:split_index].reset_index(drop=True)
            val_df = df.iloc[split_index:].reset_index(drop=True)

            if split == "train":
                df = train_df
            elif split == "val":
                df = val_df
            else:
                df = val_df  # fallback: test = val
        else:
            # Standard behavior
            if split == "train":
                df = df[df["set_name"] == "train"]
            elif split == "val":
                df = df[df["set_name"].isin(["val", "valid", "validation"])]
            else:
                df = df[df["set_name"] == "test"]

        if limit is not None and len(df) > limit:
            df = df.sample(n=limit, random_state=42)

        self.df = df.reset_index(drop=True)
        self.src = src_lang
        self.tgt = tgt_lang

        print(f"[{src_lang}->{tgt_lang}] {split} samples:", len(self.df))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        en = str(row["source"])
        fr = str(row["target"])

        if self.src == "en":
            return {"src": en, "tgt": fr, "img": row["image_file"]}
        else:
            return {"src": fr, "tgt": en, "img": row["image_file"]}


# ------------------ COLLATE FUNCTION ------------------
def make_collate(tokenizer, processor):
    def fn(batch):
        src = [b["src"] for b in batch]
        tgt = [b["tgt"] for b in batch]

        enc_src = tokenizer(
            src, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt"
        )
        with tokenizer.as_target_tokenizer():
            enc_tgt = tokenizer(
                tgt, padding="max_length", truncation=True, max_length=MAX_LEN, return_tensors="pt"
            )

        labels = enc_tgt["input_ids"]
        labels[labels == tokenizer.pad_token_id] = -100

        imgs = [safe_load_image(b["img"]) for b in batch]
        pixel_values = processor(images=imgs, return_tensors="pt")["pixel_values"]

        return {
            "input_ids": enc_src["input_ids"].to(device),
            "attention_mask": enc_src["attention_mask"].to(device),
            "labels": labels.to(device),
            "pixel_values": pixel_values.to(device),
        }
    return fn

# ------------------ TRAINING LOOP ------------------
def train_one(src, tgt, model_type, tokenizer, processor):
    print(f"\n==== TRAINING {model_type.upper()} {src}->{tgt} ====")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    train_ds = EcommDataset("train", src, tgt, MAX_TRAIN_SAMPLES)
    val_ds   = EcommDataset("val", src, tgt, MAX_VAL_SAMPLES)

    loader = DataLoader(
        train_ds, batch_size=BATCH_SIZE, shuffle=True,
        collate_fn=make_collate(tokenizer, processor),
    )
    vloader = DataLoader(
        val_ds, batch_size=BATCH_SIZE, shuffle=False,
        collate_fn=make_collate(tokenizer, processor),
    )

    if model_type == "mm":
        model = MultiModalModel().to(device)
        ckpt = BASE_DIR / f"siglip_fusion_lora_{src}_{tgt}_mm_best.pt"
    else:
        model = TextOnlyModel().to(device)
        ckpt = BASE_DIR / f"mbart_lora_{src}_{tgt}_text_best.pt"

    print("Loading pretrained checkpoint:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad], lr=LR
    )

    best = float("inf")
    out_path = OUT_DIR / f"ecomm_{src}_{tgt}_{model_type}.pt"

    for ep in range(1, EPOCHS+1):
        print(f"\nEpoch {ep}/{EPOCHS}")

        model.train()
        total = 0
        for batch in tqdm(loader):
            optimizer.zero_grad()

            if model_type == "mm":
                out = model(
                    batch["input_ids"], batch["attention_mask"],
                    batch["pixel_values"], batch["labels"]
                )
            else:
                out = model(
                    batch["input_ids"], batch["attention_mask"], batch["labels"]
                )

            loss = out.loss
            loss.backward()
            optimizer.step()
            total += loss.item()

        print("Train loss:", total / len(loader))

        # ---- VALIDATION ----
        model.eval()
        vloss = 0
        with torch.no_grad():
            for batch in vloader:
                if model_type == "mm":
                    out = model(
                        batch["input_ids"], batch["attention_mask"],
                        batch["pixel_values"], batch["labels"]
                    )
                else:
                    out = model(
                        batch["input_ids"], batch["attention_mask"], batch["labels"]
                    )
                vloss += out.loss.item()

        vloss /= len(vloader)
        print("Val loss:", vloss)

        if vloss < best:
            best = vloss
            torch.save(model.state_dict(), out_path)
            print("Saved best:", out_path)


In [1]:
def evaluate_one(src, tgt, model_type, tokenizer, processor, max_test=1000):
    print(f"\n====== EVALUATING {model_type.upper()} {src}->{tgt} =======")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ds = EcommDataset("val", src, tgt, MAX_VAL_SAMPLES)

    # Limit test size
    if len(ds) > max_test:
        ds = [ds[i] for i in range(max_test)]
    print(f"Eval samples: {len(ds)}")

    # Load model
    if model_type == "mm":
        model = MultiModalModel().to(device)
        ckpt = CKPT_DIR / f"ecomm_{src}_{tgt}_mm.pt"
    else:
        model = TextOnlyModel().to(device)
        ckpt = CKPT_DIR / f"ecomm_{src}_{tgt}_txt.pt"

    print("Loading:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    predictions = []
    references = []

    for sample in tqdm(ds):
        src_txt = sample["src"]
        tgt_txt = sample["tgt"]
        references.append(tgt_txt)

        enc = tokenizer(
            src_txt, padding="max_length", truncation=True,
            max_length=MAX_LEN, return_tensors="pt"
        ).to(device)

        if model_type == "mm":
            img = safe_load_image(sample["img"])
            pixel = processor(images=[img], return_tensors="pt")["pixel_values"].to(device)
            gen = model.generate(enc["input_ids"], enc["attention_mask"], pixel, tokenizer)
        else:
            gen = model.generate(enc["input_ids"], enc["attention_mask"], tokenizer)

        pred = tokenizer.decode(gen[0], skip_special_tokens=True)
        predictions.append(pred)

    # üî• FIXED BLEU CALCULATION ‚Äî no nested list
    bleu = sacrebleu.compute(predictions=predictions, references=references)["score"]

    print("BLEU:", bleu)
    return bleu
# ================================================================
# RUN TRAINING + EVALUATION
# ================================================================

tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
processor = SiglipProcessor.from_pretrained(VISION_MODEL_NAME)

# --------------- TRAIN ALL PAIRS ---------------
# EN ‚Üî DE
#train_one("en", "de", "mm", tokenizer, processor)
#train_one("en", "de", "txt", tokenizer, processor)
#train_one("de", "en", "mm", tokenizer, processor)
#train_one("de", "en", "txt", tokenizer, processor)

# EN ‚Üî FR
train_one("en", "fr", "mm", tokenizer, processor)
train_one("en", "fr", "txt", tokenizer, processor)
train_one("fr", "en", "mm", tokenizer, processor)
train_one("fr", "en", "txt", tokenizer, processor)

# --------------- EVALUATE ALL PAIRS ---------------
#evaluate_one("en", "de", "mm", tokenizer, processor)
#evaluate_one("en", "de", "txt", tokenizer, processor)
#evaluate_one("de", "en", "mm", tokenizer, processor)
#evaluate_one("de", "en", "txt", tokenizer, processor)

#evaluate_one("en", "fr", "mm", tokenizer, processor)
#evaluate_one("en", "fr", "txt", tokenizer, processor)
#evaluate_one("fr", "en", "mm", tokenizer, processor)
#evaluate_one("fr", "en", "txt", tokenizer, processor)

print("\nüéâ ALL TRAINING + EVALUATION COMPLETED SUCCESSFULLY!")


NameError: name 'MBart50TokenizerFast' is not defined

In [None]:
# Training and testing

In [None]:
import os

folder = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
print("\n".join(os.listdir(folder)))


siglip_fusion_lora_en_de_mm_best.pt
mbart_lora_en_de_text_best.pt
siglip_fusion_lora_en_fr_mm_best.pt
mbart_lora_en_fr_text_best.pt
siglip_fusion_lora_de_en_mm_best.pt
mbart_lora_de_en_text_best.pt
siglip_fusion_lora_de_fr_mm_best.pt
mbart_lora_de_fr_text_best.pt
config_siglip_fusion_lora.json
siglip_fusion_lora_fr_en_mm_best.pt
mbart_lora_fr_en_text_best.pt
siglip_fusion_lora_fr_de_mm_best.pt
mbart_lora_fr_de_text_best.pt
test2016_predictions
test2017_predictions
config_siglip_fusion_lora_all6.json
ecomm_finetuned
siglip_fusion_lora_all6_mm_best.pt
mbart_lora_all6_text_best.pt
test2016_all6_eval
ecomm_eval_en_de


In [None]:
# ================================================================
# FINAL TRAINING + EVALUATION SCRIPT (SIGLIP FUSION + LORA)
# Supports: en‚Üíde, de‚Üíen, en‚Üífr, fr‚Üíen
# Dataset columns:
# ['project_name','set_name','image_id','image_file','source','target','french']
# ================================================================

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from pathlib import Path
from tqdm import tqdm
import evaluate

from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
from transformers import SiglipVisionModel, SiglipProcessor
from peft import LoraConfig, get_peft_model, TaskType

# ================================================================
# DEVICE
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

# ================================================================
# BLEU
# ================================================================
sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
MODEL_DIR = Path(BASE)
OUT_DIR = MODEL_DIR / "ecomm_finetuned"
OUT_DIR.mkdir(exist_ok=True)

TSV_FILE = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
IMG_DIR = Path("/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images")

# Pretrained base checkpoints
PRETRAINED = {
    "en_de": MODEL_DIR / "siglip_fusion_lora_en_de_mm_best.pt",
    "de_en": MODEL_DIR / "siglip_fusion_lora_de_en_mm_best.pt",
    "en_fr": MODEL_DIR / "siglip_fusion_lora_en_fr_mm_best.pt",
    "fr_en": MODEL_DIR / "siglip_fusion_lora_fr_en_mm_best.pt"
}

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# ================================================================
# HYPERPARAMETERS
# ================================================================
MAX_LEN = 64
BATCH = 2
LR = 2e-4
EPOCHS = 8     # adjust if needed

# ================================================================
# SAFE IMAGE LOAD
# ================================================================
def safe_load(img_name):
    if not isinstance(img_name, str) or img_name.strip() == "":
        return Image.new("RGB", (224,224), (128,128,128))

    img_name = img_name.strip()

    for split in ["train","val","test"]:
        fp = IMG_DIR / split / img_name
        try:
            if fp.exists():
                return Image.open(fp).convert("RGB")
        except:
            return Image.new("RGB", (224,224), (128,128,128))

    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# DATASET
# ================================================================
class ECommerceDataset(Dataset):
    def __init__(self, df, src_lang, tgt_lang):
        self.df = df.reset_index(drop=True)
        self.src = src_lang
        self.tgt = tgt_lang

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        en = str(row["source"]).strip()
        de = str(row["target"]).strip()
        fr = str(row["french"]).strip()
        img = row["image_file"]

        # SOURCE selection
        if self.src == "en":   src_text = en
        elif self.src == "de": src_text = de
        else:                  src_text = fr

        # TARGET selection
        if self.tgt == "en":   tgt_text = en
        elif self.tgt == "de": tgt_text = de
        else:                  tgt_text = fr

        # Clean None / empty
        if src_text is None or src_text == "nan": src_text = ""
        if tgt_text is None or tgt_text == "nan": tgt_text = ""

        return {"src": src_text, "tgt": tgt_text, "img": img}

# ================================================================
# CREATE TRAIN/TEST SPLITS
# ================================================================
def load_split(src, tgt, total_samples):

    df = pd.read_csv(TSV_FILE, sep="\t")

    df = df[df["set_name"].str.lower().isin(["train","test"])]
    df = df[["source","target","french","image_file","set_name"]]

    # Remove bad text rows
    df = df.dropna(subset=["source","target","french"]).reset_index(drop=True)

    # EN‚ÜîDE = 15000 max ‚Üí your dataset = 7500 ‚Üí we use all
    train_df = df[df["set_name"]=="train"]
    test_df  = df[df["set_name"]=="test"]

    # Train split limit
    train_df = train_df.sample(min(len(train_df), int(total_samples*0.8)), random_state=42)
    test_df = test_df.sample(min(len(test_df), int(total_samples*0.2)), random_state=42)

    print(f"Train={len(train_df)} Test={len(test_df)} for {src}->{tgt}")

    return train_df, test_df

# ================================================================
# SIGLIP FUSION MODEL (LoRA)
# ================================================================
def apply_lora(m):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj","v_proj"]
    )
    return get_peft_model(m, cfg)

class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8,
            dim_feedforward=2048, dropout=0.1,
            batch_first=True
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        return self.enc(torch.cat([img_emb, txt_emb], dim=1))

class SiglipFusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
        for p in self.vision.parameters(): p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        self.mbart = apply_lora(base)
        self.txt_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def forward(self, ids, mask, pixel, labels):
        with torch.no_grad():
            v = self.vision(pixel).last_hidden_state[:,0,:]

        img_e = self.proj(v).unsqueeze(1)
        txt_e = self.txt_emb(ids)

        fused = self.fusion(img_e, txt_e)
        fused_mask = torch.cat([torch.ones((ids.size(0),1), device=device), mask], dim=1)

        return self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels
        )

# ================================================================
# TOKENIZER + PROCESSOR
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-224")

# ================================================================
# COLLATE
# ================================================================
def collate(batch):
    src = [b["src"] for b in batch]
    tgt = [b["tgt"] for b in batch]
    imgs = [safe_load(b["img"]) for b in batch]

    # Tokenize source
    enc_src = tokenizer(
        src, padding="max_length", truncation=True,
        max_length=MAX_LEN, return_tensors="pt"
    )

    # Tokenize target
    with tokenizer.as_target_tokenizer():
        enc_tgt = tokenizer(
            tgt, padding="max_length", truncation=True,
            max_length=MAX_LEN, return_tensors="pt"
        )

    labels = enc_tgt["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    pixel = processor(images=imgs, return_tensors="pt")["pixel_values"]

    return {
        "ids": enc_src["input_ids"].to(device),
        "mask": enc_src["attention_mask"].to(device),
        "labels": labels.to(device),
        "pixel": pixel.to(device)
    }

# ================================================================
# TRAINING LOOP
# ================================================================
def train_direction(src, tgt, key, total_samples):

    print("\n===================================================")
    print(f"üî• TRAINING {src} ‚Üí {tgt} on {total_samples} samples")
    print("===================================================")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    train_df, test_df = load_split(src, tgt, total_samples)

    train_loader = DataLoader(ECommerceDataset(train_df, src, tgt),
                              batch_size=BATCH, shuffle=True, collate_fn=collate)
    test_ds = ECommerceDataset(test_df, src, tgt)

    model = SiglipFusionModel().to(device)
    print("Loading pretrained checkpoint:", PRETRAINED[key])
    model.load_state_dict(torch.load(PRETRAINED[key], map_location=device), strict=False)

    # Freeze all except LoRA weights
    for n, p in model.named_parameters():
        p.requires_grad = ("lora" in n)

    optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

    best = 999999
    save_path = OUT_DIR / f"ecomm_{src}_{tgt}_mm.pt"

    for ep in range(1, EPOCHS+1):
        model.train()
        total = 0

        for batch in tqdm(train_loader):
            optim.zero_grad()
            out = model(batch["ids"], batch["mask"], batch["pixel"], batch["labels"])
            loss = out.loss
            loss.backward()
            optim.step()
            total += loss.item()

        ep_loss = total / len(train_loader)
        print(f"Epoch {ep} Loss = {ep_loss:.4f}")

        if ep_loss < best:
            best = ep_loss
            torch.save(model.state_dict(), save_path)
            print("Saved best model:", save_path)

    return test_ds

# ================================================================
# EVALUATION
# ================================================================
def evaluate_direction(src, tgt, test_ds):

    print("\n====== Evaluating", src, "‚Üí", tgt, "======")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ckpt = OUT_DIR / f"ecomm_{src}_{tgt}_mm.pt"

    model = SiglipFusionModel().to(device)
    print("Loading finetuned model:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs = [], []

    for sample in tqdm(test_ds):
        src_text = sample["src"]
        tgt_text = sample["tgt"]
        refs.append(tgt_text)

        enc = tokenizer(
            src_text, truncation=True, padding="max_length",
            max_length=MAX_LEN, return_tensors="pt"
        ).to(device)

        img = safe_load(sample["img"])
        pixel = processor(images=[img], return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            v = model.vision(pixel).last_hidden_state[:,0,:]
            img_e = model.proj(v).unsqueeze(1)
            txt_e = model.txt_emb(enc["input_ids"])
            fused = model.fusion(img_e, txt_e)
            fused_mask = torch.cat([torch.ones((1,1), device=device), enc["attention_mask"]], dim=1)

            gen = model.mbart.generate(
                inputs_embeds=fused,
                attention_mask=fused_mask,
                num_beams=5,
                max_length=MAX_LEN,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )

        pred = tokenizer.decode(gen[0], skip_special_tokens=True)
        preds.append(pred)

    bleu = sacrebleu.compute(predictions=preds, references=[[r] for r in refs])["score"]
    print("BLEU:", bleu)
    return bleu

# ================================================================
# RUN TRAINING FOR ALL 4 DIRECTIONS
# ================================================================
# Dataset size = 7500 ‚Üí use full
SAMPLES = 7500

test_en_de = train_direction("en", "de", "en_de", SAMPLES)
#evaluate_direction("en", "de", test_en_de)

test_de_en = train_direction("de", "en", "de_en", SAMPLES)
#evaluate_direction("de", "en", test_de_en)

test_en_fr = train_direction("en", "fr", "en_fr", SAMPLES)
#evaluate_direction("en", "fr", test_en_fr)

test_fr_en = train_direction("fr", "en", "fr_en", SAMPLES)
#evaluate_direction("fr", "en", test_fr_en)


Using: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/711 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


üî• TRAINING en ‚Üí de on 7500 samples
Train=6000 Test=0 for en->de


config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Loading pretrained checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [1:01:23<00:00,  1.23s/it]


Epoch 1 Loss = 1.5987
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:44<00:00,  7.41it/s]


Epoch 2 Loss = 1.4186
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:39<00:00,  7.52it/s]


Epoch 3 Loss = 1.3123
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:15<00:00,  7.99it/s]


Epoch 4 Loss = 1.2201
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:37<00:00,  7.55it/s]


Epoch 5 Loss = 1.1424
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [07:38<00:00,  6.54it/s]


Epoch 6 Loss = 1.0699
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:08<00:00,  8.13it/s]


Epoch 7 Loss = 1.0074
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:14<00:00,  8.01it/s]


Epoch 8 Loss = 0.9501
Saved best model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt

Loading finetuned model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


0it [00:00, ?it/s]


IndexError: list index out of range

In [None]:
import pandas as pd
df = pd.read_csv(TSV_FILE, sep="\t")
print(df.columns.tolist())


['project_name', 'set_name', 'image_id', 'image_file', 'source', 'target', 'french']


In [None]:
# ================================================================
# FINAL TRAINING + EVALUATION SCRIPT (SIGLIP FUSION + LORA)
# Supports: en‚Üíde, de‚Üíen, en‚Üífr, fr‚Üíen
# Dataset columns:
# ['project_name','set_name','image_id','image_file','source','target','french']
# ================================================================

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from pathlib import Path
from tqdm import tqdm
import evaluate

from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
from transformers import SiglipVisionModel, SiglipProcessor
from peft import LoraConfig, get_peft_model, TaskType

# ================================================================
# DEVICE
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

# ================================================================
# BLEU
# ================================================================
sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
MODEL_DIR = Path(BASE)
OUT_DIR = MODEL_DIR / "ecomm_finetuned"
OUT_DIR.mkdir(exist_ok=True)

TSV_FILE = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
IMG_DIR = Path("/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images")

# Pretrained base checkpoints
PRETRAINED = {
    "en_de": MODEL_DIR / "siglip_fusion_lora_en_de_mm_best.pt",
    "de_en": MODEL_DIR / "siglip_fusion_lora_de_en_mm_best.pt",
    "en_fr": MODEL_DIR / "siglip_fusion_lora_en_fr_mm_best.pt",
    "fr_en": MODEL_DIR / "siglip_fusion_lora_fr_en_mm_best.pt"
}

LANG_CODES = {"en": "en_XX", "de": "de_DE", "fr": "fr_XX"}

# ================================================================
# HYPERPARAMETERS
# ================================================================
MAX_LEN = 64
BATCH = 2
LR = 2e-4
EPOCHS = 8     # adjust if needed

# ================================================================
# SAFE IMAGE LOAD
# ================================================================
def safe_load(img_name):
    if not isinstance(img_name, str) or img_name.strip() == "":
        return Image.new("RGB", (224,224), (128,128,128))

    img_name = img_name.strip()

    for split in ["train","val","test"]:
        fp = IMG_DIR / split / img_name
        try:
            if fp.exists():
                return Image.open(fp).convert("RGB")
        except:
            return Image.new("RGB", (224,224), (128,128,128))

    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# DATASET
# ================================================================
class ECommerceDataset(Dataset):
    def __init__(self, df, src_lang, tgt_lang):
        self.df = df.reset_index(drop=True)
        self.src = src_lang
        self.tgt = tgt_lang

    def __len__(self): return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        en = str(row["source"]).strip()
        de = str(row["target"]).strip()
        fr = str(row["french"]).strip()
        img = row["image_file"]

        # SOURCE selection
        if self.src == "en":   src_text = en
        elif self.src == "de": src_text = de
        else:                  src_text = fr

        # TARGET selection
        if self.tgt == "en":   tgt_text = en
        elif self.tgt == "de": tgt_text = de
        else:                  tgt_text = fr

        # Clean None / empty
        if src_text is None or src_text == "nan": src_text = ""
        if tgt_text is None or tgt_text == "nan": tgt_text = ""

        return {"src": src_text, "tgt": tgt_text, "img": img}

# ================================================================
# CREATE TRAIN/TEST SPLITS
# ================================================================
def load_split(src, tgt, total_samples):

    df = pd.read_csv(TSV_FILE, sep="\t")

    df = df[df["set_name"].str.lower().isin(["train","test"])]
    df = df[["source","target","french","image_file","set_name"]]

    # Remove bad text rows
    df = df.dropna(subset=["source","target","french"]).reset_index(drop=True)

    # EN‚ÜîDE = 15000 max ‚Üí your dataset = 7500 ‚Üí we use all
    train_df = df[df["set_name"]=="train"]
    test_df  = df[df["set_name"]=="test"]

    # Train split limit
    train_df = train_df.sample(min(len(train_df), int(total_samples*0.8)), random_state=42)
    test_df = test_df.sample(min(len(test_df), int(total_samples*0.2)), random_state=42)

    print(f"Train={len(train_df)} Test={len(test_df)} for {src}->{tgt}")

    return train_df, test_df

# ================================================================
# SIGLIP FUSION MODEL (LoRA)
# ================================================================
def apply_lora(m):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj","v_proj"]
    )
    return get_peft_model(m, cfg)

class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8,
            dim_feedforward=2048, dropout=0.1,
            batch_first=True
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        return self.enc(torch.cat([img_emb, txt_emb], dim=1))

class SiglipFusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
        for p in self.vision.parameters(): p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
        self.mbart = apply_lora(base)
        self.txt_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def forward(self, ids, mask, pixel, labels):
        with torch.no_grad():
            v = self.vision(pixel).last_hidden_state[:,0,:]

        img_e = self.proj(v).unsqueeze(1)
        txt_e = self.txt_emb(ids)

        fused = self.fusion(img_e, txt_e)
        fused_mask = torch.cat([torch.ones((ids.size(0),1), device=device), mask], dim=1)

        return self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels
        )

# ================================================================
# TOKENIZER + PROCESSOR
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-224")

# ================================================================
# COLLATE
# ================================================================
def collate(batch):
    src = [b["src"] for b in batch]
    tgt = [b["tgt"] for b in batch]
    imgs = [safe_load(b["img"]) for b in batch]

    # Tokenize source
    enc_src = tokenizer(
        src, padding="max_length", truncation=True,
        max_length=MAX_LEN, return_tensors="pt"
    )

    # Tokenize target
    with tokenizer.as_target_tokenizer():
        enc_tgt = tokenizer(
            tgt, padding="max_length", truncation=True,
            max_length=MAX_LEN, return_tensors="pt"
        )

    labels = enc_tgt["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    pixel = processor(images=imgs, return_tensors="pt")["pixel_values"]

    return {
        "ids": enc_src["input_ids"].to(device),
        "mask": enc_src["attention_mask"].to(device),
        "labels": labels.to(device),
        "pixel": pixel.to(device)
    }

# ================================================================
# TRAINING LOOP
# ================================================================
def train_direction(src, tgt, key, total_samples):

    print("\n===================================================")
    print(f"üî• TRAINING {src} ‚Üí {tgt} on {total_samples} samples")
    print("===================================================")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    train_df, test_df = load_split(src, tgt, total_samples)

    train_loader = DataLoader(ECommerceDataset(train_df, src, tgt),
                              batch_size=BATCH, shuffle=True, collate_fn=collate)
    test_ds = ECommerceDataset(test_df, src, tgt)

    model = SiglipFusionModel().to(device)
    print("Loading pretrained checkpoint:", PRETRAINED[key])
    model.load_state_dict(torch.load(PRETRAINED[key], map_location=device), strict=False)

    # Freeze all except LoRA weights
    for n, p in model.named_parameters():
        p.requires_grad = ("lora" in n)

    optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

    best = 999999
    save_path = OUT_DIR / f"ecomm_{src}_{tgt}_mm.pt"

    for ep in range(1, EPOCHS+1):
        model.train()
        total = 0

        for batch in tqdm(train_loader):
            optim.zero_grad()
            out = model(batch["ids"], batch["mask"], batch["pixel"], batch["labels"])
            loss = out.loss
            loss.backward()
            optim.step()
            total += loss.item()

        ep_loss = total / len(train_loader)
        print(f"Epoch {ep} Loss = {ep_loss:.4f}")

        if ep_loss < best:
            best = ep_loss
            torch.save(model.state_dict(), save_path)
            print("Saved best model:", save_path)

    return test_ds

# ================================================================
# EVALUATION
# ================================================================
def evaluate_direction(src, tgt, test_ds):

    print("\n====== Evaluating", src, "‚Üí", tgt, "======")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ckpt = OUT_DIR / f"ecomm_{src}_{tgt}_mm.pt"

    model = SiglipFusionModel().to(device)
    print("Loading finetuned model:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs = [], []

    for sample in tqdm(test_ds):
        src_text = sample["src"]
        tgt_text = sample["tgt"]
        refs.append(tgt_text)

        enc = tokenizer(
            src_text, truncation=True, padding="max_length",
            max_length=MAX_LEN, return_tensors="pt"
        ).to(device)

        img = safe_load(sample["img"])
        pixel = processor(images=[img], return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            v = model.vision(pixel).last_hidden_state[:,0,:]
            img_e = model.proj(v).unsqueeze(1)
            txt_e = model.txt_emb(enc["input_ids"])
            fused = model.fusion(img_e, txt_e)
            fused_mask = torch.cat([torch.ones((1,1), device=device), enc["attention_mask"]], dim=1)

            gen = model.mbart.generate(
                inputs_embeds=fused,
                attention_mask=fused_mask,
                num_beams=5,
                max_length=MAX_LEN,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )

        pred = tokenizer.decode(gen[0], skip_special_tokens=True)
        preds.append(pred)

    bleu = sacrebleu.compute(predictions=preds, references=[[r] for r in refs])["score"]
    print("BLEU:", bleu)
    return bleu

# ================================================================
# RUN TRAINING FOR ALL 4 DIRECTIONS
# ================================================================
# Dataset size = 7500 ‚Üí use full
SAMPLES = 7500

test_en_de = train_direction("en", "de", "en_de", SAMPLES)
#evaluate_direction("en", "de", test_en_de)

test_de_en = train_direction("de", "en", "de_en", SAMPLES)
#evaluate_direction("de", "en", test_de_en)

test_en_fr = train_direction("en", "fr", "en_fr", SAMPLES)
#evaluate_direction("en", "fr", test_en_fr)

test_fr_en = train_direction("fr", "en", "fr_en", SAMPLES)
#evaluate_direction("fr", "en", test_fr_en)


Using: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


preprocessor_config.json:   0%|          | 0.00/368 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/711 [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/798k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]


üî• TRAINING en ‚Üí de on 7500 samples
Train=6000 Test=0 for en->de


config.json:   0%|          | 0.00/432 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/813M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

Loading pretrained checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_de_mm_best.pt


 11%|‚ñà‚ñè        | 344/3000 [13:28<1:13:35,  1.66s/it]

In [None]:
# ================================================================
# üîÅ FINAL FULL TRAINING + EVALUATION SCRIPT FOR EN‚ÜîFR
# ================================================================

import os
import json
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from tqdm import tqdm
import evaluate

from transformers import (
    MBart50TokenizerFast,
    MBartForConditionalGeneration,
    SiglipVisionModel,
    SiglipProcessor,
)
from peft import LoraConfig, get_peft_model, TaskType

# ================================================================
# CONFIG
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

sacrebleu = evaluate.load("sacrebleu")

MAX_LEN = 64
BATCH = 2
LR = 2e-4
EPOCHS = 10

LANG_CODES = {"en": "en_XX", "fr": "fr_XX"}

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
MODEL_DIR = Path(BASE)

# Output folder (special folder for French finetuning)
OUT_DIR = MODEL_DIR / "ecomm_french_finetuned"
OUT_DIR.mkdir(exist_ok=True)

# Dataset files
TSV_FILE = (
    "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
)

IMG_DIR = Path(
    "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images"
)

# Pretrained models provided by you
PRETRAINED = {
    "en_fr": MODEL_DIR / "siglip_fusion_lora_en_fr_mm_best.pt",
    "fr_en": MODEL_DIR / "siglip_fusion_lora_fr_en_mm_best.pt",
}

# ================================================================
# SAFE IMAGE LOADER
# ================================================================
def safe_load(img_name):
    if not isinstance(img_name, str):
        return Image.new("RGB", (224,224), (128,128,128))

    img_name = img_name.strip()

    for split in ["train", "val", "test"]:
        fp = IMG_DIR / split / img_name
        try:
            if fp.exists():
                return Image.open(fp).convert("RGB")
        except:
            return Image.new("RGB", (224,224), (128,128,128))

    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# DATASET
# ================================================================
class FrenchDataset(Dataset):
    def __init__(self, df, src, tgt):
        self.df = df.reset_index(drop=True)
        self.src = src
        self.tgt = tgt

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        en = str(row["source"])
        de = str(row["target"])
        fr = str(row["french"])
        img = row["image_file"]

        if self.src == "en":
            src_text = en
        else:
            src_text = fr

        if self.tgt == "fr":
            tgt_text = fr
        else:
            tgt_text = en

        return {"src": src_text, "tgt": tgt_text, "img": img}

# ================================================================
# CREATE TRAIN/TEST SPLITS (AUTO-SPLIT 6000 / 1500)
# ================================================================
def load_french_splits(src, tgt):
    df = pd.read_csv(TSV_FILE, sep="\t")

    df = df[df["set_name"].str.lower().isin(["train", "test"])]
    df = df.dropna(subset=["source", "french"]).reset_index(drop=True)

    train_df = df[df["set_name"] == "train"]
    test_df = df[df["set_name"] == "test"]

    # enforce 6000/1500 rule
    train_df = train_df.sample(min(len(train_df), 6000), random_state=42)
    test_df  = test_df.sample(min(len(test_df), 1500), random_state=42)

    print(f"{src} ‚Üí {tgt}: Train={len(train_df)}, Test={len(test_df)}")
    return train_df, test_df

# ================================================================
# MODEL DEFINITIONS
# ================================================================
def apply_lora(model):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"],
    )
    return get_peft_model(model, cfg)

class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=8,
            dim_feedforward=2048, dropout=0.1,
            batch_first=True
        )
        self.enc = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        return self.enc(torch.cat([img_emb, txt_emb], dim=1))

class SiglipFusionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained(
            "google/siglip-base-patch16-224"
        )
        for p in self.vision.parameters():
            p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)
        self.txt_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def forward(self, ids, mask, pixel, labels):
        with torch.no_grad():
            v = self.vision(pixel_values=pixel).last_hidden_state[:,0,:]

        img_e = self.proj(v).unsqueeze(1)
        txt_e = self.txt_emb(ids)

        fused = self.fusion(img_e, txt_e)
        fused_mask = torch.cat(
            [torch.ones((ids.size(0),1), device=device), mask],
            dim=1
        )

        return self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels
        )

# ================================================================
# TOKENIZER + COLLATE
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-224")

def collate(batch):
    src = [b["src"] for b in batch]
    tgt = [b["tgt"] for b in batch]
    imgs = [safe_load(b["img"]) for b in batch]

    enc_s = tokenizer(
        src,
        truncation=True,
        padding="max_length",
        max_length=MAX_LEN,
        return_tensors="pt"
    )

    with tokenizer.as_target_tokenizer():
        enc_t = tokenizer(
            tgt,
            truncation=True,
            padding="max_length",
            max_length=MAX_LEN,
            return_tensors="pt"
        )

    labels = enc_t["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    pixel = processor(images=imgs, return_tensors="pt")["pixel_values"]

    return {
        "ids": enc_s["input_ids"].to(device),
        "mask": enc_s["attention_mask"].to(device),
        "labels": labels.to(device),
        "pixel": pixel.to(device)
    }

# ================================================================
# TRAINING LOOP
# ================================================================
def train_french(src, tgt, key):
    print(f"\nüî• Training {src} ‚Üí {tgt}")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    train_df, test_df = load_french_splits(src, tgt)

    train_loader = DataLoader(
        FrenchDataset(train_df, src, tgt),
        batch_size=BATCH,
        shuffle=True,
        collate_fn=collate
    )
    test_ds = FrenchDataset(test_df, src, tgt)

    model = SiglipFusionModel().to(device)

    print("Loading pretrained:", PRETRAINED[key])
    model.load_state_dict(torch.load(PRETRAINED[key], map_location=device), strict=False)

    # Freeze everything except LoRA
    for name, p in model.named_parameters():
        p.requires_grad = ("lora" in name)

    optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)

    best = 9999
    save_path = OUT_DIR / f"ecomm_{src}_{tgt}.pt"

    for ep in range(1, EPOCHS+1):
        model.train()
        total = 0

        for batch in tqdm(train_loader):
            optim.zero_grad()
            out = model(batch["ids"], batch["mask"], batch["pixel"], batch["labels"])
            loss = out.loss
            loss.backward()
            optim.step()
            total += loss.item()

        print(f"Epoch {ep} Loss = {total/len(train_loader):.4f}")

        if total < best:
            best = total
            torch.save(model.state_dict(), save_path)
            print("Saved best:", save_path)

    return test_ds

# ================================================================
# EVALUATION (SAVES BLEU + PREDICTIONS)
# ================================================================
def evaluate_french(src, tgt, test_ds):
    print(f"\nüîç Evaluating {src} ‚Üí {tgt}")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    model = SiglipFusionModel().to(device)
    ckpt = OUT_DIR / f"ecomm_{src}_{tgt}.pt"
    print("Loading:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs, sources, images = [], [], [], []

    for sample in tqdm(test_ds):
        src_text = sample["src"]
        tgt_text = sample["tgt"]
        img_name = sample["img"]

        sources.append(src_text)
        refs.append(tgt_text)
        images.append(img_name)

        enc = tokenizer(
            src_text,
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt"
        ).to(device)

        img = safe_load(img_name)
        pixel = processor(images=[img], return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():
            v = model.vision(pixel).last_hidden_state[:,0,:]
            img_e = model.proj(v).unsqueeze(1)
            txt_e = model.txt_emb(enc["input_ids"])

            fused = model.fusion(img_e, txt_e)
            fused_mask = torch.cat(
                [torch.ones((1,1),device=device), enc["attention_mask"]],
                dim=1
            )

            gen = model.mbart.generate(
                inputs_embeds=fused,
                attention_mask=fused_mask,
                max_length=MAX_LEN,
                num_beams=5,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )

        preds.append(tokenizer.decode(gen[0], skip_special_tokens=True))

    # BLEU score
    bleu = sacrebleu.compute(predictions=preds, references=[refs])["score"]
    print(f"‚≠ê BLEU ({src} ‚Üí {tgt}) = {bleu:.4f}")

    # Save BLEU
    bleu_file = OUT_DIR / f"bleu_{src}_{tgt}.txt"
    with open(bleu_file, "w") as f:
        f.write(f"BLEU = {bleu:.4f}\n")

    # Save predictions
    pred_file = OUT_DIR / f"preds_{src}_{tgt}.tsv"
    pd.DataFrame({
        "source": sources,
        "reference": refs,
        "prediction": preds,
        "image": images
    }).to_csv(pred_file, sep="\t", index=False)

    print("Saved:", bleu_file)
    print("Saved:", pred_file)

    return bleu

# ================================================================
# RUN TRAINING
# ================================================================
test_en_fr = train_french("en", "fr", "en_fr")
test_fr_en = train_french("fr", "en", "fr_en")

# ================================================================
# RUN EVALUATION
# ================================================================
#evaluate_french("en", "fr", test_en_fr)
#evaluate_french("fr", "en", test_fr_en)


Using device: cuda

üî• Training en ‚Üí fr
en ‚Üí fr: Train=6000, Test=0
Loading pretrained: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_en_fr_mm_best.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:05<00:00,  8.22it/s]


Epoch 1 Loss = 1.8187
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:26<00:00,  7.76it/s]


Epoch 2 Loss = 1.5648
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [07:07<00:00,  7.02it/s]


Epoch 3 Loss = 1.4165
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:27<00:00,  7.74it/s]


Epoch 4 Loss = 1.2953
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:49<00:00,  7.33it/s]


Epoch 5 Loss = 1.2225
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:46<00:00,  7.38it/s]


Epoch 6 Loss = 1.1116
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:28<00:00,  7.73it/s]


Epoch 7 Loss = 1.0424
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:27<00:00,  7.75it/s]


Epoch 8 Loss = 0.9805
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:29<00:00,  7.71it/s]


Epoch 9 Loss = 0.9218
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:27<00:00,  7.74it/s]


Epoch 10 Loss = 0.8727
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt

üî• Training fr ‚Üí en
fr ‚Üí en: Train=6000, Test=0
Loading pretrained: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_fr_en_mm_best.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:06<00:00,  8.19it/s]


Epoch 1 Loss = 2.0215
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:51<00:00,  7.29it/s]


Epoch 2 Loss = 1.8208
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:27<00:00,  7.74it/s]


Epoch 3 Loss = 1.6903
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:26<00:00,  7.75it/s]


Epoch 4 Loss = 1.5824
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:49<00:00,  7.33it/s]


Epoch 5 Loss = 1.4867
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:26<00:00,  7.75it/s]


Epoch 6 Loss = 1.4055
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:50<00:00,  7.30it/s]


Epoch 7 Loss = 1.3354
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:28<00:00,  7.72it/s]


Epoch 8 Loss = 1.2661
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:35<00:00,  7.59it/s]


Epoch 9 Loss = 1.2136
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3000/3000 [06:31<00:00,  7.67it/s]


Epoch 10 Loss = 1.1580
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt

üîç Evaluating en ‚Üí fr
Loading: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


0it [00:00, ?it/s]


IndexError: list index out of range

In [None]:
# ================================================================
# üîÅ CONTINUED TRAINING + EVALUATION ON E-COMMERCE DATA (IMAGE + TEXT)
#    - Directions: en‚Üíde, de‚Üíen
#    - Models:
#        * Multimodal (SigLIP + mBART + LoRA + Fusion)
#        * Text-only (mBART + LoRA)
#    - Uses existing SigLIP + MBART + LoRA checkpoints from Drive
#    - Uses at most 15,000 train samples per direction (per split)
#    - Computes BLEU on validation split and saves results to Drive
# ================================================================

import os
import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pandas as pd
from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

from tqdm import tqdm

# ================================================================
# Install / import HF + PEFT + evaluate
# ================================================================
try:
    from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
    from transformers import SiglipVisionModel, SiglipProcessor
    from peft import LoraConfig, get_peft_model, TaskType
    import evaluate
except Exception:
    !pip install -q transformers peft accelerate sentencepiece evaluate
    from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
    from transformers import SiglipVisionModel, SiglipProcessor
    from peft import LoraConfig, get_peft_model, TaskType
    import evaluate

# BLEU metric
sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# Mount Drive
# ================================================================
from google.colab import drive
drive.mount("/content/drive", force_remount=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if device.type == "cuda":
    torch.backends.cudnn.benchmark = True

# ================================================================
# PATHS ‚Äì ADAPTED TO YOUR DRIVE STRUCTURE
# ================================================================
# 1) Model folder ‚Äì REAL path behind shared Drive shortcut
MODEL_DIR = Path(
    "/content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/"
    "multimodal_translation_models_siglip_lora_fusion"
)

print("üìå MODEL_DIR:", MODEL_DIR)
print("üìÇ Contents of MODEL_DIR:")
!ls -lh "/content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion"

# 2) Dataset TSV + images (E-commerce dataset)
ECOMM_TSV = (
    "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de.tsv"
)

ECOMM_IMG_DIR = (
    "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images"
)  # contains train/val/test subfolders

# Where to save new finetuned models
OUT_MODEL_DIR = MODEL_DIR / "ecomm_finetuned"
OUT_MODEL_DIR.mkdir(parents=True, exist_ok=True)
print("üìå OUT_MODEL_DIR:", OUT_MODEL_DIR)

# Where to save evaluation outputs
EVAL_DIR = OUT_MODEL_DIR / "evals"
EVAL_DIR.mkdir(parents=True, exist_ok=True)
print("üìå EVAL_DIR:", EVAL_DIR)

# ================================================================
# LOAD CONFIG FROM PREVIOUS TRAINING (for LoRA + hyperparams)
# ================================================================
cfg_path = MODEL_DIR / "config_siglip_fusion_lora.json"
print("Loading config from:", cfg_path)

with open(cfg_path, "r") as f:
    config_dict = json.load(f)

import types
config = types.SimpleNamespace(**config_dict)

# Fallbacks if any field missing
BATCH_SIZE = getattr(config, "batch_size", 24)
MAX_LEN = getattr(config, "max_length", 64)
LR = getattr(config, "lr", 2e-4)
EPOCHS = getattr(config, "num_epochs", 3)
vision_model_name = getattr(config, "vision_model_name", "google/siglip-base-patch16-224")

# Limit samples (per split)
MAX_TRAIN_SAMPLES = 15000
MAX_VAL_SAMPLES   = 2000   # you can increase if you want

LANG_CODES = {"en": "en_XX", "de": "de_DE"}

print("‚úîÔ∏è Training hyperparams:")
print("   batch_size =", BATCH_SIZE)
print("   max_len    =", MAX_LEN)
print("   lr         =", LR)
print("   epochs     =", EPOCHS)
print("   vision     =", vision_model_name)
print("   train_max  =", MAX_TRAIN_SAMPLES)
print("   val_max    =", MAX_VAL_SAMPLES)

# ================================================================
# SAFE IMAGE LOADER ‚Äì searches train/val/test
# ================================================================
def safe_load_image(filename: Any) -> Image.Image:
    if filename is None or str(filename).lower() == "nan":
        return Image.new("RGB", (224, 224), (128, 128, 128))

    filename = str(filename).strip()

    # try each split folder
    for split in ["train", "val", "test"]:
        fp = Path(ECOMM_IMG_DIR) / split / filename
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except Exception:
                pass

    # last-resort fallback
    return Image.new("RGB", (224, 224), (128, 128, 128))

# ================================================================
# LoRA helpers
# ================================================================
def apply_lora_to_mbart(mbart: MBartForConditionalGeneration):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout,
        target_modules=config.lora_targets,
    )
    return get_peft_model(mbart, cfg)

def freeze_all_except_lora(mbart: MBartForConditionalGeneration):
    for name, p in mbart.named_parameters():
        if "lora" not in name:
            p.requires_grad = False

# ================================================================
# Models
# ================================================================
class FusionBlock(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
        )
        self.encoder = nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_embed, text_embed):
        # img_embed: [B, 1, d_model]
        # text_embed: [B, T, d_model]
        x = torch.cat([img_embed, text_embed], dim=1)
        return self.encoder(x)

class MultiModalModel(nn.Module):
    """
    SigLIP vision encoder + MBART with LoRA + fusion.
    """

    def __init__(self):
        super().__init__()
        print("Loading SigLIP:", vision_model_name)
        self.vision = SiglipVisionModel.from_pretrained(vision_model_name)
        for p in self.vision.parameters():
            p.requires_grad = False  # freeze vision

        print("Loading MBART base...")
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)
        self.text_emb = self.mbart.get_input_embeddings()

        vision_dim = self.vision.config.hidden_size
        self.proj = nn.Linear(vision_dim, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def forward(self, input_ids, attention_mask, pixel_values, labels):
        # 1) Vision forward (frozen)
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:, 0, :]

        img_embed = self.proj(vis).unsqueeze(1)   # [B, 1, d_model]
        txt_embed = self.text_emb(input_ids)      # [B, T, d_model]

        fused = self.fusion(img_embed, txt_embed) # [B, 1+T, d_model]
        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0), 1), device=device), attention_mask],
            dim=1,
        )

        outputs = self.mbart(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            labels=labels,
        )
        return outputs

    def generate(self, input_ids, attention_mask, pixel_values, tokenizer, max_length=None, num_beams=5):
        if max_length is None:
            max_length = MAX_LEN

        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:, 0, :]

        img_embed = self.proj(vis).unsqueeze(1)
        txt_embed = self.text_emb(input_ids)

        fused = self.fusion(img_embed, txt_embed)
        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0), 1), device=device), attention_mask],
            dim=1,
        )

        gen_ids = self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            num_beams=num_beams,
            max_length=max_length,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

class TextOnlyModel(nn.Module):
    """
    MBART with LoRA (no images).
    """

    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora_to_mbart(base)

    def forward(self, input_ids, attention_mask, labels):
        return self.mbart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )

    def generate(self, input_ids, attention_mask, tokenizer, max_length=None, num_beams=5):
        if max_length is None:
            max_length = MAX_LEN
        gen_ids = self.mbart.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_length=max_length,
            num_beams=num_beams,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )
        return gen_ids

# ================================================================
# Dataset (TSV with set_name, source, target, image_file)
# ================================================================
class EcommDataset(Dataset):
    def __init__(self, csv_path: str, split: str, src_lang: str, tgt_lang: str,
                 max_samples: int | None = None):
        assert split in ["train", "val", "test"]

        if csv_path.endswith(".tsv") or csv_path.endswith(".txt"):
            df = pd.read_csv(csv_path, sep="\t")
        else:
            df = pd.read_csv(csv_path)

        df["set_name"] = df["set_name"].str.lower()

        if split == "val":
            mask = df["set_name"].isin(["val", "valid", "validation"])
        else:
            mask = df["set_name"] == split

        df = df[mask].reset_index(drop=True)

        if max_samples is not None and len(df) > max_samples:
            df = df.sample(n=max_samples, random_state=42).reset_index(drop=True)

        self.df = df
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        print(f"[{src_lang}->{tgt_lang}] {split} rows:", len(self.df))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        row = self.df.iloc[idx]

        en = str(row["source"])
        de = str(row["target"])
        img_file = row["image_file"]

        if self.src_lang == "en" and self.tgt_lang == "de":
            src_text = en
            tgt_text = de
        elif self.src_lang == "de" and self.tgt_lang == "en":
            src_text = de
            tgt_text = en
        else:
            # just in case; but we only use en<->de here
            src_text = en
            tgt_text = de

        return {
            "src": src_text,
            "tgt": tgt_text,
            "img": img_file,
        }

# ================================================================
# Collate
# ================================================================
def make_collate_fn(tokenizer, image_processor, max_length: int):
    def collate_fn(batch):
        src_texts = [b["src"] for b in batch]
        tgt_texts = [b["tgt"] for b in batch]
        img_files = [b["img"] for b in batch]

        # tokenize source
        enc_src = tokenizer(
            src_texts,
            padding="max_length",
            truncation=True,
            max_length=max_length,
            return_tensors="pt",
        )

        # tokenize target
        with tokenizer.as_target_tokenizer():
            enc_tgt = tokenizer(
                tgt_texts,
                padding="max_length",
                truncation=True,
                max_length=max_length,
                return_tensors="pt",
            )

        labels = enc_tgt["input_ids"]
        labels[labels == tokenizer.pad_token_id] = -100

        # images
        imgs = [safe_load_image(fn) for fn in img_files]
        pixel_values = image_processor(images=imgs, return_tensors="pt")["pixel_values"]

        batch_out = {
            "input_ids": enc_src["input_ids"].to(device),
            "attention_mask": enc_src["attention_mask"].to(device),
            "labels": labels.to(device),
            "pixel_values": pixel_values.to(device),
        }
        return batch_out

    return collate_fn

# ================================================================
# TRAINING LOOP
# ================================================================
def train_model(src_lang: str, tgt_lang: str, model_type: str,
                tokenizer, processor,
                train_limit=15000, val_limit=2000):
    """
    model_type: "mm" (multimodal) or "txt" (text-only)
    """

    print("\n" + "=" * 70)
    print(f"üöÄ TRAINING {model_type.upper()} MODEL FOR {src_lang} ‚Üí {tgt_lang}")
    print("=" * 70)

    tokenizer.src_lang = LANG_CODES[src_lang]
    tokenizer.tgt_lang = LANG_CODES[tgt_lang]

    # ================================================================
    # LOAD DATASET ‚Äî FIXED PARAMETER NAME (max_samples instead of limit)
    # ================================================================
    train_ds = EcommDataset(ECOMM_TSV, "train", src_lang, tgt_lang, max_samples=train_limit)
    val_ds   = EcommDataset(ECOMM_TSV, "val",   src_lang, tgt_lang, max_samples=val_limit)


    collate_fn = make_collate_fn(tokenizer, image_processor, MAX_LEN)

    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=False,
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=collate_fn,
        num_workers=0,
        pin_memory=False,
    )

    # model + base checkpoint (pre-ecomm)
    if model_type == "mm":
        model = MultiModalModel().to(device)
        base_ckpt = MODEL_DIR / f"siglip_fusion_lora_{src_lang}_{tgt_lang}_mm_best.pt"
    else:
        model = TextOnlyModel().to(device)
        base_ckpt = MODEL_DIR / f"mbart_lora_{src_lang}_{tgt_lang}_text_best.pt"

    print("üì• Loading base checkpoint:", base_ckpt)
    model.load_state_dict(torch.load(base_ckpt, map_location=device))

    # freeze all but LoRA
    freeze_all_except_lora(model.mbart)

    optimizer = torch.optim.AdamW(
        [p for p in model.parameters() if p.requires_grad],
        lr=LR,
        weight_decay=0.01,
    )

    best_val = float("inf")
    out_ckpt = OUT_MODEL_DIR / f"ecomm_{src_lang}_{tgt_lang}_{model_type}.pt"

    for epoch in range(1, EPOCHS + 1):
        print(f"\nEpoch {epoch}/{EPOCHS}")

        # ---- train ----
        model.train()
        running = 0.0
        n_steps = 0

        for batch in tqdm(train_loader, desc="Train"):
            optimizer.zero_grad()

            if model_type == "mm":
                out = model(
                    batch["input_ids"],
                    batch["attention_mask"],
                    batch["pixel_values"],
                    batch["labels"],
                )
            else:
                out = model(
                    batch["input_ids"],
                    batch["attention_mask"],
                    batch["labels"],
                )

            loss = out.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()

            running += loss.item()
            n_steps += 1

        train_loss = running / max(1, n_steps)

        # ---- val ----
        model.eval()
        val_running = 0.0
        val_steps = 0

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Val"):
                if model_type == "mm":
                    out = model(
                        batch["input_ids"],
                        batch["attention_mask"],
                        batch["pixel_values"],
                        batch["labels"],
                    )
                else:
                    out = model(
                        batch["input_ids"],
                        batch["attention_mask"],
                        batch["labels"],
                    )

                val_running += out.loss.item()
                val_steps += 1

        val_loss = val_running / max(1, val_steps)
        print(f"Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), out_ckpt)
            print(f"  ‚úÖ New best; saved to {out_ckpt}")

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    print(f"‚úÖ Done {model_type.upper()} {src_lang}‚Üí{tgt_lang}; best val={best_val:.4f}")
    return val_ds  # return val dataset for evaluation later if needed

# ================================================================
# EVALUATION (BLEU ON VAL)
# ================================================================
def evaluate_model(src_lang: str, tgt_lang: str, model_type: str,
                   tokenizer, image_processor):
    """
    Evaluate given ecomm model on VAL split using sacreBLEU.
    """
    assert model_type in ["mm", "txt"]

    print("\n" + "=" * 60)
    print(f"üîç Evaluating {model_type.upper()} model for {src_lang} ‚Üí {tgt_lang}")
    print("=" * 60)

    tokenizer.src_lang = LANG_CODES[src_lang]
    tokenizer.tgt_lang = LANG_CODES[tgt_lang]

    # load val data (we eval on same MAX_VAL_SAMPLES subset)
    val_ds = EcommDataset(ECOMM_TSV, "val", src_lang, tgt_lang, max_samples=MAX_VAL_SAMPLES)

    # load model from ecomm_finetuned
    if model_type == "mm":
        model = MultiModalModel().to(device)
        ckpt = OUT_MODEL_DIR / f"ecomm_{src_lang}_{tgt_lang}_mm.pt"
    else:
        model = TextOnlyModel().to(device)
        ckpt = OUT_MODEL_DIR / f"ecomm_{src_lang}_{tgt_lang}_txt.pt"

    print("üì• Loading finetuned checkpoint:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device))
    model.eval()

    preds = []
    golds = []
    srcs  = []
    imgs  = []

    for i in tqdm(range(len(val_ds)), desc="Eval"):
        sample = val_ds[i]
        src = sample["src"]
        tgt = sample["tgt"]
        img_file = sample["img"]

        srcs.append(src)
        golds.append(tgt)
        imgs.append(img_file)

        # tokenize source
        enc = tokenizer(
            src,
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt",
        ).to(device)

        if model_type == "mm":
            # image
            img = safe_load_image(img_file)
            pixel = image_processor(images=[img], return_tensors="pt")["pixel_values"].to(device)

            gen_ids = model.generate(
                enc["input_ids"],
                enc["attention_mask"],
                pixel,
                tokenizer,
                max_length=MAX_LEN,
                num_beams=5,
            )
        else:
            gen_ids = model.generate(
                enc["input_ids"],
                enc["attention_mask"],
                tokenizer,
                max_length=MAX_LEN,
                num_beams=5,
            )

        pred_text = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
        preds.append(pred_text)

    # BLEU
    references = [[g for g in golds]]  # sacrebleu expects list of reference lists
    bleu = sacrebleu.compute(predictions=preds, references=[golds])["score"]

    # save predictions
    out_tsv = EVAL_DIR / f"preds_{model_type}_{src_lang}_{tgt_lang}.tsv"
    pd.DataFrame({
        "src": srcs,
        "gold": golds,
        "pred": preds,
        "image_file": imgs,
    }).to_csv(out_tsv, sep="\t", index=False)

    out_bleu = EVAL_DIR / f"bleu_{model_type}_{src_lang}_{tgt_lang}.txt"
    with open(out_bleu, "w") as f:
        f.write(f"BLEU = {bleu:.4f}\n")

    print(f"‚≠ê BLEU ({model_type.upper()} {src_lang}‚Üí{tgt_lang}) = {bleu:.4f}")
    print(f"üìÅ Predictions saved to: {out_tsv}")
    print(f"üìÅ BLEU score saved to: {out_bleu}")

    return bleu

# ================================================================
# RUN TRAINING + EVAL FOR EN‚ÜîDE
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
image_processor = SiglipProcessor.from_pretrained(vision_model_name)

# -------- TRAIN --------
# EN ‚Üí DE
train_model("en", "de", "mm",  tokenizer, image_processor)
#train_model("en", "de", "txt", tokenizer, image_processor)

# DE ‚Üí EN
train_model("de", "en", "mm",  tokenizer, image_processor)
#train_model("de", "en", "txt", tokenizer, image_processor)

# -------- EVAL (BLEU on VAL) --------
results = {}
#results["en_de_mm"]  = evaluate_model("en", "de", "mm",  tokenizer, image_processor)
#results["en_de_txt"] = evaluate_model("en", "de", "txt", tokenizer, image_processor)
#results["de_en_mm"]  = evaluate_model("de", "en", "mm",  tokenizer, image_processor)
#results["de_en_txt"] = evaluate_model("de", "en", "txt", tokenizer, image_processor)

# Save combined BLEU summary
summary_path = EVAL_DIR / "BLEU_summary_ecomm_en_de.json"
with open(summary_path, "w") as f:
    json.dump(results, f, indent=2)

print("\n===================================================")
print("üéâ All fine-tuning & evaluation runs completed.")
print("üìå BLEU summary saved at:", summary_path)
print("===================================================")


Mounted at /content/drive
Using device: cuda
üìå MODEL_DIR: /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion
üìÇ Contents of MODEL_DIR:
total 35G
-rw------- 1 root root  962 Dec  3 00:47 config_siglip_fusion_lora_all6.json
-rw------- 1 root root  776 Nov 20 00:52 config_siglip_fusion_lora.json
drwx------ 2 root root 4.0K Dec  3 17:50 ecomm_eval_en_de
drwx------ 2 root root 4.0K Dec  3 01:46 ecomm_finetuned
drwx------ 2 root root 4.0K Dec  7 09:02 ecomm_finetuned_fixed
drwx------ 2 root root 4.0K Dec  8 21:35 ecomm_french_finetuned
-rw------- 1 root root 2.3G Dec  3 04:36 mbart_lora_all6_text_best.pt
-rw------- 1 root root 2.3G Nov 19 17:04 mbart_lora_de_en_text_best.pt
-rw------- 1 root root 2.3G Nov 19 21:37 mbart_lora_de_fr_text_best.pt
-rw------- 1 root root 2.3G Nov 19 08:19 mbart_lora_en_de_text_best.pt
-rw------- 1 root root 2.3G Nov 19 12:56 mbart_lora_en_fr_text_best.pt
-rw------- 1 root root 2.3G Nov 20

Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [1:34:11<00:00,  1.33it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [12:25<00:00,  1.34it/s]


Train loss: 1.5769 | Val loss: 1.4035
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt

Epoch 2/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [17:03<00:00,  7.33it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:13<00:00, 13.67it/s]


Train loss: 1.4200 | Val loss: 1.3532
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt

Epoch 3/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [17:03<00:00,  7.33it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:12<00:00, 13.70it/s]


Train loss: 1.3115 | Val loss: 1.2850
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt

Epoch 4/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [17:02<00:00,  7.33it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:12<00:00, 13.71it/s]


Train loss: 1.2313 | Val loss: 1.3019

Epoch 5/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [16:39<00:00,  7.50it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:13<00:00, 13.68it/s]


Train loss: 1.1848 | Val loss: 1.2396
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt

Epoch 6/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [17:01<00:00,  7.34it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:13<00:00, 13.66it/s]


Train loss: 1.1384 | Val loss: 1.2460
‚úÖ Done MM en‚Üíde; best val=1.2396

üöÄ TRAINING MM MODEL FOR de ‚Üí en
[de->en] train rows: 15000
[de->en] val rows: 2000
Loading SigLIP: google/siglip-base-patch16-224
Loading MBART base...
üì• Loading base checkpoint: /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/siglip_fusion_lora_de_en_mm_best.pt

Epoch 1/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [16:39<00:00,  7.50it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:13<00:00, 13.67it/s]


Train loss: 1.4294 | Val loss: 1.2780
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt

Epoch 2/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [17:01<00:00,  7.34it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:13<00:00, 13.69it/s]


Train loss: 1.2967 | Val loss: 1.2336
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt

Epoch 3/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [16:46<00:00,  7.45it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:11<00:00, 13.96it/s]


Train loss: 1.2047 | Val loss: 1.1781
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt

Epoch 4/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [16:57<00:00,  7.37it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:11<00:00, 13.99it/s]


Train loss: 1.1826 | Val loss: 1.1622
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt

Epoch 5/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [16:58<00:00,  7.37it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:11<00:00, 13.95it/s]


Train loss: 1.0817 | Val loss: 1.1250
  ‚úÖ New best; saved to /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt

Epoch 6/6


Train: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 7500/7500 [16:39<00:00,  7.50it/s]
Val: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [01:11<00:00, 14.00it/s]

Train loss: 1.0206 | Val loss: 1.1287
‚úÖ Done MM de‚Üíen; best val=1.1250

üéâ All fine-tuning & evaluation runs completed.
üìå BLEU summary saved at: /content/drive/.shortcut-targets-by-id/1GcIeOxxtd-cnipwAaf8rdRqrjBuQeOWP/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/evals/BLEU_summary_ecomm_en_de.json





In [None]:
# ================================================================
# FINAL ERROR-PROOF EVALUATION SCRIPT (1500 TEST SAMPLES)
# SigLIP + mBART + LoRA Fusion ‚Äî Evaluation Only
# ================================================================

import torch
import torch.nn as nn
import pandas as pd
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import evaluate

from transformers import (
    MBart50TokenizerFast, MBartForConditionalGeneration,
    SiglipVisionModel, SiglipProcessor
)
from peft import LoraConfig, get_peft_model, TaskType

# ================================================================
# DEVICE
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
OUT_DIR = Path(BASE) / "ecomm_french_finetuned"

TSV_FILE = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
)

IMG_DIR = Path(
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/images"
)

LANG_CODES = {"en": "en_XX", "fr": "fr_XX"}


# ================================================================
# SAFE IMAGE LOADER
# ================================================================
def safe_load(img_name):
    if not isinstance(img_name, str):
        return Image.new("RGB", (224,224), (128,128,128))

    for split in ["train", "val", "test"]:
        fp = IMG_DIR / split / img_name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                pass

    return Image.new("RGB", (224,224), (128,128,128))


# ================================================================
# LOAD EXACT 1500 TEST SAMPLES (FROM FIRST 7500 ONLY)
# ================================================================
def load_french_test_set():
    df = pd.read_csv(TSV_FILE, sep="\t")
    print("\nColumns in dataset:", df.columns.tolist())

    if "french" not in df.columns:
        raise ValueError("‚ùå ERROR: 'french' column NOT found in dataset!")

    df = df.dropna(subset=["source", "french"]).reset_index(drop=True)

    # First 7500 rows only
    df = df.iloc[:7500].reset_index(drop=True)
    print("Usable rows (first 7500):", len(df))

    # Shuffle
    df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)

    # Last 1500 = test
    test_df = df.tail(1500).reset_index(drop=True)
    print("Loaded test size:", len(test_df))

    return test_df


# ================================================================
# MODEL DEFINITIONS
# ================================================================
def apply_lora(model):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    )
    return get_peft_model(model, cfg)


class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.enc = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.1,
                batch_first=True
            ),
            num_layers=1
        )

    def forward(self, img_emb, txt_emb):
        x = torch.cat([img_emb, txt_emb], dim=1)
        return self.enc(x)


class SiglipFusionModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Vision encoder
        self.vision = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
        for p in self.vision.parameters():
            p.requires_grad = False

        # mBART + LoRA
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)
        self.txt_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)


# ================================================================
# EVALUATION FUNCTION (FIXED BLEU)
# ================================================================
def evaluate_french(src, tgt, test_df):

    print(f"\nüîç Evaluating {src} ‚Üí {tgt} on {len(test_df)} samples")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ckpt = OUT_DIR / f"ecomm_{src}_{tgt}.pt"
    print("Loading finetuned model:", ckpt)

    model = SiglipFusionModel().to(device)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs, srcs, imgs = [], [], [], []

    for idx in tqdm(range(len(test_df))):
        row = test_df.iloc[idx]

        en = str(row["source"])
        fr = str(row["french"])
        img_file = row["image_file"]

        # Direction
        src_text = en if src == "en" else fr
        tgt_text = fr if tgt == "fr" else en

        refs.append(str(tgt_text))
        srcs.append(src_text)
        imgs.append(img_file)

        # Tokenize
        enc = tokenizer(
            src_text,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        ).to(device)

        # Process image
        img = safe_load(img_file)
        pixel = processor(images=[img], return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():

            v = model.vision(pixel).last_hidden_state[:, 0, :]
            img_e = model.proj(v).unsqueeze(1)

            txt_e = model.txt_emb(enc["input_ids"])

            fused = model.fusion(img_e, txt_e)
            fused_mask = torch.cat(
                [torch.ones((1, 1), device=device), enc["attention_mask"]],
                dim=1
            )

            gen = model.mbart.generate(
                inputs_embeds=fused,
                attention_mask=fused_mask,
                num_beams=5,
                max_length=64,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )

        preds.append(tokenizer.decode(gen[0], skip_special_tokens=True))

    # ============================================================
    # DEBUG COUNTS
    # ============================================================
    print("\nüîé DEBUG ‚Äî Prediction & Reference Counts")
    print("Pred count:", len(preds))
    print("Ref count: ", len(refs))

    # ============================================================
    # ‚≠ê FIXED BLEU ‚Äî WRAP REFS PROPERLY
    # ============================================================
    wrapped_refs = [[r] for r in refs]   # <-- THE FIX

    bleu = sacrebleu.compute(
        predictions=preds,
        references=wrapped_refs
    )["score"]

    print(f"\n‚≠ê BLEU ({src} ‚Üí {tgt}) = {bleu}")

    # Save results
    out_file = OUT_DIR / f"preds_{src}_{tgt}_1500.tsv"
    pd.DataFrame({
        "source": srcs,
        "gold": refs,
        "pred": preds,
        "image_file": imgs
    }).to_csv(out_file, sep="\t", index=False)

    print("Saved predictions ‚Üí", out_file)
    return bleu


# ================================================================
# RUN EVALUATION
# ================================================================
processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-224")
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)

test_df = load_french_test_set()

evaluate_french("en", "fr", test_df)
evaluate_french("fr", "en", test_df)


Using: cuda

Columns in dataset: ['project_name', 'set_name', 'image_id', 'image_file', 'source', 'target', 'french']
Usable rows (first 7500): 7500
Loaded test size: 1500

üîç Evaluating en ‚Üí fr on 1500 samples
Loading finetuned model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [15:46<00:00,  1.58it/s]



üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê BLEU (en ‚Üí fr) = 32.086293476649935
Saved predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/preds_en_fr_1500.tsv

üîç Evaluating fr ‚Üí en on 1500 samples
Loading finetuned model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [13:50<00:00,  1.81it/s]


üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê BLEU (fr ‚Üí en) = 31.07808077842777
Saved predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/preds_fr_en_1500.tsv





31.07808077842777

In [None]:
# ================================================================
# FINAL ERROR-PROOF EVALUATION SCRIPT (1500 TEST SAMPLES)
# SigLIP + mBART + LoRA Fusion ‚Äî Evaluation Only
# ================================================================

import torch
import torch.nn as nn
import pandas as pd
from PIL import Image
from pathlib import Path
from tqdm import tqdm
import evaluate

from transformers import (
    MBart50TokenizerFast, MBartForConditionalGeneration,
    SiglipVisionModel, SiglipProcessor
)
from peft import LoraConfig, get_peft_model, TaskType

# ================================================================
# DEVICE
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
OUT_DIR = Path(BASE) / "ecomm_french_finetuned"

TSV_FILE = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
)

IMG_DIR = Path(
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/images"
)

LANG_CODES = {"en": "en_XX", "fr": "fr_XX"}


# ================================================================
# SAFE IMAGE LOADER
# ================================================================
def safe_load(img_name):
    if not isinstance(img_name, str):
        return Image.new("RGB", (224,224), (128,128,128))

    for split in ["train", "val", "test"]:
        fp = IMG_DIR / split / img_name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                pass

    return Image.new("RGB", (224,224), (128,128,128))


# ================================================================
# LOAD EXACT 1500 TEST SAMPLES (FROM FIRST 7500 ONLY)
# ================================================================
def load_french_test_set():
    df = pd.read_csv(TSV_FILE, sep="\t")
    print("\nColumns in dataset:", df.columns.tolist())

    if "french" not in df.columns:
        raise ValueError("‚ùå ERROR: 'french' column NOT found in dataset!")

    df = df.dropna(subset=["source", "french"]).reset_index(drop=True)

    # First 7500 rows only
    df = df.iloc[:7500].reset_index(drop=True)
    print("Usable rows (first 7500):", len(df))

    # Shuffle
    df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)

    # Last 1500 = test
    test_df = df.head(1500).reset_index(drop=True)
    print("Loaded test size:", len(test_df))

    return test_df


# ================================================================
# MODEL DEFINITIONS
# ================================================================
def apply_lora(model):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8, lora_alpha=16, lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    )
    return get_peft_model(model, cfg)


class FusionBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.enc = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(
                d_model=d_model,
                nhead=8,
                dim_feedforward=2048,
                dropout=0.1,
                batch_first=True
            ),
            num_layers=1
        )

    def forward(self, img_emb, txt_emb):
        x = torch.cat([img_emb, txt_emb], dim=1)
        return self.enc(x)


class SiglipFusionModel(nn.Module):
    def __init__(self):
        super().__init__()

        # Vision encoder
        self.vision = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224")
        for p in self.vision.parameters():
            p.requires_grad = False

        # mBART + LoRA
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)
        self.txt_emb = self.mbart.get_input_embeddings()

        self.proj = nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)


# ================================================================
# EVALUATION FUNCTION (FIXED BLEU)
# ================================================================
def evaluate_french(src, tgt, test_df):

    print(f"\nüîç Evaluating {src} ‚Üí {tgt} on {len(test_df)} samples")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ckpt = OUT_DIR / f"ecomm_{src}_{tgt}.pt"
    print("Loading finetuned model:", ckpt)

    model = SiglipFusionModel().to(device)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs, srcs, imgs = [], [], [], []

    for idx in tqdm(range(len(test_df))):
        row = test_df.iloc[idx]

        en = str(row["source"])
        fr = str(row["french"])
        img_file = row["image_file"]

        # Direction
        src_text = en if src == "en" else fr
        tgt_text = fr if tgt == "fr" else en

        refs.append(str(tgt_text))
        srcs.append(src_text)
        imgs.append(img_file)

        # Tokenize
        enc = tokenizer(
            src_text,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        ).to(device)

        # Process image
        img = safe_load(img_file)
        pixel = processor(images=[img], return_tensors="pt")["pixel_values"].to(device)

        with torch.no_grad():

            v = model.vision(pixel).last_hidden_state[:, 0, :]
            img_e = model.proj(v).unsqueeze(1)

            txt_e = model.txt_emb(enc["input_ids"])

            fused = model.fusion(img_e, txt_e)
            fused_mask = torch.cat(
                [torch.ones((1, 1), device=device), enc["attention_mask"]],
                dim=1
            )

            gen = model.mbart.generate(
                inputs_embeds=fused,
                attention_mask=fused_mask,
                num_beams=5,
                max_length=64,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )

        preds.append(tokenizer.decode(gen[0], skip_special_tokens=True))

    # ============================================================
    # DEBUG COUNTS
    # ============================================================
    print("\nüîé DEBUG ‚Äî Prediction & Reference Counts")
    print("Pred count:", len(preds))
    print("Ref count: ", len(refs))

    # ============================================================
    # ‚≠ê FIXED BLEU ‚Äî WRAP REFS PROPERLY
    # ============================================================
    wrapped_refs = [[r] for r in refs]   # <-- THE FIX

    bleu = sacrebleu.compute(
        predictions=preds,
        references=wrapped_refs
    )["score"]

    print(f"\n‚≠ê BLEU ({src} ‚Üí {tgt}) = {bleu}")

    # Save results
    out_file = OUT_DIR / f"preds_{src}_{tgt}_1500.tsv"
    pd.DataFrame({
        "source": srcs,
        "gold": refs,
        "pred": preds,
        "image_file": imgs
    }).to_csv(out_file, sep="\t", index=False)

    print("Saved predictions ‚Üí", out_file)
    return bleu


# ================================================================
# RUN EVALUATION
# ================================================================
processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-224")
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)

test_df = load_french_test_set()

evaluate_french("en", "fr", test_df)
evaluate_french("fr", "en", test_df)


Using: cuda

Columns in dataset: ['project_name', 'set_name', 'image_id', 'image_file', 'source', 'target', 'french']
Usable rows (first 7500): 7500
Loaded test size: 1500

üîç Evaluating en ‚Üí fr on 1500 samples
Loading finetuned model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [15:41<00:00,  1.59it/s]



üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê BLEU (en ‚Üí fr) = 49.444148843245
Saved predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/preds_en_fr_1500.tsv

üîç Evaluating fr ‚Üí en on 1500 samples
Loading finetuned model: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/ecomm_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [13:46<00:00,  1.81it/s]


üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê BLEU (fr ‚Üí en) = 42.289458052779175
Saved predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_finetuned/preds_fr_en_1500.tsv





42.289458052779175


Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [04:25<00:00,  1.13it/s]


Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (en ‚Üí de, mm): 20.31750199975862

Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [02:56<00:00,  1.70it/s]


Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (en ‚Üí de, txt): 20.19394399919782

Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [02:50<00:00,  1.76it/s]


Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (de ‚Üí en, mm): 21.885141269601355

Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [02:41<00:00,  1.86it/s]

Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (de ‚Üí en, txt): 22.135581100657067





In [None]:
# ================================================================
# FIXED EVALUATION SCRIPT ‚Äî BLEU MISMATCH SOLVED FOREVER (EN <-> DE)
# ================================================================

import torch
from transformers import MBart50TokenizerFast, SiglipProcessor
from transformers import MBartForConditionalGeneration, SiglipVisionModel
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import evaluate
import pandas as pd
from PIL import Image
from pathlib import Path

# ================================================================
# CONSTANTS
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sacrebleu = evaluate.load("sacrebleu")

LANG_CODES = {"en": "en_XX", "de": "de_DE"}

MAX_LEN = 64
TEST_LIMIT = 1500  # <= change this to 1500 for full test

VISION_MODEL_NAME = "google/siglip-base-patch16-224"

BASE_DIR = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
CKPT_DIR = Path(BASE_DIR) / "ecomm_finetuned"

ECOMM_TSV = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de.tsv"
)

IMG_ROOT = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images"


# ================================================================
# SAFE IMAGE LOADING
# ================================================================
def safe_load_image(img_name):
    """Load image safely; return fallback image if missing."""
    if not isinstance(img_name, str) or img_name.strip() == "":
        return Image.new("RGB", (224,224), (128,128,128))

    img_name = img_name.strip()

    for split in ["train", "val", "test"]:
        fp = Path(IMG_ROOT) / split / img_name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                return Image.new("RGB", (224,224), (128,128,128))

    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# DATASET CLASS (CORRECTED: source = ENGLISH, target = GERMAN)
# ================================================================
class EcommDataset:
    """
    Dataset wrapper:
      English  = `source`
      German   = `target`
    """

    def __init__(self, src, tgt, limit=TEST_LIMIT):
        df = pd.read_csv(ECOMM_TSV, sep="\t")

        # Keep validation split
        df = df[df["set_name"].str.lower().isin(["val", "valid", "validation"])]
        df = df.reset_index(drop=True)

        if len(df) > limit:
            df = df.iloc[:limit]

        self.df = df
        self.src = src
        self.tgt = tgt

        print(f"Evaluation samples loaded: {len(df)}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]

        en = str(row["source"])   # English
        de = str(row["target"])   # German

        # Choose direction
        src_text = en if self.src == "en" else de
        tgt_text = de if self.tgt == "de" else en

        return {
            "src": src_text,
            "tgt": tgt_text,
            "img": row["image_file"]
        }

# ================================================================
# MODEL DEFINITIONS
# ================================================================
def apply_lora(mbart):
    cfg = LoraConfig(
        task_type="SEQ_2_SEQ_LM",
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    )
    return get_peft_model(mbart, cfg)


class FusionBlock(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = torch.nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
        )
        self.encoder = torch.nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        return self.encoder(torch.cat([img_emb, txt_emb], dim=1))


class MultiModalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained(VISION_MODEL_NAME)
        for p in self.vision.parameters():
            p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)
        self.text_emb = self.mbart.get_input_embeddings()

        self.proj = torch.nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def generate(self, input_ids, mask, pixel_values, tokenizer):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:, 0, :]

        img_emb = self.proj(vis).unsqueeze(1)
        txt_emb = self.text_emb(input_ids)

        fused = self.fusion(img_emb, txt_emb)
        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0), 1), device=device), mask],
            dim=1,
        )

        return self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )


class TextOnlyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)

    def generate(self, input_ids, mask, tokenizer):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )


# ================================================================
# EVALUATION (BLEU FIXED)
# ================================================================
def evaluate_one(src, tgt, model_type, tokenizer, processor):
    print(f"\n====== EVALUATING {model_type.upper()} {src}->{tgt} ======")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ds = EcommDataset(src, tgt, limit=TEST_LIMIT)

    # Load model + checkpoint
    if model_type == "mm":
        model = MultiModalModel().to(device)
        ckpt = CKPT_DIR / f"ecomm_{src}_{tgt}_mm.pt"
    else:
        model = TextOnlyModel().to(device)
        ckpt = CKPT_DIR / f"ecomm_{src}_{tgt}_txt.pt"

    print("Loading checkpoint:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs = [], []

    for i in tqdm(range(len(ds))):
        sample = ds[i]

        tgt_clean = str(sample["tgt"]).strip().replace("\n", " ")
        refs.append(tgt_clean)

        enc = tokenizer(
            sample["src"],
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt",
        ).to(device)

        if model_type == "mm":
            img = safe_load_image(sample["img"])
            pixel = processor(images=[img], return_tensors="pt")[
                "pixel_values"
            ].to(device)
            out = model.generate(enc["input_ids"], enc["attention_mask"], pixel, tokenizer)
        else:
            out = model.generate(enc["input_ids"], enc["attention_mask"], tokenizer)

        pred = tokenizer.decode(out[0], skip_special_tokens=True)
        preds.append(pred.strip())

    # CLEAN OUTPUT FOR BLEU
    preds = [p.strip() for p in preds]
    refs = [r.strip() for r in refs]

    references = [[r] for r in refs]

    print("Final Pred Count:", len(preds))
    print("Final Ref Count :", len(references))

    bleu = sacrebleu.compute(predictions=preds, references=references)["score"]

    print(f"‚≠ê BLEU SCORE ({src} ‚Üí {tgt}, {model_type}):", bleu)


# ================================================================
# RUN EVALUATION
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
processor = SiglipProcessor.from_pretrained(VISION_MODEL_NAME)

# Multimodal + Text-only, both directions
evaluate_one("en", "de", "mm", tokenizer, processor)
evaluate_one("en", "de", "txt", tokenizer, processor)

evaluate_one("de", "en", "mm", tokenizer, processor)
evaluate_one("de", "en", "txt", tokenizer, processor)



Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [03:03<00:00,  1.64it/s]


Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (en ‚Üí de, mm): 38.317652582244285

Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [02:53<00:00,  1.73it/s]


Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (en ‚Üí de, txt): 35.821675907249606

Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [02:49<00:00,  1.77it/s]


Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (de ‚Üí en, mm): 45.94036035099213

Evaluation samples loaded: 300
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 300/300 [02:41<00:00,  1.86it/s]

Final Pred Count: 300
Final Ref Count : 300
‚≠ê BLEU SCORE (de ‚Üí en, txt): 45.153433337371354





In [None]:
# ================================================================
# FIXED EVALUATION SCRIPT ‚Äî BLEU MISMATCH SOLVED FOREVER (EN <-> DE)
# ================================================================

import torch
from transformers import MBart50TokenizerFast, SiglipProcessor
from transformers import MBartForConditionalGeneration, SiglipVisionModel
from peft import LoraConfig, get_peft_model
from tqdm import tqdm
import evaluate
import pandas as pd
from PIL import Image
from pathlib import Path

# ================================================================
# CONSTANTS
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sacrebleu = evaluate.load("sacrebleu")

LANG_CODES = {"en": "en_XX", "de": "de_DE"}

MAX_LEN = 64
TEST_LIMIT = 1500  # <= change this to 1500 for full test

VISION_MODEL_NAME = "google/siglip-base-patch16-224"

BASE_DIR = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
CKPT_DIR = Path(BASE_DIR) / "ecomm_finetuned"

ECOMM_TSV = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de.tsv"
)

IMG_ROOT = "/content/drive/MyDrive/dataset/ImageGuidedTranslationDataset-main/dataset/images"


# ================================================================
# SAFE IMAGE LOADING
# ================================================================
def safe_load_image(img_name):
    """Load image safely; return fallback image if missing."""
    if not isinstance(img_name, str) or img_name.strip() == "":
        return Image.new("RGB", (224,224), (128,128,128))

    img_name = img_name.strip()

    for split in ["train", "val", "test"]:
        fp = Path(IMG_ROOT) / split / img_name
        if fp.exists():
            try:
                return Image.open(fp).convert("RGB")
            except:
                return Image.new("RGB", (224,224), (128,128,128))

    return Image.new("RGB", (224,224), (128,128,128))

# ================================================================
# DATASET CLASS (CORRECTED: source = ENGLISH, target = GERMAN)
# ================================================================
class EcommDataset:
    """
    Dataset wrapper:
      English  = `source`
      German   = `target`
    """

    def __init__(self, src, tgt, limit=TEST_LIMIT):
        df = pd.read_csv(ECOMM_TSV, sep="\t")

        # Keep validation split
        df = df[df["set_name"].str.lower().isin(["val", "valid", "validation"])]
        df = df.reset_index(drop=True)

        if len(df) > limit:
            df = df.iloc[:limit]

        self.df = df
        self.src = src
        self.tgt = tgt

        print(f"Evaluation samples loaded: {len(df)}")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, i):
        row = self.df.iloc[i]

        en = str(row["source"])   # English
        de = str(row["target"])   # German

        # Choose direction
        src_text = en if self.src == "en" else de
        tgt_text = de if self.tgt == "de" else en

        return {
            "src": src_text,
            "tgt": tgt_text,
            "img": row["image_file"]
        }

# ================================================================
# MODEL DEFINITIONS
# ================================================================
def apply_lora(mbart):
    cfg = LoraConfig(
        task_type="SEQ_2_SEQ_LM",
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    )
    return get_peft_model(mbart, cfg)


class FusionBlock(torch.nn.Module):
    def __init__(self, d_model):
        super().__init__()
        layer = torch.nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=8,
            dim_feedforward=2048,
            dropout=0.1,
            batch_first=True,
        )
        self.encoder = torch.nn.TransformerEncoder(layer, num_layers=1)

    def forward(self, img_emb, txt_emb):
        return self.encoder(torch.cat([img_emb, txt_emb], dim=1))


class MultiModalModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.vision = SiglipVisionModel.from_pretrained(VISION_MODEL_NAME)
        for p in self.vision.parameters():
            p.requires_grad = False

        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)
        self.text_emb = self.mbart.get_input_embeddings()

        self.proj = torch.nn.Linear(768, self.mbart.config.d_model)
        self.fusion = FusionBlock(self.mbart.config.d_model)

    def generate(self, input_ids, mask, pixel_values, tokenizer):
        with torch.no_grad():
            vis = self.vision(pixel_values=pixel_values).last_hidden_state[:, 0, :]

        img_emb = self.proj(vis).unsqueeze(1)
        txt_emb = self.text_emb(input_ids)

        fused = self.fusion(img_emb, txt_emb)
        fused_mask = torch.cat(
            [torch.ones((input_ids.size(0), 1), device=device), mask],
            dim=1,
        )

        return self.mbart.generate(
            inputs_embeds=fused,
            attention_mask=fused_mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )


class TextOnlyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(base)

    def generate(self, input_ids, mask, tokenizer):
        return self.mbart.generate(
            input_ids=input_ids,
            attention_mask=mask,
            num_beams=5,
            max_length=MAX_LEN,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang],
        )


# ================================================================
# EVALUATION (BLEU FIXED)
# ================================================================
def evaluate_one(src, tgt, model_type, tokenizer, processor):
    print(f"\n====== EVALUATING {model_type.upper()} {src}->{tgt} ======")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ds = EcommDataset(src, tgt, limit=TEST_LIMIT)

    # Load model + checkpoint
    if model_type == "mm":
        model = MultiModalModel().to(device)
        ckpt = CKPT_DIR / f"ecomm_{src}_{tgt}_mm.pt"
    else:
        model = TextOnlyModel().to(device)
        ckpt = CKPT_DIR / f"ecomm_{src}_{tgt}_txt.pt"

    print("Loading checkpoint:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs = [], []

    for i in tqdm(range(len(ds))):
        sample = ds[i]

        tgt_clean = str(sample["tgt"]).strip().replace("\n", " ")
        refs.append(tgt_clean)

        enc = tokenizer(
            sample["src"],
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt",
        ).to(device)

        if model_type == "mm":
            img = safe_load_image(sample["img"])
            pixel = processor(images=[img], return_tensors="pt")[
                "pixel_values"
            ].to(device)
            out = model.generate(enc["input_ids"], enc["attention_mask"], pixel, tokenizer)
        else:
            out = model.generate(enc["input_ids"], enc["attention_mask"], tokenizer)

        pred = tokenizer.decode(out[0], skip_special_tokens=True)
        preds.append(pred.strip())

    # CLEAN OUTPUT FOR BLEU
    preds = [p.strip() for p in preds]
    refs = [r.strip() for r in refs]

    references = [[r] for r in refs]

    print("Final Pred Count:", len(preds))
    print("Final Ref Count :", len(references))

    bleu = sacrebleu.compute(predictions=preds, references=references)["score"]

    print(f"‚≠ê BLEU SCORE ({src} ‚Üí {tgt}, {model_type}):", bleu)


# ================================================================
# RUN EVALUATION
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)
processor = SiglipProcessor.from_pretrained(VISION_MODEL_NAME)

# Multimodal + Text-only, both directions
evaluate_one("en", "de", "mm", tokenizer, processor)
evaluate_one("en", "de", "txt", tokenizer, processor)

evaluate_one("de", "en", "mm", tokenizer, processor)
evaluate_one("de", "en", "txt", tokenizer, processor)



Evaluation samples loaded: 1500
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [20:43<00:00,  1.21it/s]


Final Pred Count: 1500
Final Ref Count : 1500
‚≠ê BLEU SCORE (en ‚Üí de, mm): 37.32377458033243

Evaluation samples loaded: 1500
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_de_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [14:37<00:00,  1.71it/s]


Final Pred Count: 1500
Final Ref Count : 1500
‚≠ê BLEU SCORE (en ‚Üí de, txt): 35.98922332352579

Evaluation samples loaded: 1500
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_mm.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [14:20<00:00,  1.74it/s]


Final Pred Count: 1500
Final Ref Count : 1500
‚≠ê BLEU SCORE (de ‚Üí en, mm): 45.38777071440199

Evaluation samples loaded: 1500
Loading checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_de_en_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [13:35<00:00,  1.84it/s]

Final Pred Count: 1500
Final Ref Count : 1500
‚≠ê BLEU SCORE (de ‚Üí en, txt): 44.2543981919643





In [None]:
# ================================================================
# FINAL ERROR-PROOF TEXT-ONLY EVALUATION SCRIPT (1500 TEST SAMPLES)
# mBART + LoRA ‚Äî TEXT ONLY ‚Äî NO IMAGES
# ================================================================

import torch
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import evaluate

from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType
from PIL import Image

# ================================================================
# DEVICE
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
OUT_DIR = Path(BASE) / "ecomm_french_finetuned"

TSV_FILE = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
)

LANG_CODES = {"en": "en_XX", "fr": "fr_XX"}


# ================================================================
# LOAD EXACT 1500 TEST SAMPLES FROM FIRST 7500
# ================================================================
def load_french_test_set():
    df = pd.read_csv(TSV_FILE, sep="\t")
    print("\nColumns in dataset:", df.columns.tolist())

    if "french" not in df.columns:
        raise ValueError("‚ùå ERROR: 'french' column NOT found!")

    df = df.dropna(subset=["source", "french"]).reset_index(drop=True)

    # First 7500 rows only
    df = df.iloc[:7500].reset_index(drop=True)
    print("Usable rows (first 7500):", len(df))

    # Shuffle
    df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)

    # Last 1500 = test
    test_df = df.tail(1500).reset_index(drop=True)
    print("Loaded test size:", len(test_df))
    return test_df


# ================================================================
# TEXT-ONLY MODEL LOADER
# ================================================================
def apply_lora(model):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    )
    return get_peft_model(model, cfg)


class MBartTextOnly(torch.nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.model = apply_lora(base)

    def generate_text(self, input_ids, attention_mask, tokenizer):
        return self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_beams=5,
            max_length=64,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
        )


# ================================================================
# EVALUATION FUNCTION ‚Äî TEXT ONLY
# ================================================================
def evaluate_french_text(src, tgt, test_df):

    print(f"\nüîç Evaluating TEXT-ONLY {src} ‚Üí {tgt} on {len(test_df)} samples")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ckpt = OUT_DIR / f"ecomm_{src}_{tgt}_txt.pt"
    print("Loading TEXT-ONLY checkpoint:", ckpt)

    model = MBartTextOnly().to(device)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs, srcs = [], [], []

    for idx in tqdm(range(len(test_df))):
        row = test_df.iloc[idx]

        en = str(row["source"])
        fr = str(row["french"])

        # Direction
        src_text = en if src == "en" else fr
        tgt_text = fr if tgt == "fr" else en

        refs.append(tgt_text)
        srcs.append(src_text)

        # Tokenize
        enc = tokenizer(
            src_text,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            gen = model.generate_text(
                enc["input_ids"], enc["attention_mask"], tokenizer
            )

        pred = tokenizer.decode(gen[0], skip_special_tokens=True)
        preds.append(pred)

    print("\nüîé DEBUG ‚Äî Prediction & Reference Counts")
    print("Pred count:", len(preds))
    print("Ref count: ", len(refs))

    # BLEU requires list[list[str]]
    wrapped_refs = [[r] for r in refs]

    bleu = sacrebleu.compute(
        predictions=preds,
        references=wrapped_refs
    )["score"]

    print(f"\n‚≠ê TEXT-ONLY BLEU ({src} ‚Üí {tgt}) = {bleu}")

    # Save predictions
    out_file = OUT_DIR / f"preds_textonly_{src}_{tgt}_1500.tsv"
    pd.DataFrame({
        "source": srcs,
        "gold": refs,
        "pred": preds,
    }).to_csv(out_file, sep="\t", index=False)

    print("Saved TEXT-ONLY predictions ‚Üí", out_file)
    return bleu


# ================================================================
# RUN EVALUATION
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

test_df = load_french_test_set()

evaluate_french_text("en", "fr", test_df)
evaluate_french_text("fr", "en", test_df)


Using: cuda


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script: 0.00B [00:00, ?B/s]

tokenizer_config.json:   0%|          | 0.00/529 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/649 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]


Columns in dataset: ['project_name', 'set_name', 'image_id', 'image_file', 'source', 'target', 'french']
Usable rows (first 7500): 7500
Loaded test size: 1500

üîç Evaluating TEXT-ONLY en ‚Üí fr on 1500 samples
Loading TEXT-ONLY checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_en_fr_txt.pt


model.safetensors:   0%|          | 0.00/2.44G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/261 [00:00<?, ?B/s]

100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [14:43<00:00,  1.70it/s]



üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê TEXT-ONLY BLEU (en ‚Üí fr) = 17.79124277319568
Saved TEXT-ONLY predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/preds_textonly_en_fr_1500.tsv

üîç Evaluating TEXT-ONLY fr ‚Üí en on 1500 samples
Loading TEXT-ONLY checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/ecomm_fr_en_txt.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [13:15<00:00,  1.89it/s]


üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê TEXT-ONLY BLEU (fr ‚Üí en) = 20.96915203789267
Saved TEXT-ONLY predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_finetuned/preds_textonly_fr_en_1500.tsv





20.96915203789267

In [None]:
# ================================================================
# üîÅ FINAL FULL TRAINING + EVALUATION SCRIPT FOR EN‚ÜîFR (TEXT-ONLY)
# ================================================================

import os
from pathlib import Path

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import pandas as pd
from tqdm import tqdm
import evaluate

from transformers import (
    MBart50TokenizerFast,
    MBartForConditionalGeneration
)
from peft import LoraConfig, get_peft_model, TaskType

# ================================================================
# CONFIG
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

sacrebleu = evaluate.load("sacrebleu")

MAX_LEN = 64
BATCH = 8
LR = 2e-4
EPOCHS = 6

LANG_CODES = {"en": "en_XX", "fr": "fr_XX"}

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
MODEL_DIR = Path(BASE)
OUT_DIR = MODEL_DIR / "ecomm_french_text_only"
OUT_DIR.mkdir(exist_ok=True)

TSV_FILE = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
)

# ================================================================
# DATASET
# ================================================================
class TextDataset(Dataset):
    def __init__(self, df, src, tgt):
        self.df = df.reset_index(drop=True)
        self.src = src
        self.tgt = tgt

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        en = str(row["source"])
        fr = str(row["french"])

        src_text = en if self.src == "en" else fr
        tgt_text = fr if self.tgt == "fr" else en

        return {"src": src_text, "tgt": tgt_text}

# ================================================================
# CREATE TRAIN/TEST SPLITS
# ================================================================
def load_french_splits(src, tgt):
    df = pd.read_csv(TSV_FILE, sep="\t")

    df = df[df["set_name"].str.lower().isin(["train", "test"])]
    df = df.dropna(subset=["source", "french"]).reset_index(drop=True)

    train_df = df[df["set_name"] == "train"]
    test_df  = df[df["set_name"] == "test"]

    train_df = train_df.sample(min(6000, len(train_df)), random_state=42)
    test_df  = test_df.sample(min(1500, len(test_df)), random_state=42)

    print(f"{src} ‚Üí {tgt}: Train={len(train_df)}, Test={len(test_df)}")
    return train_df, test_df

# ================================================================
# MODEL DEFINITION (TEXT ONLY)
# ================================================================
def apply_lora(model):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"],
    )
    return get_peft_model(model, cfg)

class TextOnlyMT(nn.Module):
    """ mBART + LoRA only (no images) """
    def __init__(self):
        super().__init__()
        self.mbart = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.mbart = apply_lora(self.mbart)

    def forward(self, ids, mask, labels):
        return self.mbart(
            input_ids=ids,
            attention_mask=mask,
            labels=labels
        )

# ================================================================
# TOKENIZER + COLLATE
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50-many-to-many-mmt"
)

def collate(batch):
    src = [b["src"] for b in batch]
    tgt = [b["tgt"] for b in batch]

    enc_s = tokenizer(
        src, truncation=True, padding="max_length",
        max_length=MAX_LEN, return_tensors="pt"
    )

    with tokenizer.as_target_tokenizer():
        enc_t = tokenizer(
            tgt, truncation=True, padding="max_length",
            max_length=MAX_LEN, return_tensors="pt"
        )

    labels = enc_t["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100

    return {
        "ids": enc_s["input_ids"].to(device),
        "mask": enc_s["attention_mask"].to(device),
        "labels": labels.to(device)
    }

# ================================================================
# TRAINING LOOP
# ================================================================
def train_text_model(src, tgt):
    print(f"\nüî• Training TEXT-ONLY {src} ‚Üí {tgt}")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    train_df, test_df = load_french_splits(src, tgt)

    train_loader = DataLoader(
        TextDataset(train_df, src, tgt),
        batch_size=BATCH,
        shuffle=True,
        collate_fn=collate
    )
    test_ds = TextDataset(test_df, src, tgt)

    model = TextOnlyMT().to(device)
    optim = torch.optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR
    )

    best = 9999
    save_path = OUT_DIR / f"text_{src}_{tgt}.pt"

    for ep in range(1, EPOCHS + 1):
        model.train()
        total = 0

        for batch in tqdm(train_loader):
            optim.zero_grad()
            out = model(batch["ids"], batch["mask"], batch["labels"])
            loss = out.loss
            loss.backward()
            optim.step()
            total += loss.item()

        print(f"Epoch {ep} Loss = {total/len(train_loader):.4f}")

        if total < best:
            best = total
            torch.save(model.state_dict(), save_path)
            print("Saved best:", save_path)

    return test_ds

# ================================================================
# EVALUATION
# ================================================================
def evaluate_text(src, tgt, test_ds):
    print(f"\nüîç Evaluating TEXT-ONLY {src} ‚Üí {tgt}")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    model = TextOnlyMT().to(device)
    ckpt = OUT_DIR / f"text_{src}_{tgt}.pt"
    print("Loading:", ckpt)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs, sources = [], [], []

    for sample in tqdm(test_ds):
        src_text = sample["src"]
        tgt_text = sample["tgt"]

        sources.append(src_text)
        refs.append(tgt_text)

        enc = tokenizer(
            src_text,
            padding="max_length",
            truncation=True,
            max_length=MAX_LEN,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            gen = model.mbart.generate(
                input_ids=enc["input_ids"],
                attention_mask=enc["attention_mask"],
                max_length=MAX_LEN,
                num_beams=5,
                forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
            )

        preds.append(tokenizer.decode(gen[0], skip_special_tokens=True))

    bleu = sacrebleu.compute(predictions=preds, references=[refs])["score"]
    print(f"‚≠ê BLEU ({src} ‚Üí {tgt}) = {bleu:.4f}")

    pred_file = OUT_DIR / f"preds_text_{src}_{tgt}.tsv"
    pd.DataFrame({
        "source": sources,
        "reference": refs,
        "prediction": preds,
    }).to_csv(pred_file, sep="\t", index=False)

    print("Saved:", pred_file)
    return bleu

# ================================================================
# RUN TRAINING + EVALUATION
# ================================================================
test_en_fr = train_text_model("en", "fr")
test_fr_en = train_text_model("fr", "en")

#evaluate_text("en", "fr", test_en_fr)
#evaluate_text("fr", "en", test_fr_en)


Using: cuda

üî• Training TEXT-ONLY en ‚Üí fr
en ‚Üí fr: Train=6000, Test=0


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.36it/s]


Epoch 1 Loss = 1.8541
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.39it/s]


Epoch 2 Loss = 1.6293
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 3 Loss = 1.5011
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 4 Loss = 1.4104
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 5 Loss = 1.3316
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 6 Loss = 1.2616
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt

üî• Training TEXT-ONLY fr ‚Üí en
fr ‚Üí en: Train=6000, Test=0


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.38it/s]


Epoch 1 Loss = 2.0173
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 2 Loss = 1.8523
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 3 Loss = 1.7547
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.37it/s]


Epoch 4 Loss = 1.6695
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.38it/s]


Epoch 5 Loss = 1.5947
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 750/750 [02:19<00:00,  5.38it/s]


Epoch 6 Loss = 1.5291
Saved best: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


In [None]:
# ================================================================
# FINAL ERROR-PROOF TEXT-ONLY EVALUATION SCRIPT (1500 TEST SAMPLES)
# mBART + LoRA ‚Äî TEXT ONLY ‚Äî NO IMAGES
# ================================================================

import torch
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import evaluate

from transformers import MBart50TokenizerFast, MBartForConditionalGeneration
from peft import LoraConfig, get_peft_model, TaskType
from PIL import Image

# ================================================================
# DEVICE
# ================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

sacrebleu = evaluate.load("sacrebleu")

# ================================================================
# PATHS
# ================================================================
BASE = "/content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion"
OUT_DIR = Path(BASE) / "ecomm_french_text_only"

TSV_FILE = (
    "/content/drive/MyDrive/dataset/"
    "ImageGuidedTranslationDataset-main/dataset/"
    "listingtitle-image-mappings/listingtitles_with_matched_images.en-de_with_french.tsv"
)

LANG_CODES = {"en": "en_XX", "fr": "fr_XX"}


# ================================================================
# LOAD EXACT 1500 TEST SAMPLES FROM FIRST 7500
# ================================================================
def load_french_test_set():
    df = pd.read_csv(TSV_FILE, sep="\t")
    print("\nColumns in dataset:", df.columns.tolist())

    if "french" not in df.columns:
        raise ValueError("‚ùå ERROR: 'french' column NOT found!")

    df = df.dropna(subset=["source", "french"]).reset_index(drop=True)

    # First 7500 rows only
    df = df.iloc[:7500].reset_index(drop=True)
    print("Usable rows (first 7500):", len(df))

    # Shuffle
    df = df.sample(frac=1.0, random_state=42).reset_index(drop=True)

    # Last 1500 = test
    test_df = df.tail(1500).reset_index(drop=True)
    print("Loaded test size:", len(test_df))
    return test_df


# ================================================================
# TEXT-ONLY MODEL LOADER
# ================================================================
def apply_lora(model):
    cfg = LoraConfig(
        task_type=TaskType.SEQ_2_SEQ_LM,
        r=8,
        lora_alpha=16,
        lora_dropout=0.05,
        target_modules=["q_proj", "v_proj"]
    )
    return get_peft_model(model, cfg)


class MBartTextOnly(torch.nn.Module):
    def __init__(self):
        super().__init__()
        base = MBartForConditionalGeneration.from_pretrained(
            "facebook/mbart-large-50-many-to-many-mmt"
        )
        self.model = apply_lora(base)

    def generate_text(self, input_ids, attention_mask, tokenizer):
        return self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_beams=5,
            max_length=64,
            forced_bos_token_id=tokenizer.lang_code_to_id[tokenizer.tgt_lang]
        )


# ================================================================
# EVALUATION FUNCTION ‚Äî TEXT ONLY
# ================================================================
def evaluate_french_text(src, tgt, test_df):

    print(f"\nüîç Evaluating TEXT-ONLY {src} ‚Üí {tgt} on {len(test_df)} samples")

    tokenizer.src_lang = LANG_CODES[src]
    tokenizer.tgt_lang = LANG_CODES[tgt]

    ckpt = OUT_DIR / f"text_{src}_{tgt}.pt"
    print("Loading TEXT-ONLY checkpoint:", ckpt)

    model = MBartTextOnly().to(device)
    model.load_state_dict(torch.load(ckpt, map_location=device), strict=False)
    model.eval()

    preds, refs, srcs = [], [], []

    for idx in tqdm(range(len(test_df))):
        row = test_df.iloc[idx]

        en = str(row["source"])
        fr = str(row["french"])

        # Direction
        src_text = en if src == "en" else fr
        tgt_text = fr if tgt == "fr" else en

        refs.append(tgt_text)
        srcs.append(src_text)

        # Tokenize
        enc = tokenizer(
            src_text,
            padding="max_length",
            truncation=True,
            max_length=64,
            return_tensors="pt"
        ).to(device)

        with torch.no_grad():
            gen = model.generate_text(
                enc["input_ids"], enc["attention_mask"], tokenizer
            )

        pred = tokenizer.decode(gen[0], skip_special_tokens=True)
        preds.append(pred)

    print("\nüîé DEBUG ‚Äî Prediction & Reference Counts")
    print("Pred count:", len(preds))
    print("Ref count: ", len(refs))

    # BLEU requires list[list[str]]
    wrapped_refs = [[r] for r in refs]

    bleu = sacrebleu.compute(
        predictions=preds,
        references=wrapped_refs
    )["score"]

    print(f"\n‚≠ê TEXT-ONLY BLEU ({src} ‚Üí {tgt}) = {bleu}")

    # Save predictions
    out_file = OUT_DIR / f"preds_textonly_{src}_{tgt}_1500.tsv"
    pd.DataFrame({
        "source": srcs,
        "gold": refs,
        "pred": preds,
    }).to_csv(out_file, sep="\t", index=False)

    print("Saved TEXT-ONLY predictions ‚Üí", out_file)
    return bleu


# ================================================================
# RUN EVALUATION
# ================================================================
tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")

test_df = load_french_test_set()

evaluate_french_text("en", "fr", test_df)
evaluate_french_text("fr", "en", test_df)


Using: cuda

Columns in dataset: ['project_name', 'set_name', 'image_id', 'image_file', 'source', 'target', 'french']
Usable rows (first 7500): 7500
Loaded test size: 1500

üîç Evaluating TEXT-ONLY en ‚Üí fr on 1500 samples
Loading TEXT-ONLY checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_en_fr.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [14:28<00:00,  1.73it/s]



üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê TEXT-ONLY BLEU (en ‚Üí fr) = 17.79124277319568
Saved TEXT-ONLY predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/preds_textonly_en_fr_1500.tsv

üîç Evaluating TEXT-ONLY fr ‚Üí en on 1500 samples
Loading TEXT-ONLY checkpoint: /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/text_fr_en.pt


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1500/1500 [13:08<00:00,  1.90it/s]


üîé DEBUG ‚Äî Prediction & Reference Counts
Pred count: 1500
Ref count:  1500

‚≠ê TEXT-ONLY BLEU (fr ‚Üí en) = 20.96915203789267
Saved TEXT-ONLY predictions ‚Üí /content/drive/MyDrive/multimodal_translation_models_siglip_lora_fusion/ecomm_french_text_only/preds_textonly_fr_en_1500.tsv





20.96915203789267