In [13]:
"""
Cross-Modal Emotion Loss Module

This module implements the core novelty of the research:
Ensuring emotional consistency between audio (voice) and video (facial expressions)
in talking face generation.

The loss penalizes emotional misalignment between modalities.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Dict, Optional, Literal, Tuple
import numpy as np


class CrossModalEmotionLoss(nn.Module):
    """
    Cross-modal emotion loss that enforces emotional consistency
    between audio and video modalities.

    Supports multiple loss types:
    - 'cosine': Cosine similarity between embeddings
    - 'mse': Mean squared error between embeddings
    - 'kl': KL divergence between emotion distributions
    - 'ce': Cross-entropy using predicted classes
    """

    def __init__(
        self,
        loss_type: Literal['cosine', 'mse', 'kl', 'ce'] = 'cosine',
        temperature: float = 0.07,
        weight: float = 1.0,
        normalize_embeddings: bool = True,
    ):
        """
        Args:
            loss_type: Type of loss function
                - 'cosine': 1 - cosine_similarity (encourages similar direction)
                - 'mse': Mean squared error (encourages similar magnitude)
                - 'kl': KL divergence (for probability distributions)
                - 'ce': Cross-entropy (for discrete emotion matching)
            temperature: Temperature for scaling similarities (for 'kl')
            weight: Weight multiplier for the loss
            normalize_embeddings: Whether to L2-normalize embeddings before comparison
        """
        super().__init__()
        self.loss_type = loss_type
        self.temperature = temperature
        self.weight = weight
        self.normalize_embeddings = normalize_embeddings

    def forward(
        self,
        audio_embeddings: torch.Tensor,
        video_embeddings: torch.Tensor,
        audio_logits: Optional[torch.Tensor] = None,
        video_logits: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Compute cross-modal emotion loss.

        Args:
            audio_embeddings: (B, D) audio emotion embeddings
            video_embeddings: (B, D) video emotion embeddings
            audio_logits: (B, num_classes) optional logits from audio encoder
            video_logits: (B, num_classes) optional logits from video encoder

        Returns:
            loss: Scalar tensor representing the emotion misalignment
        """
        if self.normalize_embeddings:
            audio_embeddings = F.normalize(audio_embeddings, p=2, dim=-1)
            video_embeddings = F.normalize(video_embeddings, p=2, dim=-1)

        if self.loss_type == 'cosine':
            # Cosine similarity loss: 1 - cosine_similarity
            # Range: [0, 2], where 0 = perfect alignment, 2 = opposite
            cos_sim = F.cosine_similarity(audio_embeddings, video_embeddings, dim=-1)
            loss = (1.0 - cos_sim).mean()

        elif self.loss_type == 'mse':
            # MSE between embeddings
            loss = F.mse_loss(audio_embeddings, video_embeddings)

        elif self.loss_type == 'kl':
            # KL divergence between emotion probability distributions
            # Convert embeddings to probabilities via softmax
            audio_probs = F.softmax(audio_embeddings / self.temperature, dim=-1)
            video_probs = F.softmax(video_embeddings / self.temperature, dim=-1)

            # KL(video || audio): how much video differs from audio
            loss = F.kl_div(
                video_probs.log(),
                audio_probs,
                reduction='batchmean'
            )

        elif self.loss_type == 'ce':
            # Cross-entropy using predicted emotion classes
            if audio_logits is None or video_logits is None:
                raise ValueError("audio_logits and video_logits required for 'ce' loss")

            # Use audio predictions as pseudo-labels for video
            audio_preds = audio_logits.argmax(dim=-1)
            video_log_probs = F.log_softmax(video_logits, dim=-1)
            loss = F.nll_loss(video_log_probs, audio_preds)

        else:
            raise ValueError(f"Unknown loss_type: {self.loss_type}")

        return self.weight * loss


class AdaptiveCrossModalEmotionLoss(nn.Module):
    """
    Adaptive cross-modal emotion loss that combines multiple loss types
    with learnable weights.
    """

    def __init__(
        self,
        use_cosine: bool = True,
        use_mse: bool = True,
        use_kl: bool = False,
        temperature: float = 0.07,
        initial_weights: Optional[Dict[str, float]] = None,
        learnable_weights: bool = True,
    ):
        """
        Args:
            use_cosine: Whether to include cosine similarity loss
            use_mse: Whether to include MSE loss
            use_kl: Whether to include KL divergence loss
            temperature: Temperature for KL divergence
            initial_weights: Initial weights for each loss type
            learnable_weights: Whether weights are learnable parameters
        """
        super().__init__()

        self.use_cosine = use_cosine
        self.use_mse = use_mse
        self.use_kl = use_kl
        self.temperature = temperature

        # Initialize weights
        if initial_weights is None:
            initial_weights = {
                'cosine': 1.0,
                'mse': 0.5,
                'kl': 0.3,
            }

        # Create weight parameters
        if learnable_weights:
            if use_cosine:
                self.weight_cosine = nn.Parameter(
                    torch.tensor(initial_weights.get('cosine', 1.0))
                )
            if use_mse:
                self.weight_mse = nn.Parameter(
                    torch.tensor(initial_weights.get('mse', 0.5))
                )
            if use_kl:
                self.weight_kl = nn.Parameter(
                    torch.tensor(initial_weights.get('kl', 0.3))
                )
        else:
            if use_cosine:
                self.register_buffer(
                    'weight_cosine',
                    torch.tensor(initial_weights.get('cosine', 1.0))
                )
            if use_mse:
                self.register_buffer(
                    'weight_mse',
                    torch.tensor(initial_weights.get('mse', 0.5))
                )
            if use_kl:
                self.register_buffer(
                    'weight_kl',
                    torch.tensor(initial_weights.get('kl', 0.3))
                )

    def forward(
        self,
        audio_embeddings: torch.Tensor,
        video_embeddings: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Compute adaptive cross-modal emotion loss.

        Args:
            audio_embeddings: (B, D) audio emotion embeddings
            video_embeddings: (B, D) video emotion embeddings

        Returns:
            total_loss: Combined weighted loss
            loss_dict: Dictionary of individual loss components
        """
        loss_dict = {}
        total_loss = 0.0

        # Normalize embeddings
        audio_norm = F.normalize(audio_embeddings, p=2, dim=-1)
        video_norm = F.normalize(video_embeddings, p=2, dim=-1)

        if self.use_cosine:
            cos_sim = F.cosine_similarity(audio_norm, video_norm, dim=-1)
            loss_cosine = (1.0 - cos_sim).mean()
            total_loss = total_loss + self.weight_cosine * loss_cosine
            loss_dict['cosine'] = loss_cosine.item()
            loss_dict['weight_cosine'] = self.weight_cosine.item()

        if self.use_mse:
            loss_mse = F.mse_loss(audio_norm, video_norm)
            total_loss = total_loss + self.weight_mse * loss_mse
            loss_dict['mse'] = loss_mse.item()
            loss_dict['weight_mse'] = self.weight_mse.item()

        if self.use_kl:
            audio_probs = F.softmax(audio_embeddings / self.temperature, dim=-1)
            video_probs = F.softmax(video_embeddings / self.temperature, dim=-1)
            loss_kl = F.kl_div(
                video_probs.log(),
                audio_probs,
                reduction='batchmean'
            )
            total_loss = total_loss + self.weight_kl * loss_kl
            loss_dict['kl'] = loss_kl.item()
            loss_dict['weight_kl'] = self.weight_kl.item()

        loss_dict['total'] = total_loss.item()

        return total_loss, loss_dict


class WindowLevelEmotionLoss(nn.Module):
    """
    Window-level emotion loss for temporal alignment.
    Computes emotion loss across multiple temporal windows.
    """

    def __init__(
        self,
        loss_type: Literal['cosine', 'mse', 'kl'] = 'cosine',
        aggregation: Literal['mean', 'max', 'weighted'] = 'mean',
        temperature: float = 0.07,
        weight: float = 1.0,
    ):
        """
        Args:
            loss_type: Type of loss function
            aggregation: How to aggregate losses across windows
                - 'mean': Simple average
                - 'max': Maximum loss (focus on worst alignment)
                - 'weighted': Weighted by temporal importance
            temperature: Temperature for KL divergence
            weight: Weight multiplier for the loss
        """
        super().__init__()
        self.loss_type = loss_type
        self.aggregation = aggregation
        self.temperature = temperature
        self.weight = weight

        # Learnable temporal weights if using weighted aggregation
        if aggregation == 'weighted':
            self.temporal_weights = nn.Parameter(torch.ones(1))

    def forward(
        self,
        audio_windows: torch.Tensor,
        video_windows: torch.Tensor,
    ) -> torch.Tensor:
        """
        Compute window-level emotion loss.

        Args:
            audio_windows: (B, T, D) audio embeddings for T windows
            video_windows: (B, T, D) video embeddings for T windows

        Returns:
            loss: Aggregated loss across windows
        """
        B, T, D = audio_windows.shape

        # Normalize
        audio_norm = F.normalize(audio_windows, p=2, dim=-1)
        video_norm = F.normalize(video_windows, p=2, dim=-1)

        if self.loss_type == 'cosine':
            # Compute cosine similarity per window
            cos_sim = F.cosine_similarity(audio_norm, video_norm, dim=-1)  # (B, T)
            window_losses = 1.0 - cos_sim  # (B, T)

        elif self.loss_type == 'mse':
            # MSE per window
            window_losses = ((audio_norm - video_norm) ** 2).mean(dim=-1)  # (B, T)

        elif self.loss_type == 'kl':
            # KL divergence per window
            audio_probs = F.softmax(audio_windows / self.temperature, dim=-1)
            video_probs = F.softmax(video_windows / self.temperature, dim=-1)

            window_losses = (audio_probs * (
                audio_probs.log() - video_probs.log()
            )).sum(dim=-1)  # (B, T)
        else:
            raise ValueError(f"Unknown loss_type: {self.loss_type}")

        # Aggregate across windows
        if self.aggregation == 'mean':
            loss = window_losses.mean()
        elif self.aggregation == 'max':
            loss = window_losses.max(dim=-1)[0].mean()
        elif self.aggregation == 'weighted':
            # Weighted by temporal position (learnable)
            weights = F.softmax(self.temporal_weights.expand(T), dim=0)
            loss = (window_losses * weights.view(1, -1)).sum(dim=-1).mean()
        else:
            raise ValueError(f"Unknown aggregation: {self.aggregation}")

        return self.weight * loss


class EmotionAgreementMetric:
    """
    Metric for evaluating emotion agreement between audio and video.
    This is your main evaluation metric: Emotion Agreement.
    """

    def __init__(
        self,
        num_classes: int = 8,
        threshold: float = 0.8,  # Cosine similarity threshold
    ):
        """
        Args:
            num_classes: Number of emotion classes
            threshold: Threshold for considering emotions "aligned"
        """
        self.num_classes = num_classes
        self.threshold = threshold
        self.reset()

    def reset(self):
        """Reset metric counters"""
        self.total = 0
        self.matches = 0
        self.cosine_sims = []
        self.audio_preds = []
        self.video_preds = []

    def update(
        self,
        audio_embeddings: torch.Tensor,
        video_embeddings: torch.Tensor,
        audio_logits: Optional[torch.Tensor] = None,
        video_logits: Optional[torch.Tensor] = None,
    ):
        """
        Update metric with a batch.

        Args:
            audio_embeddings: (B, D) audio emotion embeddings
            video_embeddings: (B, D) video emotion embeddings
            audio_logits: (B, num_classes) optional logits
            video_logits: (B, num_classes) optional logits
        """
        B = audio_embeddings.shape[0]
        self.total += B

        # Normalize embeddings
        audio_norm = F.normalize(audio_embeddings, p=2, dim=-1)
        video_norm = F.normalize(video_embeddings, p=2, dim=-1)

        # Compute cosine similarity
        cos_sim = F.cosine_similarity(audio_norm, video_norm, dim=-1)
        self.cosine_sims.extend(cos_sim.detach().cpu().tolist())

        # Count matches based on threshold
        matches = (cos_sim >= self.threshold).sum().item()
        self.matches += matches

        # If logits provided, also track class agreement
        if audio_logits is not None and video_logits is not None:
            audio_pred = audio_logits.argmax(dim=-1)
            video_pred = video_logits.argmax(dim=-1)
            self.audio_preds.extend(audio_pred.detach().cpu().tolist())
            self.video_preds.extend(video_pred.detach().cpu().tolist())

    def compute(self) -> Dict[str, float]:
        """
        Compute final metrics.

        Returns:
            metrics: Dictionary containing:
                - agreement_rate: % of samples with high cosine similarity
                - avg_cosine_sim: Average cosine similarity
                - class_agreement: % of samples with matching predicted classes
        """
        if self.total == 0:
            return {
                'agreement_rate': 0.0,
                'avg_cosine_sim': 0.0,
                'class_agreement': 0.0,
            }

        metrics = {
            'agreement_rate': self.matches / self.total,
            'avg_cosine_sim': np.mean(self.cosine_sims),
        }

        # Class agreement if predictions available
        if len(self.audio_preds) > 0:
            class_matches = sum(
                a == v for a, v in zip(self.audio_preds, self.video_preds)
            )
            metrics['class_agreement'] = class_matches / len(self.audio_preds)
        else:
            metrics['class_agreement'] = 0.0

        return metrics


# Example usage function
def compute_emotion_loss_example(
    audio_encoder,
    video_encoder,
    batch: Dict,
    loss_fn: CrossModalEmotionLoss,
    device: str = 'cuda',
) -> Tuple[torch.Tensor, Dict[str, float]]:
    """
    Example function showing how to compute emotion loss during training.

    Args:
        audio_encoder: Trained audio emotion encoder
        video_encoder: Trained video emotion encoder
        batch: Batch from dataloader with 'audio' and 'video'
        loss_fn: CrossModalEmotionLoss instance
        device: Device to use

    Returns:
        loss: Emotion loss value
        metrics: Dictionary with additional metrics
    """
    # Extract embeddings from encoders
    audio_embs = audio_encoder.extract_embeddings_clip(
        batch['audio'],
        sr=batch.get('sample_rate', 16000),
        window_seconds=1.5
    )

    video_embs = video_encoder.extract_embeddings_clip(
        batch['video'],
        frames_for_model=16
    )

    # Compute loss
    loss = loss_fn(audio_embs, video_embs)

    # Compute metrics
    with torch.no_grad():
        cos_sim = F.cosine_similarity(
            F.normalize(audio_embs, p=2, dim=-1),
            F.normalize(video_embs, p=2, dim=-1),
            dim=-1
        )
        metrics = {
            'emotion_loss': loss.item(),
            'avg_cosine_sim': cos_sim.mean().item(),
            'min_cosine_sim': cos_sim.min().item(),
            'max_cosine_sim': cos_sim.max().item(),
        }

    return loss, metrics

In [14]:
if __name__ == "__main__":
    # Example usage
    print("="*60)
    print("Cross-Modal Emotion Loss - Example Usage")
    print("="*60)

    # Create sample embeddings
    batch_size = 4
    embed_dim = 768

    audio_emb = torch.randn(batch_size, embed_dim)
    video_emb = torch.randn(batch_size, embed_dim)

    # Test different loss types
    loss_types = ['cosine', 'mse', 'kl']

    for loss_type in loss_types:
        loss_fn = CrossModalEmotionLoss(
            loss_type=loss_type,
            weight=1.0,
            normalize_embeddings=True
        )

        loss = loss_fn(audio_emb, video_emb)
        print(f"\n{loss_type.upper()} Loss: {loss.item():.4f}")

    # Test adaptive loss
    print("\n" + "="*60)
    print("Adaptive Cross-Modal Emotion Loss")
    print("="*60)

    adaptive_loss_fn = AdaptiveCrossModalEmotionLoss(
        use_cosine=True,
        use_mse=True,
        use_kl=False,
        learnable_weights=True
    )

    total_loss, loss_dict = adaptive_loss_fn(audio_emb, video_emb)
    print(f"\nTotal Loss: {total_loss.item():.4f}")
    print("Components:")
    for key, value in loss_dict.items():
        print(f"  {key}: {value:.4f}")

    # Test emotion agreement metric
    print("\n" + "="*60)
    print("Emotion Agreement Metric")
    print("="*60)

    metric = EmotionAgreementMetric(num_classes=8, threshold=0.8)

    # Simulate predictions
    audio_logits = torch.randn(batch_size, 8)
    video_logits = torch.randn(batch_size, 8)

    metric.update(audio_emb, video_emb, audio_logits, video_logits)
    results = metric.compute()

    print(f"\nAgreement Rate: {results['agreement_rate']:.2%}")
    print(f"Avg Cosine Sim: {results['avg_cosine_sim']:.4f}")
    print(f"Class Agreement: {results['class_agreement']:.2%}")

    print("\n✓ Cross-modal emotion loss module ready!")

Cross-Modal Emotion Loss - Example Usage

COSINE Loss: 0.9795

MSE Loss: 0.0026

KL Loss: 0.2595

Adaptive Cross-Modal Emotion Loss

Total Loss: 0.9808
Components:
  cosine: 0.9795
  weight_cosine: 1.0000
  mse: 0.0026
  weight_mse: 0.5000
  total: 0.9808

Emotion Agreement Metric

Agreement Rate: 0.00%
Avg Cosine Sim: 0.0205
Class Agreement: 25.00%

✓ Cross-modal emotion loss module ready!


In [15]:
from torch.utils.data import Dataset, DataLoader, Subset

In [16]:
# encoders_window_level_v3_fixed2.py
# -*- coding: utf-8 -*-
"""
Window-level Emotion Encoders (Audio + Video), stable on small VMs/Colab.

Fixes vs v3_fixed:
- Do NOT call enable_input_require_grads() (it caused AttributeError with tuple outputs).
- Gradient checkpointing is optional and disabled by default (use_checkpoint=False).
- Keeps do_rescale=False, AMP, workers=0, pin_memory=False, etc.
"""

import os
import json
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
import torchaudio
from torch.utils.data import Dataset, DataLoader, Subset

# Stable AMP API
from torch.amp import GradScaler, autocast

# Silence noisy torchaudio warnings
warnings.filterwarnings("ignore", message=".*StreamingMediaDecoder.*")
warnings.filterwarnings("ignore", message=".*load_with_torchcodec.*")

# Safer multiprocessing defaults (we still use workers=0 by default)
import torch.multiprocessing as mp
try:
    mp.set_start_method("spawn", force=True)
except RuntimeError:
    pass
try:
    mp.set_sharing_strategy("file_system")
except RuntimeError:
    pass

from transformers import (
    Wav2Vec2ForSequenceClassification,
    HubertForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    AutoImageProcessor,
    TimesformerForVideoClassification,
)

from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix


# -------------------
# Constants / Labels
# -------------------

EMOTION_TO_ID = {
    "neutral": 0, "calm": 1, "happy": 2, "sad": 3,
    "angry": 4, "fearful": 5, "disgust": 6, "surprised": 7
}
EMOTION_NAMES = ["neutral", "calm", "happy", "sad", "angry", "fearful", "disgust", "surprised"]


# -------------------
# Utils
# -------------------

def ensure_dir(p: Union[str, Path]):
    Path(p).mkdir(parents=True, exist_ok=True)


def uniform_indices(total: int, target: int) -> np.ndarray:
    if total <= 0:
        return np.zeros((target,), dtype=int)
    if total <= target:
        base = np.arange(total)
        pad = np.full(target - total, total - 1, dtype=int)
        return np.concatenate([base, pad])
    return np.round(np.linspace(0, total - 1, target)).astype(int)


def set_seed(seed: int = 42):
    import random
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def set_backbone_trainable_timesformer(model: nn.Module, trainable: bool):
    for n, p in model.named_parameters():
        if "classifier" in n:
            continue
        p.requires_grad = trainable


def safe_freeze_wav2vec_feature_encoder(model: nn.Module):
    if hasattr(model, "freeze_feature_encoder"):
        model.freeze_feature_encoder()
    else:
        for n, p in model.named_parameters():
            if "classifier" in n:
                continue
            p.requires_grad = False


def safe_unfreeze_wav2vec_feature_encoder(model: nn.Module):
    for p in model.parameters():
        p.requires_grad = True


# -------------------
# Dataset
# -------------------

class EmotionDataset(Dataset):
    def __init__(
        self,
        metadata_path: Union[str, Path],
        video_max_frames: int = 64,
        audio_target_sr: int = 16000,
        load_audio: bool = True,
        load_video: bool = True,
    ):
        with open(metadata_path, "r", encoding="utf-8") as f:
            self.meta: List[Dict] = json.load(f)
        if len(self.meta) == 0:
            raise ValueError("Empty metadata file.")

        self.video_max_frames = int(video_max_frames)
        self.audio_target_sr = int(audio_target_sr)
        self.load_audio = bool(load_audio)
        self.load_video = bool(load_video)

        m0 = self.meta[0]
        self.uses_npz = "video_npz" in m0

        self.frames_per_clip = int(m0.get("frames_per_clip", m0.get("fixed_T", 32)))
        if "frame_size" in m0:
            self.frame_size = tuple(m0["frame_size"])
        elif "video_size" in m0:
            H, W = m0["video_size"]
            self.frame_size = (W, H)
        else:
            self.frame_size = (224, 224)

    def __len__(self):
        return len(self.meta)

    def _load_audio(self, audio_path: str) -> torch.Tensor:
        wav, sr = torchaudio.load(audio_path)
        if wav.shape[0] > 1:
            wav = wav.mean(dim=0, keepdim=True)
        if sr != self.audio_target_sr:
            wav = torchaudio.transforms.Resample(sr, self.audio_target_sr)(wav)
        return wav.squeeze(0)

    def _load_video_npz(self, npz_path: str) -> Tuple[torch.Tensor, np.ndarray]:
        data = np.load(npz_path)
        frames = data["frames"]
        ts = data.get("timestamps", None)
        if ts is None:
            T = frames.shape[0]
            ts = np.linspace(0.0, float(T - 1) / 25.0, num=T, dtype=np.float32)
        if frames.shape[0] > self.video_max_frames:
            idx = uniform_indices(frames.shape[0], self.video_max_frames)
            frames = frames[idx]
            ts = ts[idx]
        frames = frames.astype(np.float32) / 255.0
        tchw = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous()
        return tchw, ts

    def _load_video_frames_dir(self, frames_dir: str) -> torch.Tensor:
        frame_files = sorted(Path(frames_dir).glob("frame_*.npy"))
        if len(frame_files) == 0:
            W, H = self.frame_size
            return torch.zeros((self.video_max_frames, 3, H, W), dtype=torch.float32)
        if len(frame_files) > self.video_max_frames:
            idx = uniform_indices(len(frame_files), self.video_max_frames)
            frame_files = [frame_files[i] for i in idx]
        frames = []
        for f in frame_files:
            arr = np.load(f, mmap_mode="r")
            if arr.ndim != 3 or arr.shape[2] != 3:
                W, H = self.frame_size
                arr = np.zeros((H, W, 3), dtype=np.uint8)
            frames.append(arr.astype(np.float32) / 255.0)
        frames = np.stack(frames, axis=0)
        tchw = torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous()
        return tchw

    def __getitem__(self, idx: int) -> Dict:
        rec = self.meta[idx]
        out = {
            "sample_id": rec["sample_id"],
            "emotion_label": EMOTION_TO_ID.get(rec.get("emotion", ""), -1),
            "meta": rec,
        }
        if self.load_audio:
            out["audio"] = self._load_audio(rec["audio_path"])
            out["sample_rate"] = self.audio_target_sr
        if self.load_video:
            if self.uses_npz:
                v, ts = self._load_video_npz(rec["video_npz"])
                out["video"] = v
                out["timestamps"] = torch.from_numpy(ts)
            else:
                v = self._load_video_frames_dir(rec["video_frames_dir"])
                out["video"] = v
                fps = rec.get("target_fps", rec.get("original_fps", 25.0))
                T = v.shape[0]
                ts = np.arange(T, dtype=np.float32) / float(fps)
                out["timestamps"] = torch.from_numpy(ts)
        return out


def emotion_collate(batch: List[Dict]) -> Dict:
    out: Dict[str, Union[List, torch.Tensor]] = {
        "sample_id": [b["sample_id"] for b in batch],
        "emotion_label": torch.tensor([b["emotion_label"] for b in batch], dtype=torch.long),
        "meta": [b["meta"] for b in batch],
    }
    if "audio" in batch[0]:
        out["audio"] = [b["audio"] for b in batch]
        out["sample_rate"] = batch[0]["sample_rate"]
    if "video" in batch[0]:
        out["video"] = torch.stack([b["video"] for b in batch], dim=0)
        out["timestamps"] = [b["timestamps"] for b in batch]
    return out


# Window cropping helpers
def crop_audio_random(wav_1d: torch.Tensor, sr: int, dur_s: float) -> torch.Tensor:
    n = wav_1d.numel()
    L = int(round(dur_s * sr))
    if n <= L:
        pad_val = wav_1d[-1] if n > 0 else torch.tensor(0.0, device=wav_1d.device)
        pad = pad_val.repeat(L - n)
        return torch.cat([wav_1d, pad], 0)
    start = torch.randint(0, n - L + 1, ()).item()
    return wav_1d[start:start + L]


def crop_audio_center(wav_1d: torch.Tensor, sr: int, dur_s: float) -> torch.Tensor:
    n = wav_1d.numel()
    L = int(round(dur_s * sr))
    if n <= L:
        pad_val = wav_1d[-1] if n > 0 else torch.tensor(0.0, device=wav_1d.device)
        pad = pad_val.repeat(L - n)
        return torch.cat([wav_1d, pad], 0)
    start = max(0, (n - L) // 2)
    return wav_1d[start:start + L]


def crop_video_random_T(video_TCHW: torch.Tensor, Ts: int) -> torch.Tensor:
    T = video_TCHW.shape[0]
    if T <= Ts:
        idx = torch.linspace(0, T - 1, Ts).round().long()
        return video_TCHW[idx]
    start = torch.randint(0, T - Ts + 1, ()).item()
    return video_TCHW[start:start + Ts]


def crop_video_center_T(video_TCHW: torch.Tensor, Ts: int) -> torch.Tensor:
    T = video_TCHW.shape[0]
    if T <= Ts:
        idx = torch.linspace(0, T - 1, Ts).round().long()
        return video_TCHW[idx]
    start = (T - Ts) // 2
    return video_TCHW[start:start + Ts]


# Audio encoder
class AudioEmotionEncoder:
    def __init__(
        self,
        model_name: str = "superb/wav2vec2-base-superb-er",
        num_emotions: int = 8,
        lr: float = 1e-5,
        device: Optional[str] = None,
        window_seconds: float = 1.5,
        grad_clip: float = 1.0,
        use_amp: bool = True,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.num_emotions = num_emotions
        self.window_seconds = float(window_seconds)
        self.grad_clip = float(grad_clip)
        self.use_amp = bool(use_amp)
        self.lr = lr

        if "hubert" in model_name.lower():
            self.model = HubertForSequenceClassification.from_pretrained(
                model_name, num_labels=num_emotions, ignore_mismatched_sizes=True
            )
        else:
            self.model = Wav2Vec2ForSequenceClassification.from_pretrained(
                model_name, num_labels=num_emotions, ignore_mismatched_sizes=True
            )

        self.model.config.output_hidden_states = True
        self.model.to(self.device)

        # Try load feature extractor; if model_name is a local dir without files this may throw.
        try:
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
        except Exception:
            # fallback to a default extractor so validation/inference works
            self.processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/wav2vec2-base-superb-er")

        # create optimizer only from trainable params
        self.optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr)
        self.crit = nn.CrossEntropyLoss()
        self.emotion_names = EMOTION_NAMES

        # proper GradScaler init
        self.scaler = GradScaler(enabled=(self.use_amp and torch.cuda.is_available()))

    def update_optimizer(self, lr: Optional[float] = None):
        """Re-build the optimizer from current trainable parameters (call after unfreezing)."""
        if lr is not None:
            self.lr = lr
        self.optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr)

    def save(self, path: Union[str, Path]):
        """Save model + processor so HF's from_pretrained(path) works later."""
        p = Path(path)
        p.mkdir(parents=True, exist_ok=True)
        self.model.save_pretrained(str(p))
        try:
            # processor may be Wav2Vec2FeatureExtractor or Processor
            self.processor.save_pretrained(str(p))
        except Exception:
            # best effort
            pass

    def _prepare(self, batch: Dict, train: bool = True):
        sr = batch["sample_rate"]
        audios = []
        for a in batch["audio"]:
            a = a.to(self.device)
            seg = crop_audio_random(a, sr, self.window_seconds) if train else crop_audio_center(a, sr, self.window_seconds)
            audios.append(seg.cpu().numpy())
        proc = self.processor(
            audios, sampling_rate=sr, return_tensors="pt",
            padding=True, truncation=True, max_length=int(self.window_seconds * sr)
        )
        x = proc["input_values"].to(self.device)
        m = proc.get("attention_mask")
        m = m.to(self.device) if m is not None else None
        y = batch["emotion_label"].to(self.device)
        return x, m, y

    def train_epoch(self, loader: DataLoader) -> Dict[str, float]:
        self.model.train()
        total, preds_all, labels_all = 0.0, [], []
        for batch in tqdm(loader, desc="Training (Audio)"):
            x, m, y = self._prepare(batch, train=True)
            self.optim.zero_grad(set_to_none=True)
            with autocast("cuda", enabled=self.use_amp):
                out = self.model(input_values=x, attention_mask=m)
                loss = self.crit(out.logits, y)
            self.scaler.scale(loss).backward()
            if self.grad_clip is not None:
                self.scaler.unscale_(self.optim)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.scaler.step(self.optim)
            self.scaler.update()
            total += loss.item()
            preds_all.extend(out.logits.argmax(dim=1).detach().cpu().numpy())
            labels_all.extend(y.detach().cpu().numpy())
        return {
            "loss": total / len(loader),
            "accuracy": accuracy_score(labels_all, preds_all),
            "f1_score": f1_score(labels_all, preds_all, average="weighted")
        }

    @torch.no_grad()
    def validate(self, loader: DataLoader) -> Dict[str, float]:
        self.model.eval()
        total, preds_all, labels_all = 0.0, [], []
        for batch in tqdm(loader, desc="Validation (Audio)"):
            x, m, y = self._prepare(batch, train=False)
            with autocast("cuda", enabled=self.use_amp):
                out = self.model(input_values=x, attention_mask=m)
                loss = self.crit(out.logits, y)
            total += loss.item()
            preds_all.extend(out.logits.argmax(dim=1).detach().cpu().numpy())
            labels_all.extend(y.detach().cpu().numpy())
        cm = confusion_matrix(labels_all, preds_all)
        return {
            "loss": total / len(loader),
            "accuracy": accuracy_score(labels_all, preds_all),
            "f1_score": f1_score(labels_all, preds_all, average="weighted"),
            "confusion_matrix": cm,
            "predictions": preds_all,
            "labels": labels_all
        }

    @torch.no_grad()
    def extract_embeddings_clip(self, audios_1d: List[torch.Tensor], sr: int = 16000, window_seconds: float = 1.5) -> torch.Tensor:
        self.model.eval()
        crops = [crop_audio_center(a.to(self.device), sr, window_seconds).cpu().numpy() for a in audios_1d]
        proc = self.processor(crops, sampling_rate=sr, return_tensors="pt", padding=True, truncation=True,
                              max_length=int(window_seconds * sr))
        x = proc["input_values"].to(self.device)
        m = proc.get("attention_mask"); m = m.to(self.device) if m is not None else None
        out = self.model(input_values=x, attention_mask=m, output_hidden_states=True)
        last = getattr(out, "hidden_states", None)
        last = last[-1] if last is not None else getattr(out, "last_hidden_state")
        return last.mean(dim=1)  # (B, D)

    @torch.no_grad()
    def extract_embeddings_window(self, audio_1d: torch.Tensor, sr: int, t0: float, t1: float) -> torch.Tensor:
        self.model.eval()
        start = int(max(0, round(t0 * sr)))
        end = int(max(start + 1, round(t1 * sr)))
        seg = audio_1d[start:end]
        L = max(1, end - start)
        if seg.numel() < L:
            pad_val = seg[-1] if seg.numel() > 0 else torch.tensor(0.0, device=audio_1d.device)
            seg = torch.cat([seg, pad_val.repeat(L - seg.numel())], 0)
        proc = self.processor([seg.cpu().numpy()], sampling_rate=sr, return_tensors="pt", padding=True, truncation=True, max_length=L)
        x = proc["input_values"].to(self.device)
        m = proc.get("attention_mask"); m = m.to(self.device) if m is not None else None
        out = self.model(input_values=x, attention_mask=m, output_hidden_states=True)
        last = getattr(out, "hidden_states", None); last = last[-1] if last is not None else getattr(out, "last_hidden_state")
        return last.mean(dim=1)


# Video encoder (checkpointing disabled by default to avoid tuple hook crash)
class VideoEmotionEncoder:
    def __init__(
        self,
        model_name: str = "facebook/timesformer-base-finetuned-k400",
        num_emotions: int = 8,
        lr: float = 1e-5,
        frames_for_model: int = 16,
        device: Optional[str] = None,
        grad_clip: float = 1.0,
        use_amp: bool = True,
        use_checkpoint: bool = False,
    ):
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.num_emotions = num_emotions
        self.frames_for_model = int(frames_for_model)
        self.grad_clip = float(grad_clip)
        self.use_amp = bool(use_amp)
        self.lr = lr

        self.model = TimesformerForVideoClassification.from_pretrained(
            model_name, num_labels=num_emotions, ignore_mismatched_sizes=True
        )
        self.model.config.output_hidden_states = True

        if use_checkpoint and hasattr(self.model, "gradient_checkpointing_enable"):
            try:
                self.model.gradient_checkpointing_enable()
            except Exception:
                pass

        self.model.to(self.device)

        # Robust processor loading
        try:
            self.processor = AutoImageProcessor.from_pretrained(model_name)
        except Exception:
            # fallback (the HF model should usually have a processor)
            self.processor = AutoImageProcessor.from_pretrained("facebook/timesformer-base-finetuned-k400")

        self.optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr)
        self.crit = nn.CrossEntropyLoss()
        self.emotion_names = EMOTION_NAMES
        self.scaler = GradScaler(enabled=(self.use_amp and torch.cuda.is_available()))

    def update_optimizer(self, lr: Optional[float] = None):
        if lr is not None:
            self.lr = lr
        self.optim = torch.optim.AdamW(filter(lambda p: p.requires_grad, self.model.parameters()), lr=self.lr)

    def save(self, path: Union[str, Path]):
        p = Path(path)
        p.mkdir(parents=True, exist_ok=True)
        self.model.save_pretrained(str(p))
        try:
            self.processor.save_pretrained(str(p))
        except Exception:
            pass

    def _prepare(self, batch: Dict, train: bool = True):
        vids = []
        Ts = self.frames_for_model
        for v in batch["video"]:
            s = crop_video_random_T(v, Ts) if train else crop_video_center_T(v, Ts)
            frames = [s[i].permute(1, 2, 0).cpu().numpy() for i in range(s.shape[0])]
            vids.append(frames)
        proc = self.processor(vids, return_tensors="pt", do_rescale=False)  # frames already in [0,1]
        x = proc["pixel_values"].to(self.device)
        y = batch["emotion_label"].to(self.device)
        return x, y

    def train_epoch(self, loader: DataLoader) -> Dict[str, float]:
        self.model.train()
        total, preds_all, labels_all = 0.0, [], []
        for batch in tqdm(loader, desc="Training (Video)"):
            x, y = self._prepare(batch, train=True)
            self.optim.zero_grad(set_to_none=True)
            with autocast("cuda", enabled=self.use_amp):
                out = self.model(pixel_values=x)
                loss = self.crit(out.logits, y)
            self.scaler.scale(loss).backward()
            if self.grad_clip is not None:
                self.scaler.unscale_(self.optim)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
            self.scaler.step(self.optim)
            self.scaler.update()
            total += loss.item()
            preds_all.extend(out.logits.argmax(dim=1).detach().cpu().numpy())
            labels_all.extend(y.detach().cpu().numpy())
        return {
            "loss": total / len(loader),
            "accuracy": accuracy_score(labels_all, preds_all),
            "f1_score": f1_score(labels_all, preds_all, average="weighted"),
        }

    @torch.no_grad()
    def validate(self, loader: DataLoader) -> Dict[str, float]:
        self.model.eval()
        total, preds_all, labels_all = 0.0, [], []
        for batch in tqdm(loader, desc="Validation (Video)"):
            x, y = self._prepare(batch, train=False)
            with autocast("cuda", enabled=self.use_amp):
                out = self.model(pixel_values=x)
                loss = self.crit(out.logits, y)
            total += loss.item()
            preds_all.extend(out.logits.argmax(dim=1).detach().cpu().numpy())
            labels_all.extend(y.detach().cpu().numpy())
        cm = confusion_matrix(labels_all, preds_all)
        return {
            "loss": total / len(loader),
            "accuracy": accuracy_score(labels_all, preds_all),
            "f1_score": f1_score(labels_all, preds_all, average="weighted"),
            "confusion_matrix": cm,
            "predictions": preds_all,
            "labels": labels_all
        }

    @torch.no_grad()
    def extract_embeddings_clip(self, video_TCHW: torch.Tensor, frames_for_model: Optional[int] = None) -> torch.Tensor:
        self.model.eval()
        Ts = frames_for_model or self.frames_for_model
        if video_TCHW.dim() == 4:
            video_TCHW = video_TCHW.unsqueeze(0)
        batch_embs = []
        for v in video_TCHW:
            s = crop_video_center_T(v, Ts)
            frames = [s[i].permute(1, 2, 0).cpu().numpy() for i in range(s.shape[0])]
            proc = self.processor([frames], return_tensors="pt", do_rescale=False)
            x = proc["pixel_values"].to(self.device)
            out = self.model(pixel_values=x, output_hidden_states=True)
            hs = getattr(out, "hidden_states", None)
            last = hs[-1] if hs is not None else getattr(out, "last_hidden_state")
            emb = last.mean(dim=1)
            batch_embs.append(emb)
        return torch.cat(batch_embs, dim=0)

    @torch.no_grad()
    def extract_embeddings_window_from_npz(self, npz_path: str, t0: float, t1: float, Ts: Optional[int] = None) -> torch.Tensor:
        self.model.eval()
        Ts = Ts or self.frames_for_model
        data = np.load(npz_path)
        frames = data["frames"]
        ts = data["timestamps"].astype(np.float32) if "timestamps" in data else np.arange(frames.shape[0], dtype=np.float32) / 25.0
        mask = (ts >= t0) & (ts <= t1)
        sub = frames[mask]
        if sub.shape[0] == 0:
            center = 0.5 * (t0 + t1)
            idx = int(np.argmin(np.abs(ts - center)))
            sub = frames[idx:idx+1]
        idx = uniform_indices(sub.shape[0], Ts)
        sub = sub[idx].astype(np.float32) / 255.0
        frames_list = [sub[i] for i in range(sub.shape[0])]
        proc = self.processor([frames_list], return_tensors="pt", do_rescale=False)
        x = proc["pixel_values"].to(self.device)
        out = self.model(pixel_values=x, output_hidden_states=True)
        hs = getattr(out, "hidden_states", None)
        last = hs[-1] if hs is not None else getattr(out, "last_hidden_state")
        return last.mean(dim=1)


# Trainer
def train_encoders(
    metadata_path: str,
    output_dir: str,
    audio_model: str = "superb/wav2vec2-base-superb-er",
    video_model: str = "facebook/timesformer-base-finetuned-k400",
    num_epochs: int = 20,
    batch_size: int = 4,
    val_split: float = 0.2,
    audio_window_s: float = 1.5,
    video_Ts: int = 16,
    video_max_frames: int = 64,
    use_wandb: bool = True,
    seed: int = 42,
    audio_freeze_epochs: int = 2,
    video_freeze_epochs: int = 1,
):
    set_seed(seed)
    ensure_dir(output_dir)

    WANDB = False
    if use_wandb:
        try:
            import wandb
            wandb.init(project="almost-human-encoders", config=dict(
                audio_model=audio_model, video_model=video_model, num_epochs=num_epochs,
                batch_size=batch_size, val_split=val_split, audio_window_s=audio_window_s,
                video_Ts=video_Ts, seed=seed
            ))
            WANDB = True
        except Exception as e:
            print(f"⚠ W&B init failed: {e}")
            WANDB = False

    base = EmotionDataset(metadata_path, video_max_frames=video_max_frames, load_audio=True, load_video=True)
    N = len(base)
    val_size = int(N * val_split)
    train_size = N - val_size
    indices = torch.randperm(N)
    train_idx, val_idx = indices[:train_size], indices[train_size:]

    ds_audio_train = Subset(EmotionDataset(metadata_path, video_max_frames=video_max_frames, load_audio=True, load_video=False), train_idx)
    ds_audio_val   = Subset(EmotionDataset(metadata_path, video_max_frames=video_max_frames, load_audio=True, load_video=False), val_idx)
    ds_video_train = Subset(EmotionDataset(metadata_path, video_max_frames=video_max_frames, load_audio=False, load_video=True), train_idx)
    ds_video_val   = Subset(EmotionDataset(metadata_path, video_max_frames=video_max_frames, load_audio=False, load_video=True), val_idx)

    train_loader_audio = DataLoader(ds_audio_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=emotion_collate)
    val_loader_audio   = DataLoader(ds_audio_val,   batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=emotion_collate)
    train_loader_video = DataLoader(ds_video_train, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=False, collate_fn=emotion_collate)
    val_loader_video   = DataLoader(ds_video_val,   batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=False, collate_fn=emotion_collate)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Device: {device} | Train: {train_size} | Val: {val_size}")

    audio_enc = AudioEmotionEncoder(model_name=audio_model, device=device, window_seconds=audio_window_s, use_amp=True)
    video_enc = VideoEmotionEncoder(model_name=video_model, device=device, frames_for_model=video_Ts, use_amp=True, use_checkpoint=False)

    safe_freeze_wav2vec_feature_encoder(audio_enc.model)
    set_backbone_trainable_timesformer(video_enc.model, trainable=False)

    best_audio_f1 = 0.0
    best_video_f1 = 0.0

    for epoch in range(num_epochs):
        print("\n" + "="*60)
        print(f"Epoch {epoch+1}/{num_epochs}")
        print("="*60)

        if epoch == audio_freeze_epochs:
            safe_unfreeze_wav2vec_feature_encoder(audio_enc.model)
            print("→ Unfroze Wav2Vec2/HubERT feature encoder")
        if epoch == video_freeze_epochs:
            set_backbone_trainable_timesformer(video_enc.model, trainable=True)
            print("→ Unfroze TimeSformer backbone")

        a_train = audio_enc.train_epoch(train_loader_audio)
        a_val = audio_enc.validate(val_loader_audio)
        print(f"[Audio] Train: loss={a_train['loss']:.4f} acc={a_train['accuracy']:.4f} f1={a_train['f1_score']:.4f}")
        print(f"[Audio]   Val: loss={a_val['loss']:.4f} acc={a_val['accuracy']:.4f} f1={a_val['f1_score']:.4f}")

        v_train = video_enc.train_epoch(train_loader_video)
        v_val = video_enc.validate(val_loader_video)
        print(f"[Video] Train: loss={v_train['loss']:.4f} acc={v_train['accuracy']:.4f} f1={v_train['f1_score']:.4f}")
        print(f"[Video]   Val: loss={v_val['loss']:.4f} acc={v_val['accuracy']:.4f} f1={v_val['f1_score']:.4f}")

        if WANDB:
            wandb.log({
                "epoch": epoch + 1,
                "audio/train_loss": a_train["loss"], "audio/train_acc": a_train["accuracy"], "audio/train_f1": a_train["f1_score"],
                "audio/val_loss": a_val["loss"], "audio/val_acc": a_val["accuracy"], "audio/val_f1": a_val["f1_score"],
                "video/train_loss": v_train["loss"], "video/train_acc": v_train["accuracy"], "video/train_f1": v_train["f1_score"],
                "video/val_loss": v_val["loss"], "video/val_acc": v_val["accuracy"], "video/val_f1": v_val["f1_score"],
            })

        if a_val["f1_score"] > best_audio_f1:
            best_audio_f1 = a_val["f1_score"]
            save_path = Path(output_dir) / "best_audio_encoder"
            audio_enc.model.save_pretrained(str(save_path))
            print(f"✓ Saved best audio encoder → {save_path} (F1={best_audio_f1:.4f})")

        if v_val["f1_score"] > best_video_f1:
            best_video_f1 = v_val["f1_score"]
            save_path = Path(output_dir) / "best_video_encoder"
            video_enc.model.save_pretrained(str(save_path))
            print(f"✓ Saved best video encoder → {save_path} (F1={best_video_f1:.4f})")

    print("\n" + "="*60)
    print("Training complete!")
    print(f"Best Audio F1: {best_audio_f1:.4f} | Best Video F1: {best_video_f1:.4f}")
    print("="*60)

    if WANDB:
        wandb.finish()

    return audio_enc, video_enc, best_audio_f1, best_video_f1


In [17]:
"""
Integration Guide: Using Cross-Modal Emotion Loss with Trained Encoders

This script shows how to:
1. Load trained emotion encoders
2. Compute emotion loss during inference
3. Evaluate emotion agreement on validation set
4. Integrate with talking face models (Wav2Lip/SadTalker)
"""

import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Tuple
import json



# ============================================================================
# Step 1: Load Trained Encoders
# ============================================================================

def load_trained_encoders(
    audio_encoder_path: str,
    video_encoder_path: str,
    device: str = 'cuda',
) -> Tuple[AudioEmotionEncoder, VideoEmotionEncoder]:
    """
    Load pre-trained emotion encoders.

    Args:
        audio_encoder_path: Path to trained audio encoder
        video_encoder_path: Path to trained video encoder
        device: Device to load models on

    Returns:
        audio_encoder: Loaded audio encoder
        video_encoder: Loaded video encoder
    """
    print("Loading trained emotion encoders...")

    # Initialize encoders with pretrained weights
    audio_enc = AudioEmotionEncoder(
        model_name=audio_encoder_path,
        num_emotions=8,
        device=device,
        window_seconds=1.5,
        use_amp=False,  # Disable AMP for inference
    )

    video_enc = VideoEmotionEncoder(
        model_name=video_encoder_path,
        num_emotions=8,
        frames_for_model=16,
        device=device,
        use_amp=False,  # Disable AMP for inference
        use_checkpoint=False,
    )

    # Set to eval mode
    audio_enc.model.eval()
    video_enc.model.eval()

    print(f"✓ Audio encoder loaded from: {audio_encoder_path}")
    print(f"✓ Video encoder loaded from: {video_encoder_path}")

    return audio_enc, video_enc


# ============================================================================
# Step 2: Compute Emotion Loss for a Batch
# ============================================================================

@torch.no_grad()
def compute_emotion_alignment(
    batch: Dict,
    audio_encoder: AudioEmotionEncoder,
    video_encoder: VideoEmotionEncoder,
    loss_fn: CrossModalEmotionLoss,
    device: str = 'cuda',
) -> Dict[str, float]:
    """
    Compute emotion alignment metrics for a batch.

    Args:
        batch: Batch from dataloader
        audio_encoder: Trained audio encoder
        video_encoder: Trained video encoder
        loss_fn: Emotion loss function
        device: Device to use

    Returns:
        metrics: Dictionary with loss and alignment metrics
    """
    audio_encoder.model.eval()
    video_encoder.model.eval()

    # Extract audio embeddings
    audio_embs = audio_encoder.extract_embeddings_clip(
        batch['audio'],
        sr=batch.get('sample_rate', 16000),
        window_seconds=1.5
    )

    # Extract video embeddings
    video_embs = video_encoder.extract_embeddings_clip(
        batch['video'],
        frames_for_model=16
    )

    # Compute loss
    loss = loss_fn(audio_embs, video_embs)

    # Compute cosine similarity
    audio_norm = torch.nn.functional.normalize(audio_embs, p=2, dim=-1)
    video_norm = torch.nn.functional.normalize(video_embs, p=2, dim=-1)
    cos_sim = torch.nn.functional.cosine_similarity(audio_norm, video_norm, dim=-1)

    metrics = {
        'emotion_loss': loss.item(),
        'avg_cosine_sim': cos_sim.mean().item(),
        'min_cosine_sim': cos_sim.min().item(),
        'max_cosine_sim': cos_sim.max().item(),
        'std_cosine_sim': cos_sim.std().item(),
    }

    return metrics


# ============================================================================
# Step 3: Evaluate on Validation Set
# ============================================================================

def evaluate_emotion_agreement(
    metadata_path: str,
    audio_encoder: AudioEmotionEncoder,
    video_encoder: VideoEmotionEncoder,
    batch_size: int = 8,
    device: str = 'cuda',
) -> Dict[str, float]:
    """
    Evaluate emotion agreement on entire validation set.

    Args:
        metadata_path: Path to metadata.json
        audio_encoder: Trained audio encoder
        video_encoder: Trained video encoder
        batch_size: Batch size for evaluation
        device: Device to use

    Returns:
        results: Dictionary with evaluation metrics
    """
    from torch.utils.data import DataLoader

    print("\n" + "="*60)
    print("Evaluating Emotion Agreement on Validation Set")
    print("="*60)

    # Load dataset
    dataset = EmotionDataset(
        metadata_path,
        video_max_frames=64,
        load_audio=True,
        load_video=True,
    )

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=0,
        collate_fn=emotion_collate,
    )

    # Initialize metric
    metric = EmotionAgreementMetric(num_classes=8, threshold=0.8)

    # Initialize loss
    loss_fn = CrossModalEmotionLoss(
        loss_type='cosine',
        weight=1.0,
        normalize_embeddings=True,
    )

    total_loss = 0.0
    num_batches = 0

    print(f"\nProcessing {len(dataset)} samples...")

    with torch.no_grad():
        for batch in loader:
            # Extract embeddings
            audio_embs = audio_encoder.extract_embeddings_clip(
                batch['audio'],
                sr=batch.get('sample_rate', 16000),
                window_seconds=1.5
            )

            video_embs = video_encoder.extract_embeddings_clip(
                batch['video'],
                frames_for_model=16
            )

            # Update metric
            metric.update(audio_embs, video_embs)

            # Compute loss
            loss = loss_fn(audio_embs, video_embs)
            total_loss += loss.item()
            num_batches += 1

    # Compute final metrics
    results = metric.compute()
    results['avg_loss'] = total_loss / num_batches

    print(f"\nResults:")
    print(f"  Agreement Rate: {results['agreement_rate']:.2%}")
    print(f"  Avg Cosine Sim: {results['avg_cosine_sim']:.4f}")
    print(f"  Class Agreement: {results['class_agreement']:.2%}")
    print(f"  Avg Loss: {results['avg_loss']:.4f}")

    return results


# ============================================================================
# Step 4: Integration with Talking Face Model
# ============================================================================

class TalkingFaceWithEmotionLoss(nn.Module):
    """
    Wrapper for talking face model (Wav2Lip/SadTalker) with emotion loss.

    This is a template showing how to integrate emotion loss into
    your talking face model training.
    """

    def __init__(
        self,
        base_model: nn.Module,  # Your Wav2Lip or SadTalker model
        audio_encoder: AudioEmotionEncoder,
        video_encoder: VideoEmotionEncoder,
        emotion_loss_weight: float = 0.1,
        freeze_encoders: bool = True,
    ):
        """
        Args:
            base_model: Base talking face model (Wav2Lip/SadTalker)
            audio_encoder: Trained audio emotion encoder
            video_encoder: Trained video emotion encoder
            emotion_loss_weight: Weight for emotion loss
            freeze_encoders: Whether to freeze encoder weights
        """
        super().__init__()

        self.base_model = base_model
        self.audio_encoder = audio_encoder
        self.video_encoder = video_encoder
        self.emotion_loss_weight = emotion_loss_weight

        # Freeze emotion encoders (they're pre-trained)
        if freeze_encoders:
            for param in self.audio_encoder.model.parameters():
                param.requires_grad = False
            for param in self.video_encoder.model.parameters():
                param.requires_grad = False

        # Initialize emotion loss
        self.emotion_loss_fn = CrossModalEmotionLoss(
            loss_type='cosine',
            weight=self.emotion_loss_weight,
            normalize_embeddings=True,
        )

    def forward(
        self,
        audio: torch.Tensor,
        reference_frame: torch.Tensor,
        **kwargs
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        """
        Forward pass with emotion loss.

        Args:
            audio: Input audio waveform
            reference_frame: Reference face image
            **kwargs: Additional arguments for base model

        Returns:
            generated_video: Generated video frames
            losses: Dictionary of loss components
        """
        # Generate video with base model
        generated_video = self.base_model(audio, reference_frame, **kwargs)

        # Compute emotion loss
        # Extract audio emotion embeddings
        audio_embs = self.audio_encoder.extract_embeddings_clip(
            [audio],
            sr=16000,
            window_seconds=1.5
        )

        # Extract video emotion embeddings from generated frames
        video_embs = self.video_encoder.extract_embeddings_clip(
            generated_video,
            frames_for_model=16
        )

        # Compute emotion loss
        emotion_loss = self.emotion_loss_fn(audio_embs, video_embs)

        losses = {
            'emotion_loss': emotion_loss,
            # Add your base model losses here (reconstruction, adversarial, etc.)
        }

        return generated_video, losses


# ============================================================================
# Step 5: Training Loop Example
# ============================================================================

def training_step_with_emotion_loss(
    model: TalkingFaceWithEmotionLoss,
    batch: Dict,
    optimizer: torch.optim.Optimizer,
    reconstruction_loss_fn: nn.Module,
    device: str = 'cuda',
) -> Dict[str, float]:
    """
    Example training step with emotion loss.

    Args:
        model: Talking face model with emotion loss
        batch: Training batch
        optimizer: Optimizer
        reconstruction_loss_fn: Your base reconstruction loss
        device: Device to use

    Returns:
        metrics: Dictionary with loss values
    """
    model.train()

    audio = batch['audio']
    reference_frame = batch['reference_frame']  # First frame or specific ref
    target_video = batch['video']

    # Forward pass
    generated_video, losses = model(audio, reference_frame)

    # Compute reconstruction loss (your original loss)
    recon_loss = reconstruction_loss_fn(generated_video, target_video)

    # Total loss = reconstruction + emotion
    total_loss = recon_loss + losses['emotion_loss']

    # Backward pass
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # Return metrics
    metrics = {
        'total_loss': total_loss.item(),
        'recon_loss': recon_loss.item(),
        'emotion_loss': losses['emotion_loss'].item(),
    }

    return metrics




In [18]:
from transformers import (
    Wav2Vec2ForSequenceClassification,
    HubertForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    AutoImageProcessor,
    TimesformerForVideoClassification,
)

In [19]:
import torch
print(torch.version.cuda)
print(torch.cuda.is_available())

12.6
False


In [20]:
# ============================================================================
# Main Example Usage
# ============================================================================

def main():
    """
    Main example showing complete workflow.
    """
    print("="*60)
    print("Cross-Modal Emotion Loss - Complete Integration Example")
    print("="*60)

    # Paths
    audio_encoder_path = "/content/trained_encoders/best_audio_encoder"
    video_encoder_path = "/content/trained_encoders/best_video_encoder"
    metadata_path = "/content/processed_data/metadata.json"

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"\nUsing device: {device}")

    # Step 1: Load encoders
    audio_enc, video_enc = load_trained_encoders(
        audio_encoder_path,
        video_encoder_path,
        device=device
    )

    # Step 2: Test on a single batch
    print("\n" + "="*60)
    print("Testing on Single Batch")
    print("="*60)

    from torch.utils.data import DataLoader

    dataset = EmotionDataset(
        metadata_path,
        video_max_frames=64,
        load_audio=True,
        load_video=True,
    )

    loader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=emotion_collate,
    )

    batch = next(iter(loader))

    loss_fn = CrossModalEmotionLoss(
        loss_type='cosine',
        weight=1.0,
        normalize_embeddings=True,
    )

    metrics = compute_emotion_alignment(
        batch,
        audio_enc,
        video_enc,
        loss_fn,
        device=device
    )

    print(f"\nSingle Batch Metrics:")
    for key, value in metrics.items():
        print(f"  {key}: {value:.4f}")

    # Step 3: Evaluate on full validation set
    results = evaluate_emotion_agreement(
        metadata_path,
        audio_enc,
        video_enc,
        batch_size=8,
        device=device
    )

    # Step 4: Save results
    output_dir = Path("/content/emotion_evaluation")
    output_dir.mkdir(parents=True, exist_ok=True)

    with open(output_dir / "emotion_agreement_results.json", "w") as f:
        json.dump(results, f, indent=2)

    print(f"\n✓ Results saved to: {output_dir / 'emotion_agreement_results.json'}")

    # Step 5: Example integration with talking face model
    print("\n" + "="*60)
    print("Integration with Talking Face Model")
    print("="*60)
    print("\nTo integrate with Wav2Lip/SadTalker:")
    print("1. Wrap your base model with TalkingFaceWithEmotionLoss")
    print("2. Add emotion_loss to your total loss")
    print("3. Use training_step_with_emotion_loss() in your training loop")
    print("\nExample:")
    print("""
    # Pseudo-code
    base_model = Wav2Lip()  # or SadTalker()
    model = TalkingFaceWithEmotionLoss(
        base_model,
        audio_enc,
        video_enc,
        emotion_loss_weight=0.1
    )

    for batch in train_loader:
        metrics = training_step_with_emotion_loss(
            model, batch, optimizer, reconstruction_loss_fn
        )
    """)

    print("\n✓ Complete integration example finished!")


if __name__ == "__main__":
    main()

Cross-Modal Emotion Loss - Complete Integration Example

Using device: cpu
Loading trained emotion encoders...
✓ Audio encoder loaded from: /content/trained_encoders/best_audio_encoder
✓ Video encoder loaded from: /content/trained_encoders/best_video_encoder

Testing on Single Batch

Single Batch Metrics:
  emotion_loss: 1.0330
  avg_cosine_sim: -0.0330
  min_cosine_sim: -0.0446
  max_cosine_sim: -0.0114
  std_cosine_sim: 0.0147

Evaluating Emotion Agreement on Validation Set

Processing 720 samples...


KeyboardInterrupt: 