In [None]:
!pip install -q transformers

In [None]:
import json
import warnings
from pathlib import Path
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torchaudio
from torch.amp import autocast
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from transformers import (
    Wav2Vec2ForSequenceClassification,
    HubertForSequenceClassification,
    Wav2Vec2FeatureExtractor,
    AutoImageProcessor,
    TimesformerForVideoClassification,
)

warnings.filterwarnings("ignore")

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
METADATA = "/content/processed_data/metadata.json"
ENCODERS_DIR = Path("/content/trained_encoders")

EMOTIONS = ["neutral", "calm", "happy", "sad", "angry", "fearful", "disgust", "surprised"]

BEST_AUDIO = "hubert-lr5e5-w3s"
BEST_VIDEO = "timesformer-lr3e5-16f"

print(f"Device: {DEVICE}")

In [None]:
class EmotionDataset(Dataset):
    def __init__(self, metadata_path: str, split: str, modality: str):
        with open(metadata_path) as f:
            data = json.load(f)
        self.samples = [s for s in data if s["split"] == split]
        self.modality = modality

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        item = {"emotion": s["emotion_idx"]}
        if self.modality == "audio":
            wav, _ = torchaudio.load(s["audio_path"])
            item["audio"] = wav.squeeze(0)
        elif self.modality == "video":
            frames = np.load(s["frames_path"])
            item["video"] = torch.from_numpy(frames).permute(0, 3, 1, 2).float() / 255.0
        return item


def collate_fn(batch):
    out = {"emotion": torch.tensor([b["emotion"] for b in batch])}
    if "audio" in batch[0]:
        out["audio"] = [b["audio"] for b in batch]
    if "video" in batch[0]:
        out["video"] = torch.stack([b["video"] for b in batch])
    return out

In [None]:
def crop_audio(wav, sr, duration):
    L = int(round(duration * sr))
    n = wav.numel()
    if n <= L:
        return torch.nn.functional.pad(wav, (0, L - n))
    start = (n - L) // 2
    return wav[start:start + L]


def crop_video(video, n_frames):
    T = video.shape[0]
    if T <= n_frames:
        idx = torch.linspace(0, T - 1, n_frames).round().long()
        return video[idx]
    start = (T - n_frames) // 2
    return video[start:start + n_frames]


def prepare_audio(batch, processor, window_s, device):
    sr = 16000
    wavs = [crop_audio(a, sr, window_s).numpy() for a in batch["audio"]]
    enc = processor(wavs, sampling_rate=sr, return_tensors="pt", padding=True,
                    truncation=True, max_length=int(window_s * sr))
    kwargs = {"input_values": enc["input_values"].to(device)}
    if "attention_mask" in enc:
        kwargs["attention_mask"] = enc["attention_mask"].to(device)
    return kwargs, batch["emotion"].to(device)


def prepare_video(batch, processor, n_frames, device):
    clips = []
    for v in batch["video"]:
        clip = crop_video(v, n_frames)
        clips.append([clip[i].permute(1, 2, 0).numpy() for i in range(clip.shape[0])])
    enc = processor(clips, return_tensors="pt", do_rescale=False)
    return {"pixel_values": enc["pixel_values"].to(device)}, batch["emotion"].to(device)

In [None]:
@torch.no_grad()
def collect_predictions(model, loader, prep_fn):
    model.eval()
    all_preds, all_labels, all_losses = [], [], []
    for batch in loader:
        kwargs, y = prep_fn(batch)
        with autocast("cuda", enabled=DEVICE == "cuda"):
            logits = model(**kwargs).logits
            loss = nn.CrossEntropyLoss(reduction="none")(logits, y)
        all_preds.extend(logits.argmax(1).cpu().tolist())
        all_labels.extend(y.cpu().tolist())
        all_losses.extend(loss.cpu().tolist())
    return np.array(all_preds), np.array(all_labels), np.array(all_losses)


def per_emotion_report(preds, labels, losses, title):
    print(f"\n{'='*70}")
    print(f"  {title}")
    print(f"{'='*70}\n")

    print(classification_report(
        labels, preds, target_names=EMOTIONS, digits=3, zero_division=0))

    cm = confusion_matrix(labels, preds, labels=list(range(len(EMOTIONS))))

    print(f"\n{'Emotion':<12s} {'N':>5s} {'Acc':>6s} {'F1':>6s} {'AvgLoss':>8s}")
    print("-" * 40)
    per_f1 = f1_score(labels, preds, average=None, labels=list(range(len(EMOTIONS))), zero_division=0)
    for i, emo in enumerate(EMOTIONS):
        mask = labels == i
        n = mask.sum()
        if n == 0:
            continue
        acc = (preds[mask] == i).mean()
        avg_loss = losses[mask].mean()
        print(f"{emo:<12s} {n:5d} {acc:6.3f} {per_f1[i]:6.3f} {avg_loss:8.3f}")

    worst = np.argsort(per_f1)
    print(f"\nWeakest emotions (by F1): ", end="")
    print(" < ".join(f"{EMOTIONS[i]} ({per_f1[i]:.3f})" for i in worst[:3]))

    return cm, per_f1

In [None]:
def plot_confusion_matrix(cm, title):
    import matplotlib.pyplot as plt
    import seaborn as sns

    fig, ax = plt.subplots(figsize=(9, 7))
    row_sums = cm.sum(axis=1, keepdims=True)
    row_sums[row_sums == 0] = 1
    cm_norm = cm / row_sums

    sns.heatmap(cm_norm, annot=True, fmt=".2f", cmap="Blues",
                xticklabels=EMOTIONS, yticklabels=EMOTIONS, ax=ax,
                vmin=0, vmax=1, linewidths=0.5)

    for i in range(len(EMOTIONS)):
        for j in range(len(EMOTIONS)):
            if cm[i, j] > 0:
                ax.text(j + 0.5, i + 0.72, f"({cm[i,j]})",
                        ha="center", va="center", fontsize=7, color="gray")

    ax.set_xlabel("Predicted")
    ax.set_ylabel("True")
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

In [None]:
val_audio = EmotionDataset(METADATA, "val", "audio")
audio_loader = DataLoader(val_audio, batch_size=8, shuffle=False, collate_fn=collate_fn)

model_path = str(ENCODERS_DIR / BEST_AUDIO)
audio_model = HubertForSequenceClassification.from_pretrained(model_path).to(DEVICE)
audio_processor = Wav2Vec2FeatureExtractor.from_pretrained(model_path)
audio_prep = partial(prepare_audio, processor=audio_processor, window_s=3.0, device=DEVICE)

print(f"Loaded: {BEST_AUDIO}")
print(f"Val samples: {len(val_audio)}")

a_preds, a_labels, a_losses = collect_predictions(audio_model, audio_loader, audio_prep)
a_cm, a_f1 = per_emotion_report(a_preds, a_labels, a_losses, f"AUDIO — {BEST_AUDIO}")
plot_confusion_matrix(a_cm, f"Audio Confusion Matrix — {BEST_AUDIO}")

del audio_model
torch.cuda.empty_cache()

In [None]:
val_video = EmotionDataset(METADATA, "val", "video")
video_loader = DataLoader(val_video, batch_size=2, shuffle=False, collate_fn=collate_fn)

model_path = str(ENCODERS_DIR / BEST_VIDEO)
video_model = TimesformerForVideoClassification.from_pretrained(model_path).to(DEVICE)
video_processor = AutoImageProcessor.from_pretrained(model_path)
video_prep = partial(prepare_video, processor=video_processor, n_frames=16, device=DEVICE)

print(f"Loaded: {BEST_VIDEO}")
print(f"Val samples: {len(val_video)}")

v_preds, v_labels, v_losses = collect_predictions(video_model, video_loader, video_prep)
v_cm, v_f1 = per_emotion_report(v_preds, v_labels, v_losses, f"VIDEO — {BEST_VIDEO}")
plot_confusion_matrix(v_cm, f"Video Confusion Matrix — {BEST_VIDEO}")

del video_model
torch.cuda.empty_cache()

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 5))
x = np.arange(len(EMOTIONS))
w = 0.35

ax.bar(x - w/2, a_f1, w, label=f"Audio ({BEST_AUDIO})", color="#4C72B0")
ax.bar(x + w/2, v_f1, w, label=f"Video ({BEST_VIDEO})", color="#DD8452")

ax.set_xticks(x)
ax.set_xticklabels(EMOTIONS, rotation=30, ha="right")
ax.set_ylabel("F1 Score")
ax.set_title("Per-Emotion F1: Audio vs Video Encoder")
ax.set_ylim(0, 1)
ax.legend()
ax.grid(axis="y", alpha=0.3)

for i in x:
    ax.text(i - w/2, a_f1[i] + 0.02, f"{a_f1[i]:.2f}", ha="center", fontsize=8)
    ax.text(i + w/2, v_f1[i] + 0.02, f"{v_f1[i]:.2f}", ha="center", fontsize=8)

plt.tight_layout()
plt.show()

print("\n" + "="*60)
print("CROSS-MODAL EMOTION GAPS")
print("="*60)
print(f"{'Emotion':<12s} {'Audio F1':>9s} {'Video F1':>9s} {'Gap':>7s}")
print("-" * 40)
gaps = []
for i, emo in enumerate(EMOTIONS):
    gap = abs(a_f1[i] - v_f1[i])
    better = "A" if a_f1[i] > v_f1[i] else "V"
    gaps.append((emo, gap, better))
    print(f"{emo:<12s} {a_f1[i]:9.3f} {v_f1[i]:9.3f} {gap:7.3f} ({better})")

print(f"\nLargest cross-modal gaps:")
for emo, gap, better in sorted(gaps, key=lambda x: -x[1])[:3]:
    print(f"  {emo}: {gap:.3f} ({'audio' if better == 'A' else 'video'} stronger)")