In [20]:
!pip install -q face-alignment torchaudio

In [21]:
import json
import subprocess
import warnings
from collections import Counter
from dataclasses import dataclass
from pathlib import Path

import cv2
import face_alignment
import numpy as np
import torch
import torchaudio
from tqdm import tqdm

warnings.filterwarnings("ignore")


@dataclass
class Config:
    raw_dir: Path = Path("/content/raw_data")
    out_dir: Path = Path("/content/processed_data")
    sample_rate: int = 16000
    frame_size: int = 224
    n_frames: int = 32
    face_pad_factor: float = 1.2
    min_face_pad: int = 40
    val_ratio: float = 0.15
    test_ratio: float = 0.15
    split_seed: int = 42


EMOTIONS = {
    "01": "neutral", "02": "calm", "03": "happy", "04": "sad",
    "05": "angry", "06": "fearful", "07": "disgust", "08": "surprised",
}
EMOTION_TO_IDX = {v: i for i, v in enumerate(EMOTIONS.values())}

cfg = Config()

In [22]:
def parse_ravdess(path: Path) -> dict:
    """Modality-Channel-Emotion-Intensity-Statement-Repetition-Actor.mp4"""
    parts = path.stem.split("-")
    if len(parts) != 7:
        return None
    actor_id = int(parts[6])
    emotion = EMOTIONS.get(parts[2], "unknown")
    return {
        "video_path": str(path),
        "sample_id": path.stem,
        "modality": parts[0],
        "channel": parts[1],
        "emotion": emotion,
        "emotion_idx": EMOTION_TO_IDX.get(emotion, -1),
        "intensity": int(parts[3]),
        "actor_id": actor_id,
        "gender": "female" if actor_id % 2 == 0 else "male",
    }


def collect_av_speech_samples(raw_dir: Path) -> list[dict]:
    samples = []
    for path in sorted(raw_dir.glob("Actor_*/*.mp4")):
        meta = parse_ravdess(path)
        if meta and meta["modality"] == "01" and meta["channel"] == "01":
            samples.append(meta)
    return samples


def assign_splits(samples: list[dict], val_ratio: float, test_ratio: float, seed: int):
    actors = sorted({s["actor_id"] for s in samples})
    rng = np.random.RandomState(seed)
    shuffled = list(actors)
    rng.shuffle(shuffled)
    n_test = max(1, round(len(shuffled) * test_ratio))
    n_val = max(1, round(len(shuffled) * val_ratio))
    test_set = set(shuffled[:n_test])
    val_set = set(shuffled[n_test:n_test + n_val])
    for s in samples:
        if s["actor_id"] in test_set:
            s["split"] = "test"
        elif s["actor_id"] in val_set:
            s["split"] = "val"
        else:
            s["split"] = "train"


samples = collect_av_speech_samples(cfg.raw_dir)
assign_splits(samples, cfg.val_ratio, cfg.test_ratio, cfg.split_seed)

print(f"Total: {len(samples)} AV speech samples")
for split in ("train", "val", "test"):
    n = sum(1 for s in samples if s["split"] == split)
    print(f"  {split}: {n}")

Total: 1380 AV speech samples
  train: 1020
  val: 180
  test: 180


In [23]:
def extract_audio(video_path: str, out_path: Path, sr: int) -> float:
    subprocess.run(
        ["ffmpeg", "-y", "-loglevel", "error", "-i", video_path,
         "-ar", str(sr), "-ac", "1", str(out_path)],
        check=True,
    )
    info = torchaudio.info(str(out_path))
    return info.num_frames / info.sample_rate


audio_dir = cfg.out_dir / "audio"
audio_dir.mkdir(parents=True, exist_ok=True)

for s in tqdm(samples, desc="Extracting audio"):
    wav_path = audio_dir / f"{s['sample_id']}.wav"
    s["audio_path"] = str(wav_path)
    if wav_path.exists():
        info = torchaudio.info(str(wav_path))
        s["duration"] = info.num_frames / info.sample_rate
        continue
    try:
        s["duration"] = extract_audio(s["video_path"], wav_path, cfg.sample_rate)
    except Exception as e:
        print(f"FAILED {s['sample_id']}: {e}")
        s["duration"] = 0.0

Extracting audio: 100%|██████████| 1380/1380 [03:02<00:00,  7.56it/s]


In [24]:
class FaceProcessor:
    def __init__(self, size: int, pad_factor: float, min_pad: int):
        self.size = size
        self.pad_factor = pad_factor
        self.min_pad = min_pad
        device = "cuda" if torch.cuda.is_available() else "cpu"
        self.fa = face_alignment.FaceAlignment(
            face_alignment.LandmarksType.TWO_D, flip_input=False, device=device
        )

    def crop_face(self, frame: np.ndarray) -> tuple[np.ndarray, bool]:
        landmarks = self.fa.get_landmarks(frame)
        if not landmarks or landmarks[0].shape[0] != 68:
            return self._center_crop(frame), False
        lm = landmarks[0]
        h, w = frame.shape[:2]
        iod = np.linalg.norm(lm[36:42].mean(0) - lm[42:48].mean(0))
        pad = int(max(self.min_pad, self.pad_factor * iod))
        x0 = max(0, int(lm[:, 0].min()) - pad)
        y0 = max(0, int(lm[:, 1].min()) - pad)
        x1 = min(w, int(lm[:, 0].max()) + pad)
        y1 = min(h, int(lm[:, 1].max()) + pad)
        crop = frame[y0:y1, x0:x1]
        if crop.size == 0:
            return self._center_crop(frame), False
        return cv2.resize(crop, (self.size, self.size)), True

    def _center_crop(self, frame: np.ndarray) -> np.ndarray:
        h, w = frame.shape[:2]
        s = min(h, w)
        y0, x0 = (h - s) // 2, (w - s) // 2
        return cv2.resize(frame[y0:y0 + s, x0:x0 + s], (self.size, self.size))


def sample_indices(total: int, n: int) -> np.ndarray:
    if total <= 0:
        return np.zeros(n, dtype=int)
    if total <= n:
        return np.pad(np.arange(total), (0, n - total), mode="edge")
    return np.linspace(0, total - 1, n, dtype=int)


def process_video(path: str, fp: FaceProcessor, n_frames: int) -> tuple[np.ndarray, float]:
    cap = cv2.VideoCapture(path)
    total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = sample_indices(total, n_frames)
    frames, hits = [], 0
    for idx in indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, int(idx))
        ok, bgr = cap.read()
        if not ok:
            frames.append(np.zeros((fp.size, fp.size, 3), dtype=np.uint8))
            continue
        face, detected = fp.crop_face(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
        hits += int(detected)
        frames.append(face.astype(np.uint8))
    cap.release()
    return np.stack(frames), hits / max(len(indices), 1)


fp = FaceProcessor(cfg.frame_size, cfg.face_pad_factor, cfg.min_face_pad)
frames_dir = cfg.out_dir / "frames"
frames_dir.mkdir(parents=True, exist_ok=True)

for s in tqdm(samples, desc="Processing video"):
    npy_path = frames_dir / f"{s['sample_id']}.npy"
    s["frames_path"] = str(npy_path)
    if npy_path.exists():
        s["face_det_rate"] = -1.0
        continue
    try:
        frames, det_rate = process_video(s["video_path"], fp, cfg.n_frames)
        np.save(npy_path, frames)
        s["face_det_rate"] = det_rate
    except Exception as e:
        print(f"FAILED {s['sample_id']}: {e}")
        s["face_det_rate"] = 0.0

Processing video: 100%|██████████| 1380/1380 [1:03:53<00:00,  2.78s/it]


In [25]:
meta_path = cfg.out_dir / "metadata.json"
with open(meta_path, "w") as f:
    json.dump(samples, f, indent=2)
print(f"Saved {len(samples)} samples -> {meta_path}\n")

for split in ("train", "val", "test"):
    sub = [s for s in samples if s["split"] == split]
    print(f"{split} ({len(sub)} samples)")
    for emo, cnt in sorted(Counter(s["emotion"] for s in sub).items()):
        print(f"  {emo:12s} {cnt}")
    dets = [s["face_det_rate"] for s in sub if s["face_det_rate"] >= 0]
    durs = [s.get("duration", 0) for s in sub]
    if dets:
        print(f"  face_det:    {np.mean(dets):.1%}")
    print(f"  duration:    {np.mean(durs):.1f}s avg\n")

Saved 1380 samples -> /content/processed_data/metadata.json

train (1020 samples)
  angry        136
  calm         136
  disgust      136
  fearful      136
  happy        136
  neutral      68
  sad          136
  surprised    136
  face_det:    98.9%
  duration:    3.7s avg

val (180 samples)
  angry        24
  calm         24
  disgust      24
  fearful      24
  happy        24
  neutral      12
  sad          24
  surprised    24
  face_det:    99.0%
  duration:    3.5s avg

test (180 samples)
  angry        24
  calm         24
  disgust      24
  fearful      24
  happy        24
  neutral      12
  sad          24
  surprised    24
  face_det:    99.0%
  duration:    3.8s avg



In [26]:
class RAVDESSDataset(torch.utils.data.Dataset):
    def __init__(self, metadata_path: str, split: str = "train", modality: str = "both"):
        with open(metadata_path) as f:
            data = json.load(f)
        self.samples = [s for s in data if s["split"] == split]
        self.modality = modality
        if not self.samples:
            raise ValueError(f"No samples for split='{split}'")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        s = self.samples[idx]
        item = {"sample_id": s["sample_id"], "emotion": s["emotion_idx"]}

        if self.modality in ("audio", "both"):
            wav, _ = torchaudio.load(s["audio_path"])
            item["audio"] = wav.squeeze(0)

        if self.modality in ("video", "both"):
            frames = np.load(s["frames_path"])  # (T, H, W, 3) uint8
            item["video"] = torch.from_numpy(frames).permute(0, 3, 1, 2).float() / 255.0

        return item


def collate_fn(batch):
    out = {
        "sample_id": [b["sample_id"] for b in batch],
        "emotion": torch.tensor([b["emotion"] for b in batch]),
    }
    if "audio" in batch[0]:
        out["audio"] = [b["audio"] for b in batch]
    if "video" in batch[0]:
        out["video"] = torch.stack([b["video"] for b in batch])
    return out

In [27]:
ds = RAVDESSDataset(str(meta_path), split="train", modality="both")
print(f"Train: {len(ds)} samples\n")

sample = ds[0]
print(f"sample_id: {sample['sample_id']}")
print(f"emotion:   {sample['emotion']}")
print(f"audio:     {sample['audio'].shape}")
print(f"video:     {sample['video'].shape}")

loader = torch.utils.data.DataLoader(ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
batch = next(iter(loader))
print(f"\nBatch: {len(batch['audio'])} audio, video {batch['video'].shape}, emotions {batch['emotion']}")

Train: 1020 samples

sample_id: 01-01-01-01-01-01-02
emotion:   0
audio:     torch.Size([58368])
video:     torch.Size([32, 3, 224, 224])

Batch: 4 audio, video torch.Size([4, 32, 3, 224, 224]), emotions tensor([6, 7, 4, 7])
