# Kaggle Training Entry Point

This notebook prepares the workspace, installs dependencies, rewrites the config with Kaggle paths, and launches training. Attach the datasets that host:

- the repository snapshot (read-only under `/kaggle/input/...`)
- pretrained encoders (`artifacts/audio-encoder`, `artifacts/roberta-text-encoder`)
- IEMOCAP raw data (`IEMOCAP_full_release`) and `iemocap_manifest.jsonl`

Update the constants below if your dataset names differ.

In [None]:
!rm -rf /kaggle/working/ser-conformer-gat-xai
!git clone https://github.com/SpeedyLabX/ser-conformer-gat-xai.git \
    /kaggle/working/ser-conformer-gat-xai

In [None]:
from pathlib import Path
import shutil

WORK_DIR = Path("/kaggle/working/ser-conformer-gat-xai")
ARTIFACTS_DATASET = Path("/kaggle/input/text-audio-encoders/pytorch/default/1/artifacts")
IEMOCAP_ROOT = Path("/kaggle/input/iemocapfullrelease/IEMOCAP_full_release")
MANIFEST_PATH = Path("/kaggle/input/iemocap-manifest-jsonl/iemocap_manifest.jsonl")

assert WORK_DIR.exists(), "Repository clone missing"
assert ARTIFACTS_DATASET.exists(), "Encoder dataset path incorrect"
assert IEMOCAP_ROOT.exists(), "IEMOCAP dataset path incorrect"
assert MANIFEST_PATH.exists(), "Manifest dataset path incorrect"

# Mirror artifacts into the writable repo tree
shutil.copytree(ARTIFACTS_DATASET, WORK_DIR / "artifacts", dirs_exist_ok=True)
print("Workspace ready at", WORK_DIR)

In [None]:
!pip install --quiet transformers soundfile scikit-learn pyyaml tqdm matplotlib networkx

In [None]:
%cd /kaggle/working/ser-conformer-gat-xai
!pip install -q -e .

## Adjusting Training Hyperparameters

All configuration is driven by YAML. By default, the notebook loads `configs/iemocap.yaml`, which itself extends `configs/base.yaml`. Only the keys you override here will change; everything else falls back to `base.yaml`.

Parameters you can override and what they control:

- `trainer`
  - `epochs`: maximum epochs.
  - `batch_size`: mini-batch size.
  - `lr`: AdamW learning rate.
  - `weight_decay`: AdamW weight decay coefficient.
  - `grad_clip`: gradient clipping threshold (`null` disables it).
  - `patience`: early-stopping patience in epochs.
  - `use_tqdm`: enable per-batch progress bars for train/eval.
- `model.audio`
  - `checkpoint`: path to the pretrained audio encoder bundle.
  - `freeze`: keep the audio encoder frozen (`True`) or fine-tune (`False`).
- `model.text`
  - `checkpoint`: path to the text encoder weights.
  - `proj_dim`: projection size fed into fusion.
  - `freeze`: freeze text backbone or fine-tune it.
- `model.fusion`
  - `gat_heads`: number of GAT attention heads.
  - `gat_layers`: number of stacked GAT layers.
  - `gat_hidden`: hidden dim per head (after concatenating heads this is the fusion output dim).
- `model.loss`
  - `type`: `focal` or `cross_entropy`.
  - `gamma`: focal loss gamma.
  - `class_weights`: class weighting list (or `null`).
- `data`
  - `split`: random split ratios (train/val/test).
  - `session_split`: session-based split if you want fixed test sessions.
  - `max_text_len`: tokenizer max length.
  - `max_audio_frames`: cap on mel frames.
  - `num_workers`: DataLoader workers (on Kaggle we keep 0).
- Misc
  - `metrics`: which metrics to log (`wa`, `ua`, `f1_macro`, `cm`, ...).
  - `artifacts_dir`, `log_dir`: where checkpoints and histories are written.

Example: uncomment to override selected values:

```python
hyperparam_overrides = {
    "trainer": {
        "epochs": 80,
        "batch_size": 8,
        "lr": 2e-4,
        "use_tqdm": True,
    },
    "model": {
        "fusion": {"gat_heads": 6, "gat_hidden": 384},
        "audio": {"freeze": False},
        "text": {"proj_dim": 256, "freeze": False},
        "loss": {"type": "cross_entropy"},
    },
    "data": {"split": {"train": 0.8, "val": 0.1, "test": 0.1}},
}
```

Only the keys listed in `hyperparam_overrides` are modified; the rest remain unchanged as defined in `configs/iemocap.yaml` (and, via the `extends` mechanism, in `configs/base.yaml`).

In [None]:
import yaml
from pathlib import Path

## ------------------------------------------------------------------
## Optional: tweak hyperparameters before writing the Kaggle config.
## Uncomment and edit the block below to override defaults inherited
## from configs/base.yaml. Only keys you set here will override the
## base values. The rest remain unchanged.
## ------------------------------------------------------------------
hyperparam_overrides = {
    "trainer": {
        "use_tqdm": True,
        # "epochs": 80,
        # "batch_size": 8,
        # "lr": 2e-4,
        # "weight_decay": 1e-5,
        # "patience": 7,
    },
    # "model": {
    #     "fusion": {
    #         "gat_hidden": 384,
    #         "gat_heads": 6,
    #         "gat_layers": 2,
    #     },
    #     "audio": {"freeze": False},
    #     "text": {"proj_dim": 256, "freeze": False},
    #     "loss": {"type": "cross_entropy"},
    # },
    # "data": {
    #     "split": {"train": 0.8, "val": 0.1, "test": 0.1},
    #     "max_text_len": 160,
    #     "max_audio_frames": 600,
    #     "num_workers": 2,
    # },
}

def deep_update(base: dict, overrides: dict) -> dict:
    out = dict(base)
    for key, value in overrides.items():
        if isinstance(value, dict) and isinstance(out.get(key), dict):
            out[key] = deep_update(out[key], value)
        else:
            out[key] = value
    return out

WORK_DIR = Path("/kaggle/working/ser-conformer-gat-xai")
cfg_path = WORK_DIR / "configs" / "iemocap.yaml"
cfg = yaml.safe_load(cfg_path.read_text())

cfg.setdefault("data", {})
cfg["data"]["root"] = str(IEMOCAP_ROOT)
cfg["data"]["manifest"] = str(MANIFEST_PATH)
cfg["data"]["num_workers"] = 0  # safer on Kaggle
cfg.setdefault("model", {})
cfg["model"].setdefault("audio", {})
cfg["model"]["audio"]["checkpoint"] = str(WORK_DIR / "artifacts" / "audio-encoder" / "conformer_encoder.pkl")
cfg["model"].setdefault("text", {})
cfg["model"]["text"]["checkpoint"] = str(WORK_DIR / "artifacts" / "roberta-text-encoder")
cfg.setdefault("trainer", {})
cfg["trainer"]["batch_size"] = cfg["trainer"].get("batch_size", 8)
cfg["artifacts_dir"] = str(WORK_DIR / "artifacts")
cfg = deep_update(cfg, hyperparam_overrides)

resolved_cfg = WORK_DIR / "configs" / "iemocap_kaggle.yaml"
resolved_cfg.write_text(yaml.safe_dump(cfg, sort_keys=False))
print("Resolved config written to", resolved_cfg)

In [None]:
import os
import subprocess
import sys

WORK_DIR = "/kaggle/working/ser-conformer-gat-xai"
os.chdir(WORK_DIR)
if "src" not in sys.path:
    sys.path.append("src")

result = subprocess.run([
    "python",
    "-m",
    "src.cli.train",
    "--config",
    "configs/iemocap_kaggle.yaml",
    "--use_tqdm",
    "--dry-run",
], check=True)
print("Dry run return code:", result.returncode)

In [None]:
import os
import subprocess
import sys

WORK_DIR = "/kaggle/working/ser-conformer-gat-xai"
os.chdir(WORK_DIR)
if "src" not in sys.path:
    sys.path.append("src")

subprocess.run([
    "python",
    "-m",
    "src.cli.train",
    "--config",
    "configs/iemocap_kaggle.yaml",
    "--use_tqdm",
], check=True)

## Evaluate the saved checkpoint

The cell below loads the most recent run directory, restores `best_model.pt`, and reports validation/test metrics. Adjust `RUN_DIR` manually if you want to evaluate a specific run.

In [None]:
import json
import numpy as np
import torch
import yaml
from pathlib import Path
from sklearn.metrics import confusion_matrix
from transformers import AutoTokenizer

from serxai.data.datamodule import DataModule
from serxai.data.labels import LABELS
from serxai.models.multimodal import MultimodalSERModel
from serxai.utils import metrics as metrics_mod
from serxai.utils.seed import set_seed

WORK_DIR = Path("/kaggle/working/ser-conformer-gat-xai")
cfg_path = WORK_DIR / "configs" / "iemocap_kaggle.yaml"
cfg = yaml.safe_load(cfg_path.read_text())

run_root = (WORK_DIR / cfg.get("log_dir", "experiments/runs")).resolve()
run_dirs = sorted([p for p in run_root.iterdir() if p.is_dir()])
assert run_dirs, f"No run directories found under {run_root}"
RUN_DIR = run_dirs[-1]
print("Evaluating run:", RUN_DIR)

ckpt_path = RUN_DIR / "best_model.pt"
assert ckpt_path.exists(), "best_model.pt missing; finish training first"

seed = int(cfg.get("seed", 42))
set_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(cfg["model"]["text"]["checkpoint"], local_files_only=True)

dm = DataModule(
    manifest_path=str(cfg["data"]["manifest"]),
    tokenizer=tokenizer,
    dataset_root=cfg["data"].get("root"),
    batch_size=int(cfg["trainer"].get("batch_size", 8)),
    num_workers=0,
    split=cfg["data"].get("split"),
    seed=seed,
    session_split=cfg["data"].get("session_split"),
    max_text_length=int(cfg["data"].get("max_text_len", 128)),
    max_audio_frames=cfg["data"].get("max_audio_frames"),
)
dm.setup()
val_loader = dm.val_dataloader()
test_loader = dm.test_dataloader()

fusion_cfg = cfg["model"].get("fusion", {})
text_cfg = cfg["model"].get("text", {})
model = MultimodalSERModel(
    audio_checkpoint=cfg["model"]["audio"]["checkpoint"],
    text_backbone=cfg["model"]["text"]["checkpoint"],
    text_proj_dim=int(text_cfg.get("proj_dim", 128)),
    fusion_hidden=int(fusion_cfg.get("gat_hidden", 256)),
    fusion_heads=int(fusion_cfg.get("gat_heads", 4)),
    fusion_layers=int(fusion_cfg.get("gat_layers", 2)),
    num_classes=len(LABELS),
    freeze_audio=bool(cfg["model"]["audio"].get("freeze", True)),
    freeze_text=bool(text_cfg.get("freeze", True)),
).to(device)
state = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state)
model.eval()

criterion = torch.nn.CrossEntropyLoss()

def run_eval(loader):
    preds, targets = [], []
    total_loss, total_samples = 0.0, 0
    with torch.no_grad():
        for batch in loader:
            labels = batch["labels"].to(device)
            batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()}
            outputs = model(batch)
            logits = outputs["logits"]
            loss = criterion(logits, labels)
            total_loss += loss.item() * labels.size(0)
            total_samples += labels.size(0)
            preds.append(logits.argmax(dim=-1).cpu().numpy())
            targets.append(labels.cpu().numpy())
    preds_np = np.concatenate(preds)
    targets_np = np.concatenate(targets)
    metrics = {
        "loss": total_loss / max(1, total_samples),
        "wa": metrics_mod.wa(preds_np, targets_np),
        "ua": metrics_mod.ua(preds_np, targets_np),
        "f1_macro": metrics_mod.f1_macro(preds_np, targets_np),
        "confusion_matrix": confusion_matrix(targets_np, preds_np).tolist(),
    }
    return metrics

val_metrics = run_eval(val_loader)
test_metrics = run_eval(test_loader)

print("Validation metrics:", json.dumps(val_metrics, indent=2))
print("Test metrics:", json.dumps(test_metrics, indent=2))

with (RUN_DIR / "evaluation_metrics.json").open("w") as fh:
    json.dump({"val": val_metrics, "test": test_metrics}, fh, indent=2)
print("Saved evaluation metrics to", RUN_DIR / "evaluation_metrics.json")