In [None]:
# ============================================================
#  MVSA-Single multimodal experiment (Baseline vs Rigid-Givens)
#  - Automatic download from Kaggle
#  - Binary sentiment: 0 = negative, 1 = positive
#  - Image + text (tweet)
#  - Comparison: baseline vs rigid rotation (few Givens) + translation
#  - K-fold CV, paired t-test, plots in ./results/
# ============================================================

# -------- 0) Dependencies (Colab) --------
!pip install -q kaggle transformers

import os, random, subprocess
from dataclasses import dataclass

import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn.functional as F
from torchvision import models, transforms

from transformers import AutoTokenizer, AutoModel
from sklearn.model_selection import KFold
from sklearn.metrics import f1_score
from scipy.stats import ttest_rel
import matplotlib.pyplot as plt

from google.colab import files
# from kaggle.api.kaggle_api_extended import KaggleApi  # not used directly, but imports the package


# ========================= CONFIG =========================

@dataclass
class Config:
    mvsa_root: str = "/content/data/MVSA_Single"

    num_samples: int | None = None
    img_size: int = 224
    max_text_len: int = 64

    batch_size: int = 3
    num_epochs: int = 12
    num_folds: int = 2

    lr: float = 1e-3
    weight_decay: float = 1e-4
    seed: int = 42
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

    temperature: float = 0.07
    lambda_cls: float = 1.0
    lambda_ntxent: float = 1.0
    lambda_l1: float = 1.0
    lambda_cos: float = 1.0

    freeze_backbones: bool = True
    text_model_name: str = "prajjwal1/bert-tiny"

    num_givens: int = 16


# ========================= UTILS =========================

def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def ensure_kaggle_api():
    """Ensure ~/.kaggle/kaggle.json exists. In Colab, ask user to upload if missing."""
    kaggle_path = os.path.expanduser("~/.kaggle")
    kaggle_json = os.path.join(kaggle_path, "kaggle.json")

    if not os.path.exists(kaggle_json):
        print("~/.kaggle/kaggle.json not found.")
        print("Please upload your kaggle.json (Kaggle account → Legacy API Key).")
        uploaded = files.upload()
        if "kaggle.json" not in uploaded:
            raise RuntimeError("kaggle.json was not uploaded. Run this cell again and select the file.")
        os.makedirs(kaggle_path, exist_ok=True)
        os.replace("kaggle.json", kaggle_json)
        os.chmod(kaggle_json, 0o600)
        print("kaggle.json copied to ~/.kaggle/kaggle.json")
    else:
        print("Found ~/.kaggle/kaggle.json")


def download_mvsa_single_if_needed(cfg: Config):
    """Download and unzip MVSA-Single from Kaggle if it is not already present."""
    if os.path.isdir(cfg.mvsa_root) and \
       os.path.isdir(os.path.join(cfg.mvsa_root, "data")) and \
       os.path.isfile(os.path.join(cfg.mvsa_root, "labelResultAll.txt")):
        print(f"MVSA-Single found at: {cfg.mvsa_root}")
        return

    print("Downloading MVSA-Single from Kaggle (sayan3270/mvsa-single)...")
    ensure_kaggle_api()
    data_root = os.path.dirname(cfg.mvsa_root)
    os.makedirs(data_root, exist_ok=True)

    cmd_download = [
        "kaggle", "datasets", "download",
        "-d", "sayan3270/mvsa-single",
        "-p", data_root,
        "-q"
    ]
    subprocess.run(cmd_download, check=True)
    zip_path = os.path.join(data_root, "mvsa-single.zip")
    if not os.path.isfile(zip_path):
        zips = [f for f in os.listdir(data_root) if f.lower().endswith(".zip")]
        if not zips:
            raise RuntimeError("Could not find the MVSA-Single zip after download.")
        zip_path = os.path.join(data_root, zips[0])

    print(f"Unzipping {zip_path} ...")
    subprocess.run(["unzip", "-q", zip_path, "-d", data_root], check=True)

    if not os.path.isdir(cfg.mvsa_root):
        candidates = [
            os.path.join(data_root, d) for d in os.listdir(data_root)
            if os.path.isdir(os.path.join(data_root, d)) and "MVSA_Single" in d
        ]
        if candidates:
            os.rename(candidates[0], cfg.mvsa_root)

    if not (os.path.isdir(cfg.mvsa_root) and
            os.path.isdir(os.path.join(cfg.mvsa_root, "data")) and
            os.path.isfile(os.path.join(cfg.mvsa_root, "labelResultAll.txt"))):
        raise RuntimeError(
            "MVSA_Single structure is not as expected. "
            f"Check the contents of {data_root}."
        )

    print(f"MVSA-Single ready at: {cfg.mvsa_root}")


# ========================= DATASET =========================

class MVSASingleDataset(Dataset):
    """
    MVSA-Single (Kaggle) → binary classification:
      0 = negative, 1 = positive
    Neutral / ambiguous examples are removed.
    """
    def __init__(self, mvsa_root, tokenizer, transform=None,
                 max_text_len=64, num_samples=3000):
        super().__init__()
        self.mvsa_root = mvsa_root
        self.data_dir = os.path.join(mvsa_root, "data")
        self.tokenizer = tokenizer
        self.max_text_len = max_text_len

        if transform is None:
            transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225]),
            ])
        self.transform = transform

        label_path = os.path.join(mvsa_root, "labelResultAll.txt")
        print(f"Loading labels from {label_path} ...")
        df = pd.read_csv(label_path, sep="\t", header=0)
        print(f"labelResultAll.txt: {df.shape[0]} rows, {df.shape[1]} columns.")

        if df.shape[1] == 2:
            df.columns = ["image_id", "annotation"]
        else:
            df = df.iloc[:, [0, -1]]
            df.columns = ["image_id", "annotation"]

        df["image_id"] = df["image_id"].astype(str).str.strip()
        ann = df["annotation"].astype(str).str.strip().str.lower()

        simple = {"positive", "negative", "neutral"}
        uniq = sorted(ann.unique())

        if set(uniq).issubset(simple):
            df["sentiment"] = ann
        else:
            print("Annotations are not direct labels; deriving sentiment (positive/negative/neutral).")

            def derive_sentiment(s):
                s_low = str(s).lower()
                has_pos = "positive" in s_low
                has_neg = "negative" in s_low
                has_neu = "neutral" in s_low

                if (has_pos and has_neg) or (has_neu and not (has_pos or has_neg)):
                    return None
                if has_pos:
                    return "positive"
                if has_neg:
                    return "negative"
                if has_neu:
                    return "neutral"
                return None

            df["sentiment"] = ann.apply(derive_sentiment)

        invalid = df["sentiment"].isna().sum()
        if invalid > 0:
            print(f"Dropping {invalid} rows with ambiguous annotations.")
            df = df.dropna(subset=["sentiment"]).reset_index(drop=True)

        mask = df["sentiment"].isin(["positive", "negative"])
        dropped_neu = (~mask).sum()
        if dropped_neu > 0:
            print(f"Dropping {dropped_neu} 'neutral' examples.")
        df = df[mask].reset_index(drop=True)

        df["label"] = (df["sentiment"] == "positive").astype(np.float32)

        def has_files(row):
            img_path = os.path.join(self.data_dir, f"{row['image_id']}.jpg")
            txt_path = os.path.join(self.data_dir, f"{row['image_id']}.txt")
            return os.path.isfile(img_path) and os.path.isfile(txt_path)

        df["has_files"] = df.apply(has_files, axis=1)
        missing = (~df["has_files"]).sum()
        if missing > 0:
            print(f"Dropping {missing} rows missing .jpg or .txt.")
            df = df[df["has_files"]].reset_index(drop=True)
        df = df.drop(columns=["has_files"])

        print("Label distribution (before subsampling):")
        print(
            df["label"].value_counts().rename({
                0.0: "negative (0)",
                1.0: "positive (1)"
            })
        )

        if num_samples is not None and num_samples < len(df):
            df = df.sample(num_samples, random_state=42).reset_index(drop=True)
            print(f"Randomly taking {num_samples} pairs.")
        else:
            print(f"Using all {len(df)} available pairs.")

        self.df = df
        print(f"Final dataset: {len(self.df)} binary pairs.")

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_id = row["image_id"]
        label = row["label"]

        img_path = os.path.join(self.data_dir, f"{img_id}.jpg")
        txt_path = os.path.join(self.data_dir, f"{img_id}.txt")

        image = Image.open(img_path).convert("RGB")
        image = self.transform(image)

        with open(txt_path, "r", encoding="utf-8", errors="ignore") as f:
            text = f.read().strip()

        encoded = self.tokenizer(
            text,
            padding="max_length",
            truncation=True,
            max_length=self.max_text_len,
            return_tensors="pt",
        )
        input_ids = encoded["input_ids"].squeeze(0)
        attention_mask = encoded["attention_mask"].squeeze(0)

        label_tensor = torch.tensor([label], dtype=torch.float32)

        return {
            "pixel_values": image,
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": label_tensor,
            "text": text,
            "image_id": img_id,
        }


# ========================= RIGID GIVENS TRANSFORM =========================

class RigidGivensTransform(nn.Module):
    """
    Rigid transform in R^dim:
      x' = R x + t
    where R is a product of a small number of 2D Givens rotations
    (planes (i_k, j_k) with learnable angles) and t is a learnable translation.
    """
    def __init__(self, dim: int = 128, num_givens: int = 16, seed: int = 0):
        super().__init__()
        self.dim = dim

        max_pairs = dim // 2
        self.num_givens = min(num_givens, max_pairs)

        g = torch.Generator()
        g.manual_seed(seed)
        perm = torch.randperm(dim, generator=g)

        idx_i = perm[0:2 * self.num_givens:2]
        idx_j = perm[1:2 * self.num_givens:2]

        self.register_buffer("idx_i", idx_i)
        self.register_buffer("idx_j", idx_j)

        self.theta = nn.Parameter(torch.zeros(self.num_givens))
        self.translation = nn.Parameter(torch.zeros(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: [B, dim]
        Apply the sequence of 2D rotations in planes (i_k, j_k) and then add a translation.
        """
        y = x

        for k in range(self.num_givens):
            i = int(self.idx_i[k])
            j = int(self.idx_j[k])
            angle = self.theta[k]

            c = torch.cos(angle)
            s = torch.sin(angle)

            xi = y[:, i].clone()
            xj = y[:, j].clone()

            y[:, i] = c * xi - s * xj
            y[:, j] = s * xi + c * xj

        return y + self.translation


# ========================= MODEL =========================

class MultimodalNet(nn.Module):
    def __init__(self,
                 text_model_name: str = "prajjwal1/bert-tiny",
                 use_affine_text: bool = False,
                 num_givens: int = 16):
        super().__init__()
        self.use_affine_text = use_affine_text

        backbone = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.cnn = nn.Sequential(*list(backbone.children())[:-1])
        self.img_fc = nn.Linear(512, 128)

        self.text_model = AutoModel.from_pretrained(text_model_name)
        hidden_size = self.text_model.config.hidden_size
        if hidden_size == 128:
            self.text_head = nn.Identity()
        else:
            self.text_head = nn.Linear(hidden_size, 128)

        if self.use_affine_text:
            self.affine_text = RigidGivensTransform(dim=128, num_givens=num_givens)

        self.proj = nn.Sequential(
            nn.Linear(128, 128),
            nn.ReLU()
        )
        self.classifier = nn.Linear(128, 1)

    def encode_image(self, pixel_values):
        feats = self.cnn(pixel_values)
        feats = feats.flatten(1)
        img_emb = self.img_fc(feats)
        return img_emb

    def encode_text(self, input_ids, attention_mask):
        out = self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        last_hidden = out.last_hidden_state
        cls = last_hidden[:, 0, :]
        txt_emb = self.text_head(cls)
        return txt_emb

    def forward(self, pixel_values, input_ids, attention_mask):
        img_emb = self.encode_image(pixel_values)
        txt_emb = self.encode_text(input_ids, attention_mask)

        if self.use_affine_text:
            txt_emb = self.affine_text(txt_emb)

        img_proj = self.proj(img_emb)
        txt_proj = self.proj(txt_emb)

        img_logits = self.classifier(img_proj)
        txt_logits = self.classifier(txt_proj)

        return {
            "img_proj": img_proj,
            "txt_proj": txt_proj,
            "img_logits": img_logits,
            "txt_logits": txt_logits,
        }


# ========================= LOSSES =========================

bce = nn.BCEWithLogitsLoss()

def ntxent_loss(z_i, z_t, temperature=0.07):
    z_i = F.normalize(z_i, dim=1)
    z_t = F.normalize(z_t, dim=1)
    logits = z_i @ z_t.T / temperature
    targets = torch.arange(z_i.size(0), device=z_i.device)
    loss_i2t = F.cross_entropy(logits, targets)
    loss_t2i = F.cross_entropy(logits.T, targets)
    return (loss_i2t + loss_t2i) / 2.0


def multimodal_loss(outputs, labels, cfg: Config):
    img_logits = outputs["img_logits"]
    txt_logits = outputs["txt_logits"]
    z_i = outputs["img_proj"]
    z_t = outputs["txt_proj"]

    loss_cls_img = bce(img_logits, labels)
    loss_cls_txt = bce(txt_logits, labels)
    loss_cls = loss_cls_img + loss_cls_txt

    loss_ntx = ntxent_loss(z_i, z_t, temperature=cfg.temperature)
    loss_l1 = torch.mean(torch.abs(z_i - z_t))
    loss_cos = 1.0 - F.cosine_similarity(z_i, z_t, dim=1).mean()

    loss = (cfg.lambda_cls * loss_cls +
            cfg.lambda_ntxent * loss_ntx +
            cfg.lambda_l1 * loss_l1 +
            cfg.lambda_cos * loss_cos)
    return loss, {
        "loss_total": loss.item(),
        "loss_cls": loss_cls.item(),
        "loss_ntx": loss_ntx.item(),
        "loss_l1": loss_l1.item(),
        "loss_cos": loss_cos.item(),
    }


# ========================= TRAIN / EVAL =========================

def freeze_backbones(model: MultimodalNet):
    for p in model.cnn.parameters():
        p.requires_grad = False
    for p in model.text_model.parameters():
        p.requires_grad = False


def get_optimizer(model: nn.Module, cfg: Config):
    params = [p for p in model.parameters() if p.requires_grad]
    return torch.optim.Adam(params, lr=cfg.lr, weight_decay=cfg.weight_decay)


def train_one_epoch(model, dataloader, optimizer, cfg: Config, epoch, model_name=""):
    model.train()
    total_loss = 0.0
    print(f"[{model_name}] Epoch {epoch+1}/{cfg.num_epochs}")

    for step, batch in enumerate(dataloader):
        pixel_values = batch["pixel_values"].to(cfg.device)
        input_ids = batch["input_ids"].to(cfg.device)
        attention_mask = batch["attention_mask"].to(cfg.device)
        labels = batch["labels"].to(cfg.device)

        optimizer.zero_grad()
        outputs = model(pixel_values, input_ids, attention_mask)
        loss, _ = multimodal_loss(outputs, labels, cfg)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (step + 1) % 40 == 0:
            pct = 100 * (step + 1) / len(dataloader)
            print(
                f"[{model_name}] Step {step+1}/{len(dataloader)} "
                f"({pct:.1f}%) | Loss: {total_loss/(step+1):.4f}"
            )

    avg = total_loss / len(dataloader)
    print(f"[{model_name}] Epoch {epoch+1} finished — mean loss: {avg:.4f}")
    return avg


@torch.no_grad()
def evaluate(model, dataloader, cfg: Config):
    model.eval()
    all_labels = []
    all_img_logits = []
    all_txt_logits = []

    for batch in dataloader:
        pixel_values = batch["pixel_values"].to(cfg.device)
        input_ids = batch["input_ids"].to(cfg.device)
        attention_mask = batch["attention_mask"].to(cfg.device)
        labels = batch["labels"].to(cfg.device)

        outputs = model(pixel_values, input_ids, attention_mask)
        all_labels.append(labels.cpu().numpy())
        all_img_logits.append(outputs["img_logits"].cpu().numpy())
        all_txt_logits.append(outputs["txt_logits"].cpu().numpy())

    y_true = np.concatenate(all_labels, axis=0).flatten()
    img_probs = torch.sigmoid(torch.tensor(np.concatenate(all_img_logits))).numpy().flatten()
    txt_probs = torch.sigmoid(torch.tensor(np.concatenate(all_txt_logits))).numpy().flatten()
    y_img = (img_probs > 0.5).astype(int)
    y_txt = (txt_probs > 0.5).astype(int)

    f1_img = f1_score(y_true, y_img)
    f1_txt = f1_score(y_true, y_txt)
    f1_mean = 0.5 * (f1_img + f1_txt)
    return {"f1_img": f1_img, "f1_txt": f1_txt, "f1_mean": f1_mean}


# ========================= MAIN EXPERIMENT =========================

def run_experiment(cfg: Config):
    set_seed(cfg.seed)
    os.makedirs("results", exist_ok=True)

    download_mvsa_single_if_needed(cfg)

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(cfg.text_model_name)

    transform = transforms.Compose([
        transforms.Resize((cfg.img_size, cfg.img_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

    print("Loading MVSA-Single dataset...")
    dataset = MVSASingleDataset(
        mvsa_root=cfg.mvsa_root,
        tokenizer=tokenizer,
        transform=transform,
        max_text_len=cfg.max_text_len,
        num_samples=cfg.num_samples,
    )
    print(f"Final dataset: {len(dataset)} samples (binary: 0=neg, 1=pos).")

    indices = np.arange(len(dataset))
    kf = KFold(n_splits=cfg.num_folds, shuffle=True, random_state=cfg.seed)

    baseline_scores = []
    affine_scores = []
    all_history = []

    fold = 0
    for train_idx, val_idx in kf.split(indices):
        fold += 1
        print("="*80)
        print(f"FOLD {fold}/{cfg.num_folds}")
        print("="*80)

        train_loader = DataLoader(
            Subset(dataset, train_idx),
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=2,
            pin_memory=True
        )
        val_loader = DataLoader(
            Subset(dataset, val_idx),
            batch_size=cfg.batch_size,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )

        # ---------- BASELINE ----------
        print("Training BASELINE model (no rigid transform)...")
        model_base = MultimodalNet(
            text_model_name=cfg.text_model_name,
            use_affine_text=False,
            num_givens=cfg.num_givens
        ).to(cfg.device)
        if cfg.freeze_backbones:
            freeze_backbones(model_base)
        optimizer_base = get_optimizer(model_base, cfg)

        loss_history_base = []
        for epoch in range(cfg.num_epochs):
            l = train_one_epoch(model_base, train_loader, optimizer_base, cfg, epoch, "Baseline")
            loss_history_base.append(l)
        metrics_base = evaluate(model_base, val_loader, cfg)
        baseline_scores.append(metrics_base["f1_mean"])
        print(
            f"BASELINE FOLD {fold} — "
            f"F1_img={metrics_base['f1_img']:.4f}, "
            f"F1_txt={metrics_base['f1_txt']:.4f}, "
            f"F1_mean={metrics_base['f1_mean']:.4f}"
        )

        # ---------- RIGID-GIVENS ----------
        print("Training RIGID-GIVENS model (learnable text rotation+translation)...")
        model_aff = MultimodalNet(
            text_model_name=cfg.text_model_name,
            use_affine_text=True,
            num_givens=cfg.num_givens
        ).to(cfg.device)
        if cfg.freeze_backbones:
            freeze_backbones(model_aff)
        optimizer_aff = get_optimizer(model_aff, cfg)

        loss_history_aff = []
        for epoch in range(cfg.num_epochs):
            l = train_one_epoch(model_aff, train_loader, optimizer_aff, cfg, epoch, "Rigid-Givens")
            loss_history_aff.append(l)
        metrics_aff = evaluate(model_aff, val_loader, cfg)
        affine_scores.append(metrics_aff["f1_mean"])
        print(
            f"RIGID-GIVENS FOLD {fold} — "
            f"F1_img={metrics_aff['f1_img']:.4f}, "
            f"F1_txt={metrics_aff['f1_txt']:.4f}, "
            f"F1_mean={metrics_aff['f1_mean']:.4f}"
        )

        all_history.append({
            "fold": fold,
            "loss_base": loss_history_base,
            "loss_aff": loss_history_aff,
            "f1_base": metrics_base,
            "f1_aff": metrics_aff
        })

    # ===================== SUMMARY =====================

    baseline_scores = np.array(baseline_scores)
    affine_scores = np.array(affine_scores)
    diffs = affine_scores - baseline_scores

    print("="*80)
    print("CROSS-VALIDATION SUMMARY")
    print("="*80)
    for i, (b, a) in enumerate(zip(baseline_scores, affine_scores), start=1):
        print(f"Fold {i}:  Baseline F1={b:.4f} | Rigid-Givens F1={a:.4f} | Δ={a-b:.4f}")

    mean_base, std_base = baseline_scores.mean(), baseline_scores.std(ddof=1)
    mean_aff, std_aff = affine_scores.mean(), affine_scores.std(ddof=1)

    print("\nMean ± std F1 over folds:")
    print(f"  Baseline     : {mean_base:.4f} ± {std_base:.4f}")
    print(f"  Rigid-Givens : {mean_aff:.4f} ± {std_aff:.4f}")
    print(f"  Improvement (Rigid-Givens - Baseline): {diffs.mean():.4f}")

    t_stat, p_value = ttest_rel(affine_scores, baseline_scores)
    print("\nPaired t-test (Rigid-Givens vs Baseline, F1_mean)")
    print(f"  t = {t_stat:.4f},  p = {p_value:.6f}")
    if p_value < 0.05:
        print("   Difference is statistically significant at α = 0.05.")
    else:
        print("   Difference is NOT statistically significant at α = 0.05.")

    np.save("results/baseline_f1.npy", baseline_scores)
    np.save("results/rigid_givens_f1.npy", affine_scores)

    summary_path = os.path.join("results", "summary_mvsa_rigid_givens.txt")
    with open(summary_path, "w") as f:
        f.write("MVSA-Single MULTIMODAL EXPERIMENT (binary pos/neg)\n")
        f.write("Baseline vs Rigid-Givens (few Givens rotations + translation)\n")
        f.write("===============================================================\n\n")
        for k, v in cfg.__dict__.items():
            f.write(f"{k}: {v}\n")
        f.write("\nF1 per fold (mean of img+txt):\n")
        for i, (b, a) in enumerate(zip(baseline_scores, affine_scores), start=1):
            f.write(f"  Fold {i}: baseline={b:.4f}, rigid-givens={a:.4f}, delta={a-b:.4f}\n")
        f.write(f"\nBaseline mean ± std: {mean_base:.4f} ± {std_base:.4f}\n")
        f.write(f"Rigid-Givens mean ± std: {mean_aff:.4f} ± {std_aff:.4f}\n")
        f.write(f"Delta (rigid-givens - baseline): {diffs.mean():.4f}\n")
        f.write(f"Paired t-test: t = {t_stat:.4f}, p = {p_value:.6f}\n")

    print(f"\nSummary saved to {summary_path}")

    # --------- PLOTS ----------
    folds = np.arange(1, cfg.num_folds + 1)

    plt.figure(figsize=(6, 4))
    bar_width = 0.35
    plt.bar(folds - bar_width/2, baseline_scores, bar_width, label="Baseline")
    plt.bar(folds + bar_width/2, affine_scores, bar_width, label="Rigid-Givens")
    plt.xticks(folds)
    plt.xlabel("Fold")
    plt.ylabel("F1 (mean img+txt)")
    plt.title("F1 per fold: Baseline vs Rigid-Givens (MVSA-Single)")
    plt.legend()
    plt.grid(axis="y", alpha=0.3)
    plt.savefig("results/f1_per_fold.png", dpi=150, bbox_inches="tight")
    plt.close()

    plt.figure(figsize=(5, 4))
    plt.boxplot([baseline_scores, affine_scores],
                labels=["Baseline", "Rigid-Givens"])
    plt.ylabel("F1 (mean img+txt)")
    plt.title("F1 distribution across folds (MVSA-Single)")
    plt.grid(axis="y", alpha=0.3)
    plt.savefig("results/f1_boxplot.png", dpi=150, bbox_inches="tight")
    plt.close()

    for hist in all_history:
        fd = hist["fold"]
        plt.figure(figsize=(6, 4))
        plt.plot(hist["loss_base"], label="Baseline")
        plt.plot(hist["loss_aff"], label="Rigid-Givens")
        plt.xlabel("Epoch")
        plt.ylabel("Training loss")
        plt.title(f"Training loss — Fold {fd}")
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.savefig(f"results/loss_fold{fd}.png", dpi=150, bbox_inches="tight")
        plt.close()

    print("\nExperiment finished. Metrics, plots and arrays are saved in ./results/")


# ========================= ENTRY POINT =========================

cfg = Config()
print("CONFIG:")
for k, v in cfg.__dict__.items():
    print(f"  {k}: {v}")

run_experiment(cfg)
