In [None]:
!pip install -q transformers

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import (
    Wav2Vec2ForSequenceClassification,
    HubertForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    TimesformerForVideoClassification,
)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMOTIONS = ["happy", "angry", "disgust"]
NUM_EMOTIONS = 3

BEST_AUDIO_PATH = "/content/trained_encoders_v2/w2v2-lg-lr2e5"
BEST_VIDEO_PATH = "/content/trained_encoders_v2/tsf-lr3e5-16f-nf"

print(f"Device: {DEVICE}")

In [None]:
class CrossModalEmotionLoss(nn.Module):
    def __init__(self, weight=1.0):
        super().__init__()
        self.weight = weight

    def forward(self, audio_emb, video_emb):
        a = F.normalize(audio_emb, p=2, dim=-1)
        v = F.normalize(video_emb, p=2, dim=-1)
        return self.weight * (1.0 - F.cosine_similarity(a, v, dim=-1)).mean()


class DifferentiableVideoPreprocess(nn.Module):
    """Replaces HF AutoImageProcessor with differentiable PyTorch ops.
    Allows gradient to flow from emotion loss back through generated frames."""

    def __init__(self, size=224):
        super().__init__()
        self.size = size
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, frames):
        """frames: (B, T, C, H, W) in [0, 1] -> (B, T, C, 224, 224) normalized"""
        B, T, C, H, W = frames.shape
        x = frames.reshape(B * T, C, H, W)
        if H != self.size or W != self.size:
            x = F.interpolate(x, size=(self.size, self.size), mode="bilinear", align_corners=False)
        x = (x - self.mean) / self.std
        return x.reshape(B, T, C, self.size, self.size)

In [None]:
def load_frozen_audio_encoder(path, device="cuda"):
    is_hubert = "hubert" in str(path).lower()
    cls = HubertForSequenceClassification if is_hubert else Wav2Vec2ForSequenceClassification
    model = cls.from_pretrained(path)
    model.config.output_hidden_states = True
    model.eval().to(device)
    for p in model.parameters():
        p.requires_grad = False
    processor = Wav2Vec2FeatureExtractor.from_pretrained(path)
    return model, processor


def load_frozen_video_encoder(path, device="cuda"):
    model = TimesformerForVideoClassification.from_pretrained(path)
    model.config.output_hidden_states = True
    model.eval().to(device)
    for p in model.parameters():
        p.requires_grad = False
    return model


@torch.no_grad()
def extract_audio_embedding(model, processor, audio_list, sr=16000, window_s=3.0, device="cuda"):
    """No gradient needed -- audio embedding is the target."""
    L = int(window_s * sr)
    wavs = []
    for a in audio_list:
        n = a.numel()
        if n <= L:
            a = F.pad(a, (0, L - n))
        else:
            start = (n - L) // 2
            a = a[start:start + L]
        wavs.append(a.cpu().numpy())
    enc = processor(wavs, sampling_rate=sr, return_tensors="pt",
                    padding=True, truncation=True, max_length=L)
    x = enc["input_values"].to(device)
    mask = enc.get("attention_mask")
    if mask is not None:
        mask = mask.to(device)
    out = model(input_values=x, attention_mask=mask, output_hidden_states=True)
    return out.hidden_states[-1].mean(dim=1)


def extract_video_embedding(model, preprocess, frames, device="cuda"):
    """Gradient flows through frames back to the generator."""
    # frames: (B, T, C, H, W) in [0, 1]
    pv = preprocess(frames)           # (B, T, C, 224, 224), differentiable
    pv = pv.to(device)
    out = model(pixel_values=pv, output_hidden_states=True)
    return out.hidden_states[-1].mean(dim=1)

In [None]:
class EmotionAgreementMetric:
    def __init__(self, threshold=0.8):
        self.threshold = threshold
        self.reset()

    def reset(self):
        self.sims = []

    def update(self, audio_emb, video_emb):
        a = F.normalize(audio_emb.detach(), p=2, dim=-1)
        v = F.normalize(video_emb.detach(), p=2, dim=-1)
        self.sims.extend(F.cosine_similarity(a, v, dim=-1).cpu().tolist())

    def compute(self):
        s = np.array(self.sims)
        return {
            "avg_cosine_sim": float(s.mean()),
            "agreement_rate": float((s >= self.threshold).mean()),
            "std_cosine_sim": float(s.std()),
        }

In [None]:
loss_fn = CrossModalEmotionLoss(weight=0.1)
a = torch.randn(4, 768)
v = torch.randn(4, 768)
print(f"Emotion loss (random): {loss_fn(a, v).item():.4f}")

preprocess = DifferentiableVideoPreprocess(224).to("cpu")
frames = torch.rand(2, 8, 3, 96, 96, requires_grad=True)
out = preprocess(frames)
loss = out.sum()
loss.backward()
print(f"Preprocess: {frames.shape} -> {out.shape}, grad flows: {frames.grad is not None}")

metric = EmotionAgreementMetric()
metric.update(a, v)
print(f"Agreement: {metric.compute()}")
print("All modules ready.")

In [None]:
%%writefile /content/emotion_utils.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import (
    Wav2Vec2ForSequenceClassification,
    HubertForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    TimesformerForVideoClassification,
)

EMOTIONS = ["happy", "angry", "disgust"]
NUM_EMOTIONS = 3


class CrossModalEmotionLoss(nn.Module):
    def __init__(self, weight=1.0):
        super().__init__()
        self.weight = weight

    def forward(self, audio_emb, video_emb):
        a = F.normalize(audio_emb, p=2, dim=-1)
        v = F.normalize(video_emb, p=2, dim=-1)
        return self.weight * (1.0 - F.cosine_similarity(a, v, dim=-1)).mean()


class DifferentiableVideoPreprocess(nn.Module):
    def __init__(self, size=224):
        super().__init__()
        self.size = size
        self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
        self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, frames):
        B, T, C, H, W = frames.shape
        x = frames.reshape(B * T, C, H, W)
        if H != self.size or W != self.size:
            x = F.interpolate(x, size=(self.size, self.size), mode="bilinear", align_corners=False)
        x = (x - self.mean) / self.std
        return x.reshape(B, T, C, self.size, self.size)


def load_frozen_audio_encoder(path, device="cuda"):
    is_hubert = "hubert" in str(path).lower()
    cls = HubertForSequenceClassification if is_hubert else Wav2Vec2ForSequenceClassification
    model = cls.from_pretrained(path)
    model.config.output_hidden_states = True
    model.eval().to(device)
    for p in model.parameters():
        p.requires_grad = False
    processor = Wav2Vec2FeatureExtractor.from_pretrained(path)
    return model, processor


def load_frozen_video_encoder(path, device="cuda"):
    model = TimesformerForVideoClassification.from_pretrained(path)
    model.config.output_hidden_states = True
    model.eval().to(device)
    for p in model.parameters():
        p.requires_grad = False
    return model


@torch.no_grad()
def extract_audio_embedding(model, processor, audio_list, sr=16000, window_s=3.0, device="cuda"):
    L = int(window_s * sr)
    wavs = []
    for a in audio_list:
        n = a.numel()
        if n <= L:
            a = F.pad(a, (0, L - n))
        else:
            start = (n - L) // 2
            a = a[start:start + L]
        wavs.append(a.cpu().numpy())
    enc = processor(wavs, sampling_rate=sr, return_tensors="pt",
                    padding=True, truncation=True, max_length=L)
    x = enc["input_values"].to(device)
    mask = enc.get("attention_mask")
    if mask is not None:
        mask = mask.to(device)
    out = model(input_values=x, attention_mask=mask, output_hidden_states=True)
    return out.hidden_states[-1].mean(dim=1)


def extract_video_embedding(model, preprocess, frames, device="cuda"):
    pv = preprocess(frames).to(device)
    out = model(pixel_values=pv, output_hidden_states=True)
    return out.hidden_states[-1].mean(dim=1)


class EmotionAgreementMetric:
    def __init__(self, threshold=0.8):
        self.threshold = threshold
        self.reset()

    def reset(self):
        self.sims = []

    def update(self, audio_emb, video_emb):
        a = F.normalize(audio_emb.detach(), p=2, dim=-1)
        v = F.normalize(video_emb.detach(), p=2, dim=-1)
        self.sims.extend(F.cosine_similarity(a, v, dim=-1).cpu().tolist())

    def compute(self):
        s = np.array(self.sims)
        return {
            "avg_cosine_sim": float(s.mean()),
            "agreement_rate": float((s >= self.threshold).mean()),
            "std_cosine_sim": float(s.std()),
        }