In [None]:
# !pip install facenet-pytorch==2.*
# !pip install torch torchvision torchaudio==2.* librosa==0.* opencv-python==4.* tqdm einops matplotlib
# !pip install --upgrade --force-reinstall numpy scikit-learn --quiet
# !pip install --force-reinstall "numpy<2.1" "numba>=0.59" --quiet

In [None]:

import os, random, csv
from pathlib import Path
from typing import List, Tuple, Dict, Any, Optional
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import warnings
warnings.filterwarnings("ignore")

try:
    import torchvision
except Exception as e:
    torchvision = None
    print("torchvision not available:", e)

try:
    import librosa
except Exception as e:
    librosa = None
    print("librosa not available:", e)

try:
    import torchaudio
except Exception as e:
    torchaudio = None
    print("torchaudio not available:", e)

try:
    import cv2
except Exception as e:
    cv2 = None
    print("opencv-python not available:", e)

try:
    from facenet_pytorch import MTCNN
except Exception as e:
    MTCNN = None
    print("facenet-pytorch not available:", e)

from dataclasses import dataclass, field
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)
torch.backends.cudnn.benchmark = True

def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)

def count_params(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def pad_sequence_1d(seqs: List[torch.Tensor], pad_value: float = 0.0) -> torch.Tensor:
    max_t = max(s.shape[0] for s in seqs)
    if seqs[0].dim() == 2:
        d = seqs[0].shape[1]
        out = torch.full((len(seqs), max_t, d), pad_value, dtype=seqs[0].dtype)
        for i, s in enumerate(seqs):
            out[i, :s.shape[0]] = s
    else:
        out = torch.full((len(seqs), max_t), pad_value, dtype=seqs[0].dtype)
        for i, s in enumerate(seqs):
            out[i, :s.shape[0]] = s
    return out

def lengths_to_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.Tensor:
    if max_len is None: max_len = int(lengths.max().item())
    range_row = torch.arange(max_len, device=lengths.device).unsqueeze(0)
    return range_row < lengths.unsqueeze(1)


In [None]:

@dataclass
class Config:
    ravdess_root: str = "/kaggle/input/ravdess-dataset"  # parent containing Audio_* and Video_*
    meld_root: str = "/kaggle/input/meld-dataset/MELD-RAW/MELD.Raw"          # Kaggle zaber666 root (csv + video/train|dev|test [+ optional audio/*])
    workdir: str = "./artifacts_ravdess_meld_splitdirs"

    # Audio
    sample_rate: int = 16000
    audio_duration_s: float = 2.45
    audio_trim_leading_s: float = 0.5
    n_mfcc: int = 13
    n_fft: int = 512
    hop_length: int = 160
    win_length: int = 400

    # Video
    frames_per_clip: int = 16
    image_size: int = 112
    face_crop: bool = True

    # 7-class
    classes7: List[str] = field(default_factory=lambda: ["angry","disgust","fear","happy","neutral","sad","surprise"])

    # Training
    seed: int = 42
    batch_size: int = 8
    num_workers: int = 4
    epochs: int = 25
    lr: float = 1e-3
    weight_decay: float = 1e-4
    grad_clip: float = 5.0

    # Model
    audio_cnn_channels: Tuple[int, ...] = (32, 64, 128)
    audio_lstm_hidden: int = 128
    video_backbone: str = "resnet18"
    video_lstm_hidden: int = 256
    transformer_d_model: int = 128
    transformer_nhead: int = 4
    transformer_num_layers: int = 2
    fusion_k: int = 64
    dropout: float = 0.3


In [None]:
def load_audio(path: str, cfg: Config) -> np.ndarray:
    """
    Load audio from an audio file (.wav/.flac/.mp3, etc.) first.
    If that fails, fall back to extracting audio from a video file (.mp4/.mkv/.avi/.mov).
    
    Returns a trimmed + padded numpy array of length cfg.audio_duration_s * cfg.sample_rate.
    """
    import numpy as np
    import os

    y, sr = None, None
    dur = int(cfg.audio_duration_s * cfg.sample_rate)
    start = int(cfg.audio_trim_leading_s * cfg.sample_rate)

    try:
        # --- Try as standard audio first ---
        import librosa
        y, sr = librosa.load(path, sr=cfg.sample_rate)
    except Exception as e_audio:
        print(f"[WARN] Audio read failed for {os.path.basename(path)} ({e_audio}). Trying video backend...")
        try:
            # --- Fallback: extract audio from video via FFmpeg ---
            import torchaudio
            wav, sr = torchaudio.load(path, format="mp4")
            if sr != cfg.sample_rate:
                wav = torchaudio.functional.resample(wav, sr, cfg.sample_rate)
                sr = cfg.sample_rate
            y = wav.mean(dim=0).numpy()
        except Exception as e_video:
            print(f"[ERROR] Failed to extract audio from {os.path.basename(path)} ({e_video}). Returning silence.")
            y = np.zeros(dur)
            sr = cfg.sample_rate

    # --- Trim and pad to fixed duration ---
    y = y[start : start + dur]
    if len(y) < dur:
        y = np.pad(y, (0, dur - len(y)))

    return y

def audio_to_mfcc(y: np.ndarray, cfg: Config) -> np.ndarray:
    mfcc = librosa.feature.mfcc(
        y=y, sr=cfg.sample_rate,
        n_mfcc=cfg.n_mfcc,
        n_fft=cfg.n_fft,
        hop_length=cfg.hop_length,
        win_length=cfg.win_length
    )
    return mfcc.T.astype(np.float32)


In [None]:

def get_mtcnn(image_size: int):
    if MTCNN is None: return None
    return MTCNN(image_size=image_size, margin=20, post_process=True, keep_all=False, device=str(DEVICE))

def center_crop_resize(img: np.ndarray, size: int) -> np.ndarray:
    h, w = img.shape[:2]
    scale = size / min(h, w)
    img = cv2.resize(img, (int(w*scale), int(h*scale)))
    h, w = img.shape[:2]
    y0 = (h - size)//2; x0 = (w - size)//2
    return img[y0:y0+size, x0:x0+size]

def load_video_frames_cv2(path: str, cfg: Config, mtcnn=None) -> List[np.ndarray]:
    assert cv2 is not None, "opencv-python not installed"
    cap = cv2.VideoCapture(path)
    frames = []
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0)
    idxs = np.linspace(0, max(total-1, 1), cfg.frames_per_clip).astype(int) if total > 0 else None
    i_sel = 0; i_count = 0
    while True:
        ret, frame = cap.read()
        if not ret: break
        if idxs is None or i_count == idxs[i_sel]:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            if cfg.face_crop and mtcnn is not None:
                try:
                    face = mtcnn(frame_rgb, save_path=None)
                    if face is not None:
                        img = face.permute(1,2,0).byte().cpu().numpy()
                    else:
                        img = center_crop_resize(frame_rgb, cfg.image_size)
                except Exception:
                    img = center_crop_resize(frame_rgb, cfg.image_size)
            else:
                img = center_crop_resize(frame_rgb, cfg.image_size)
            frames.append(img)
            if total > 0 and i_sel < len(idxs)-1: i_sel += 1
            if len(frames) >= cfg.frames_per_clip: break
        i_count += 1
    cap.release()
    if len(frames) == 0:
        frames = [np.zeros((cfg.image_size, cfg.image_size, 3), dtype=np.uint8) for _ in range(cfg.frames_per_clip)]
    if len(frames) < cfg.frames_per_clip:
        frames += [frames[-1]] * (cfg.frames_per_clip - len(frames))
    return frames


In [None]:

EMO_MAP_RAVDESS = {
    "01": "neutral",
    "02": "calm",
    "03": "happy",
    "04": "sad",
    "05": "angry",
    "06": "fearful",
    "07": "disgust",
    "08": "surprised",
}
RAVDESS_TO_7 = {
    "neutral": "neutral",
    "calm": "neutral",
    "happy": "happy",
    "sad": "sad",
    "angry": "angry",
    "fearful": "fear",
    "disgust": "disgust",
    "surprised": "surprise",
}

def index_ravdess_7(root: str) -> List[Dict[str, Any]]:
    root = Path(root)
    video_dirs = [p for p in root.iterdir() if p.is_dir() and "video" in p.name.lower()]
    search_bases = video_dirs if video_dirs else [root]
    mp4s = []
    for b in search_bases:
        mp4s.extend(b.rglob("*.mp4"))
    if not mp4s: mp4s = list(root.rglob("*.mp4"))
    items = []
    for mp4 in mp4s:
        parts = mp4.stem.split("-")
        if len(parts) < 3: continue
        emo_code = parts[2]
        emo_full = EMO_MAP_RAVDESS.get(emo_code)
        if emo_full is None: continue
        mapped = RAVDESS_TO_7.get(emo_full, None)
        if mapped is None: continue
        # Find WAV (mirror folder) or fallback to mp4
        wav_path = None
        for parent in mp4.parents:
            if "video" in parent.name.lower():
                alt = parent.with_name(parent.name.replace("Video", "Audio").replace("video", "Audio"))
                cand = alt / mp4.name.replace(".mp4", ".wav")
                if cand.exists():
                    wav_path = cand; break
        if wav_path is None:
            cand = root / mp4.name.replace(".mp4", ".wav")
            if cand.exists(): wav_path = cand
        audio_path = wav_path if wav_path is not None else mp4
        actor_id = "00"
        for par in mp4.parents:
            if par.name.startswith("Actor_"):
                actor_id = par.name.split("_")[-1]; break
        items.append({"video": str(mp4), "audio": str(audio_path), "label7": mapped, "actor": actor_id})
    return items


In [None]:

MELD_TO_7 = {
    "anger": "angry",
    "disgust": "disgust",
    "fear": "fear",
    "joy": "happy",
    "neutral": "neutral",
    "sadness": "sad",
    "surprise": "surprise",
}

def _csv_rows(csv_path: Path) -> List[Dict[str, Any]]:
    rows = []
    if not csv_path.exists(): return rows
    with csv_path.open("r", encoding="utf-8", errors="ignore") as f:
        reader = csv.DictReader(f)
        for r in reader: rows.append(r)
    return rows

def index_meld_7_splitdirs(root: str, include_splits=("train","dev","test")) -> List[Dict[str, Any]]:
    root = Path(root)
    split_to_csv = {
        "train": root / "train" / "train_sent_emo.csv",
        "dev": root / "dev_sent_emo.csv",
        "test": root / "test_sent_emo.csv",
    }
    items = []
    for split in include_splits:
        rows = _csv_rows(split_to_csv[split])
        if split == "train":
            video_dir = root / split / "train_splits"
        elif split == "dev":
            video_dir = root / split / "dev_splits_complete"
        elif split == "test":
            video_dir = root / split / "output_repeated_splits_test"
        
        audio_dir = root / "audio" / split  # optional
        for r in rows:
            # ints in the CSV
            try:
                dia = int(r.get("Dialogue_ID", r.get("dialogue_id", r.get("Dialogue_ID".lower()))))
                utt = int(r.get("Utterance_ID", r.get("utterance_id", r.get("Utterance_ID".lower()))))
            except Exception:
                continue
            emo_raw = (r.get("Emotion") or r.get("emotion") or "").strip().lower()
            mapped = MELD_TO_7.get(emo_raw, None)
            if mapped is None: continue
            key = f"dia{dia}_utt{utt}"
            vid = None; wav = None
            # prefer per-split video
            cand_vid = video_dir / f"{key}.mp4"
            if not cand_vid.exists():
                # try other common video extensions
                for ext in (".avi",".mov",".mkv"):
                    cv = video_dir / f"{key}{ext}"
                    if cv.exists(): cand_vid = cv; break
            if cand_vid.exists(): vid = cand_vid
            # audio optional
            cand_wav = audio_dir / f"{key}.wav"
            if cand_wav.exists(): wav = cand_wav
            # only accept if we have video; if no wav, we'll read audio from video later
            if vid is None: continue
            items.append({
                "video": str(vid),
                "audio": str(wav if wav is not None else vid),
                "label7": mapped,
                "split": split,
                "utt": key,
            })
    return items


In [None]:

class AVItem:
    __slots__ = ("video", "audio", "label7", "meta")
    def __init__(self, video: str, audio: str, label7: str, meta: Dict[str, Any]):
        self.video, self.audio, self.label7, self.meta = video, audio, label7, meta

class AVDataset(Dataset):
    def __init__(self, items: List[Dict[str, Any]], label_list: List[str], cfg: Config):
        self.items = []
        self.label_to_idx = {c:i for i,c in enumerate(label_list)}
        self.cfg = cfg
        self.mtcnn = get_mtcnn(cfg.image_size) if cfg.face_crop else None
        for it in items:
            if not Path(it["video"]).exists(): continue
            if not Path(it["audio"]).exists():
                # audio path might be a video file (fallback extraction)
                if not Path(it["audio"]).exists(): pass
            meta = {k:v for k,v in it.items() if k not in ("video","audio","label7")}
            self.items.append(AVItem(it["video"], it["audio"], it["label7"], meta))
        print(f"AVDataset: {len(self.items)} items")

    def __len__(self) -> int: return len(self.items)

    def __getitem__(self, idx) -> Dict[str, Any]:
        it = self.items[idx]
        y = load_audio(it.audio, self.cfg)
        mfcc = audio_to_mfcc(y, self.cfg)
        a_len = mfcc.shape[0]
        frames = load_video_frames_cv2(it.video, self.cfg, self.mtcnn)
        vid = np.stack(frames, 0).astype(np.float32) / 255.0
        vid = np.transpose(vid, (0,3,1,2))
        label_idx = self.label_to_idx[it.label7]
        return {"mfcc": torch.from_numpy(mfcc),
                "a_len": torch.tensor(a_len, dtype=torch.long),
                "video": torch.from_numpy(vid),
                "label": torch.tensor(label_idx, dtype=torch.long),
                "meta": it.meta}

def av_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    mfccs = [b["mfcc"] for b in batch]
    a_lens = torch.tensor([b["a_len"] for b in batch], dtype=torch.long)
    labels = torch.stack([b["label"] for b in batch])
    mfcc_p = pad_sequence_1d(mfccs)
    vids = torch.stack([b["video"] for b in batch], dim=0)
    return {"mfcc": mfcc_p.float(), "a_len": a_lens, "video": vids.float(), "label": labels}


In [None]:

class Conv1DBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k=7, p=3, s=1, dropout=0.0):
        super().__init__()
        self.conv = nn.Conv1d(in_ch, out_ch, kernel_size=k, padding=p, stride=s)
        self.bn = nn.BatchNorm1d(out_ch)
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.conv(x); x = F.relu(self.bn(x)); return self.drop(x)

class AudioEncoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        c = cfg.audio_cnn_channels
        self.conv1 = Conv1DBlock(cfg.n_mfcc, c[0], k=7, p=3, dropout=cfg.dropout)
        self.conv2 = Conv1DBlock(c[0], c[1], k=7, p=3, dropout=cfg.dropout)
        self.conv3 = Conv1DBlock(c[1], c[2], k=5, p=2, dropout=cfg.dropout)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.lstm = nn.LSTM(input_size=c[-1], hidden_size=cfg.audio_lstm_hidden, num_layers=1, batch_first=True, bidirectional=True)
        self.proj = nn.Linear(2*cfg.audio_lstm_hidden, cfg.transformer_d_model)
        enc_layer = nn.TransformerEncoderLayer(d_model=cfg.transformer_d_model, nhead=cfg.transformer_nhead,
                                               dim_feedforward=4*cfg.transformer_d_model, batch_first=True,
                                               dropout=cfg.dropout, activation="relu")
        self.self_attn = nn.TransformerEncoder(enc_layer, num_layers=cfg.transformer_num_layers)
        self.out_dim = cfg.transformer_d_model
    def forward(self, mfcc: torch.Tensor, lengths: torch.Tensor):
        x = mfcc.transpose(1,2)
        x = self.conv1(x); x = self.pool(x)
        x = self.conv2(x); x = self.pool(x)
        x = self.conv3(x); x = x.transpose(1,2)
        x_packed = nn.utils.rnn.pack_padded_sequence(x, lengths.cpu()//4, batch_first=True, enforce_sorted=False)
        x_enc, _ = self.lstm(x_packed)
        x_enc, _ = nn.utils.rnn.pad_packed_sequence(x_enc, batch_first=True)
        x_proj = self.proj(x_enc)
        x_sel = self.self_attn(x_proj)
        mask = lengths_to_mask(lengths//4, x_sel.size(1)).unsqueeze(-1)
        x_sel_masked = x_sel * mask
        denom = mask.sum(dim=1).clamp(min=1)
        a_summary = x_sel_masked.sum(dim=1) / denom
        return x_sel, a_summary

class FrameCNN(nn.Module):
    def __init__(self, backbone="resnet18"):
        super().__init__()
        assert torchvision is not None, "torchvision not installed"
        if backbone == "resnet18":
            net = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)
            modules = list(net.children())[:-1]
            self.backbone = nn.Sequential(*modules)
            self.out_dim = 512
        else: raise ValueError("Unsupported backbone")
    def forward(self, x):
        return self.backbone(x).flatten(1)

class VideoEncoder(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.frame_cnn = FrameCNN(backbone=cfg.video_backbone)
        self.lstm = nn.LSTM(input_size=self.frame_cnn.out_dim, hidden_size=cfg.video_lstm_hidden, num_layers=1, batch_first=True, bidirectional=True)
        self.out_dim = 2*cfg.video_lstm_hidden
    def forward(self, video: torch.Tensor):
        B, T, C, H, W = video.shape
        x = video.reshape(B*T, C, H, W)
        feats = self.frame_cnn(x).view(B, T, -1)
        v_seq, _ = self.lstm(feats)
        v_summary = v_seq.mean(dim=1)
        return v_seq, v_summary

class CrossModalFusion(nn.Module):
    def __init__(self, v_dim: int, a_dim: int, k: int):
        super().__init__()
        self.Wv = nn.Linear(v_dim, k, bias=True)
        self.Wa = nn.Linear(a_dim, k, bias=False)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, v_seq: torch.Tensor, a_summary: torch.Tensor) -> torch.Tensor:
        Pv = self.Wv(v_seq)
        Pa = self.Wa(a_summary).unsqueeze(1)
        Xq = torch.tanh(Pv + Pa)
        alpha = self.softmax(Xq.mean(dim=-1)).unsqueeze(-1)
        return v_seq + alpha * v_seq

class AVFusionModel(nn.Module):
    def __init__(self, cfg: Config, n_classes: int):
        super().__init__()
        self.audio_enc = AudioEncoder(cfg)
        self.video_enc = VideoEncoder(cfg)
        self.fusion = CrossModalFusion(self.video_enc.out_dim, self.audio_enc.out_dim, cfg.fusion_k)
        fused_dim = self.video_enc.out_dim + self.audio_enc.out_dim
        self.cls = nn.Sequential(nn.LayerNorm(fused_dim), nn.Dropout(cfg.dropout), nn.Linear(fused_dim, n_classes))
    def forward(self, mfcc, a_len, video):
        _, a_summary = self.audio_enc(mfcc, a_len)
        v_seq, _ = self.video_enc(video)
        v_fused = self.fusion(v_seq, a_summary)
        v_pooled = v_fused.mean(dim=1)
        fused = torch.cat([v_pooled, a_summary], dim=-1)
        return self.cls(fused)


In [None]:

def accuracy_score_simple(y_true, y_pred):
    yt = torch.as_tensor(y_true); yp = torch.as_tensor(y_pred)
    return (yt == yp).float().mean().item()

def f1_weighted_simple(y_true, y_pred, n_classes):
    yt = torch.as_tensor(y_true, dtype=torch.long)
    yp = torch.as_tensor(y_pred, dtype=torch.long)
    f1s, supports = [], []
    for c in range(n_classes):
        tp = ((yp == c) & (yt == c)).sum().item()
        fp = ((yp == c) & (yt != c)).sum().item()
        fn = ((yp != c) & (yt == c)).sum().item()
        support = (yt == c).sum().item()
        denom = (2*tp + fp + fn)
        f1 = (2*tp / denom) if denom > 0 else 0.0
        f1s.append(f1); supports.append(support)
    total = sum(supports) if sum(supports) > 0 else 1
    return sum(f*s for f, s in zip(f1s, supports)) / total

def train_one_epoch(model, loader, optim, scheduler, cfg: Config):
    model.train()
    total_loss, correct, total = 0.0, 0, 0
    for batch in tqdm(loader, desc="Train", leave=False):
        mfcc = batch["mfcc"].to(DEVICE)
        a_len = batch["a_len"].to(DEVICE)
        video = batch["video"].to(DEVICE)
        labels = batch["label"].to(DEVICE)

        optim.zero_grad()
        logits = model(mfcc, a_len, video)
        loss = F.cross_entropy(logits, labels)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip)
        optim.step()

        total_loss += loss.item() * labels.size(0)
        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    if scheduler is not None:
        scheduler.step(total_loss / max(1,total))
    return total_loss / max(1,total), correct / max(1,total)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total_loss, y_true, y_pred = 0.0, [], []
    for batch in tqdm(loader, desc="Eval", leave=False):
        mfcc = batch["mfcc"].to(DEVICE)
        a_len = batch["a_len"].to(DEVICE)
        video = batch["video"].to(DEVICE)
        labels = batch["label"].to(DEVICE)
        logits = model(mfcc, a_len, video)
        loss = F.cross_entropy(logits, labels)
        total_loss += loss.item() * labels.size(0)
        y_true.extend(labels.cpu().tolist())
        y_pred.extend(logits.argmax(dim=-1).cpu().tolist())
    n_classes = int(max(max(y_true), max(y_pred)) + 1) if y_true else 1
    acc = accuracy_score_simple(y_true, y_pred)
    f1w = f1_weighted_simple(y_true, y_pred, n_classes=n_classes)
    return total_loss / max(1,len(y_true)), acc, f1w, (y_true, y_pred)


In [None]:

def kfold_from_actors(items: List[Dict[str, Any]], k: int = 5) -> List[Tuple[List[int], List[int]]]:
    from collections import defaultdict
    by_actor = defaultdict(list)
    for i, it in enumerate(items):
        a = it.get("actor", "00"); by_actor[a].append(i)
    actors = sorted(by_actor.keys())
    folds = [[] for _ in range(k)]
    for i, a in enumerate(actors):
        folds[i % k].extend(by_actor[a])
    splits = []
    for i in range(k):
        test_idx = folds[i]
        train_idx = [j for t in range(k) if t != i for j in folds[t]]
        splits.append((train_idx, test_idx))
    return splits

def meld_standard_split(items: List[Dict[str, Any]]):
    idx_train = [i for i,x in enumerate(items) if x.get("split")=="train"]
    idx_dev   = [i for i,x in enumerate(items) if x.get("split")=="dev"]
    idx_test  = [i for i,x in enumerate(items) if x.get("split")=="test"]
    return idx_train, idx_dev, idx_test

def make_loader(items, labels, cfg: Config, shuffle: bool):
    ds = AVDataset(items, labels, cfg)
    dl = DataLoader(ds, batch_size=cfg.batch_size, shuffle=shuffle, num_workers=cfg.num_workers,
                    collate_fn=av_collate, pin_memory=True)
    return ds, dl


In [None]:
import os
import torch
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from sklearn.model_selection import train_test_split

def train_on_ravdess_eval_meld(cfg: Config):
    labels = cfg.classes7

    # Load datasets
    rav = index_ravdess_7(cfg.ravdess_root)
    assert len(rav) > 0, "No RAVDESS items found (7-class)."
    meld = index_meld_7_splitdirs(cfg.meld_root, include_splits=("train","dev","test"))
    assert len(meld) > 0, "No MELD items found (7-class)."

    # Split RAVDESS into train/test
    tr_items, te_items = train_test_split(rav, test_size=0.2, random_state=cfg.seed, shuffle=True)
    _, _, meld_test_idx = meld_standard_split(meld)
    meld_test = [meld[i] for i in meld_test_idx]

    print(f"RAVDESS train: {len(tr_items)}, test: {len(te_items)}, MELD test: {len(meld_test)}")

    # Data loaders
    set_seed(cfg.seed)
    _, train_dl = make_loader(tr_items, labels, cfg, shuffle=True)
    _, test_dl_r = make_loader(te_items, labels, cfg, shuffle=False)
    _, test_dl_m = make_loader(meld_test, labels, cfg, shuffle=False)

    # Model setup
    model = AVFusionModel(cfg, n_classes=len(labels)).to(DEVICE)
    print("Params:", count_params(model))
    optim = Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = ReduceLROnPlateau(optim, mode="min", patience=3, factor=0.5, verbose=True)

    best = {"acc_r": 0.0, "state": None}

    # Training loop
    for epoch in range(1, cfg.epochs + 1):
        tl, ta = train_one_epoch(model, train_dl, optim, sched, cfg)
        vl, va, vf1, _ = evaluate(model, test_dl_r)
        print(f"[RAV train/test] Epoch {epoch:02d} | train {tl:.4f}/{ta:.3f} | val {vl:.4f}/{va:.3f} f1w {vf1:.3f}")
        if va > best["acc_r"]:
            best["acc_r"] = va
            best["state"] = {k: v.cpu() for k, v in model.state_dict().items()}

    # Evaluate best model
    model.load_state_dict(best["state"])
    _, acc_r, f1_r, _ = evaluate(model, test_dl_r)
    _, acc_m, f1_m, _ = evaluate(model, test_dl_m)
    print(f"Best -> RAVDESS acc {acc_r:.3f}, MELD acc {acc_m:.3f}")

    return (acc_r, f1_r, acc_m, f1_m)


def train_on_meld_eval_ravdess(cfg: Config):
    labels = cfg.classes7
    meld = index_meld_7_splitdirs(cfg.meld_root, include_splits=("train","dev","test"))
    assert len(meld) > 0, "No MELD items found (7-class)."
    idx_tr, idx_dev, idx_te = meld_standard_split(meld)
    meld_train = [meld[i] for i in (idx_tr + idx_dev)]
    meld_test  = [meld[i] for i in idx_te]

    rav = index_ravdess_7(cfg.ravdess_root)
    assert len(rav) > 0, "No RAVDESS items found (7-class)."

    set_seed(cfg.seed)
    _, train_dl = make_loader(meld_train, labels, cfg, shuffle=True)
    _, test_dl_m = make_loader(meld_test, labels, cfg, shuffle=False)
    _, test_dl_r = make_loader(rav, labels, cfg, shuffle=False)

    model = AVFusionModel(cfg, n_classes=len(labels)).to(DEVICE)
    print("Params:", count_params(model))
    optim = Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)
    sched = ReduceLROnPlateau(optim, mode="min", patience=3, factor=0.5, verbose=True)

    best = {"acc_m":0.0, "state":None}
    for epoch in range(1, cfg.epochs+1):
        tl, ta = train_one_epoch(model, train_dl, optim, sched, cfg)
        vl, va, vf1, _ = evaluate(model, test_dl_m)
        print(f"[MELD] Epoch {epoch:02d} | train {tl:.4f}/{ta:.3f} | val {vl:.4f}/{va:.3f} f1w {vf1:.3f}")
        if va > best["acc_m"]:
            best["acc_m"] = va
            best["state"] = {k:v.cpu() for k,v in model.state_dict().items()}
    model.load_state_dict(best["state"])
    _, acc_m, f1_m, _ = evaluate(model, test_dl_m)
    _, acc_r, f1_r, _ = evaluate(model, test_dl_r)
    print(f"MELD-trained best -> MELD acc {acc_m:.3f}, RAVDESS acc {acc_r:.3f}")
    return (acc_m, f1_m, acc_r, f1_r)


In [None]:
# Dry-run shapes
set_seed(7)
cfg = Config()
B, Ta, Tv = 2, 200, cfg.frames_per_clip
mfcc = torch.randn(B, Ta, cfg.n_mfcc).to(DEVICE)
a_len = torch.tensor([Ta, Ta-10], dtype=torch.long).to(DEVICE)
video = torch.randn(B, Tv, 3, cfg.image_size, cfg.image_size).to(DEVICE)
model = AVFusionModel(cfg, n_classes=len(cfg.classes7)).to(DEVICE)
with torch.no_grad():
    logits = model(mfcc, a_len, video)
print("Logits shape:", logits.shape)


In [None]:
cfg = Config(
    ravdess_root="/kaggle/input/ravdess-dataset",
    meld_root="/kaggle/input/meld-dataset/MELD-RAW/MELD.Raw",   # folder that contains csvs and video/train|dev|test (and optional audio/..)
    workdir="./artifacts_ravdess_meld_splitdirs",
    face_crop=True,
    epochs=10,
    batch_size=32,
)

rav_to_meld = train_on_ravdess_eval_meld(cfg)
meld_to_rav = train_on_meld_eval_ravdess(cfg)