In [None]:
!pip install face-alignment



In [None]:
pip install torch torchaudio opencv-python numpy face-alignment moviepy tqdm



In [None]:
"""
RAVDESS Dataset Preprocessing and PyTorch Dataset
Prepares audio-visual emotion data for talking face generation research.

Directory Structure:
    raw_data/
    ├── Actor_01/
    │   ├── 01-01-01-01-01-01-01.mp4
    │   └── ...
    ├── Actor_02/
    └── ...

Author: Cross-Modal Emotion Synchronization Research
"""

import json
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Union, Optional
from dataclasses import dataclass

import cv2
import numpy as np
import torch
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from moviepy.editor import VideoFileClip
import face_alignment


# ============================================================================
# Configuration
# ============================================================================

@dataclass
class PreprocessingConfig:
    """Configuration for RAVDESS preprocessing."""
    target_sample_rate: int = 16000
    video_height: int = 224
    video_width: int = 224
    num_frames_per_clip: int = 32
    min_face_detection_confidence: float = 0.5
    face_padding_multiplier: float = 1.2
    min_face_padding_pixels: int = 40
    num_facial_landmarks: int = 68


EMOTION_CODES = {
    '01': 'neutral',
    '02': 'calm',
    '03': 'happy',
    '04': 'sad',
    '05': 'angry',
    '06': 'fearful',
    '07': 'disgust',
    '08': 'surprised'
}

EMOTION_TO_INDEX = {
    'neutral': 0,
    'calm': 1,
    'happy': 2,
    'sad': 3,
    'angry': 4,
    'fearful': 5,
    'disgust': 6,
    'surprised': 7
}


# ============================================================================
# Utility Functions
# ============================================================================

def sample_frame_indices(
    total_frames: int,
    target_num_frames: int
) -> np.ndarray:
    """
    Uniformly sample target_num_frames indices from the total available frames.

    Args:
        total_frames: Total number of frames in the video
        target_num_frames: Desired number of frames to sample

    Returns:
        Array of frame indices to extract
    """
    if total_frames <= 0:
        warnings.warn("Total frames <= 0, returning zero indices")
        return np.zeros(target_num_frames, dtype=int)

    if total_frames <= target_num_frames:
        # Pad with last frame if insufficient frames
        indices = np.arange(total_frames)
        padding = np.full(
            target_num_frames - total_frames,
            total_frames - 1,
            dtype=int
        )
        return np.concatenate([indices, padding])

    # Uniform sampling
    return np.linspace(
        0,
        total_frames - 1,
        target_num_frames,
        dtype=int
    )


def ensure_directory_exists(path: Union[str, Path]) -> Path:
    """Create directory if it doesn't exist."""
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)
    return path


def validate_video_has_audio(video_path: Path) -> bool:
    """
    Check if video file contains an audio track.

    Args:
        video_path: Path to video file

    Returns:
        True if audio track exists, False otherwise

    Raises:
        ValueError: If video cannot be opened or has no audio
    """
    try:
        with VideoFileClip(str(video_path)) as clip:
            if clip.audio is None:
                raise ValueError(
                    f"Video file has no audio track: {video_path.name}"
                )
            return True
    except Exception as e:
        raise ValueError(
            f"Cannot validate audio in {video_path.name}: {str(e)}"
        )


# ============================================================================
# RAVDESS Filename Parser
# ============================================================================

class RAVDESSFilenameParser:
    """
    Parse RAVDESS filename format:
    Modality-VocalChannel-Emotion-Intensity-Statement-Repetition-Actor.mp4

    Example: 03-01-06-01-02-01-12.mp4
    """

    EXPECTED_PARTS = 7
    MODALITY_CODES = {
        '01': 'audio_video',
        '02': 'video_only',
        '03': 'audio_only'
    }
    VOCAL_CHANNEL_CODES = {
        '01': 'speech',
        '02': 'song'
    }

    @staticmethod
    def parse(filename: Union[str, Path]) -> Dict:
        """
        Parse RAVDESS filename into structured metadata.

        Args:
            filename: RAVDESS format filename

        Returns:
            Dictionary containing parsed metadata

        Raises:
            ValueError: If filename format is invalid
        """
        filename = Path(filename)
        parts = filename.stem.split('-')

        if len(parts) != RAVDESSFilenameParser.EXPECTED_PARTS:
            raise ValueError(
                f"Invalid RAVDESS filename format: {filename.name}. "
                f"Expected {RAVDESSFilenameParser.EXPECTED_PARTS} parts, "
                f"got {len(parts)}"
            )

        modality_code = parts[0]
        vocal_channel_code = parts[1]
        emotion_code = parts[2]
        intensity = int(parts[3])
        statement = int(parts[4])
        repetition = int(parts[5])
        actor_id = int(parts[6])

        # Determine gender (odd = male, even = female in RAVDESS)
        gender = 'female' if actor_id % 2 == 0 else 'male'

        # Get descriptive names
        modality = RAVDESSFilenameParser.MODALITY_CODES.get(
            modality_code,
            'unknown'
        )
        vocal_channel = RAVDESSFilenameParser.VOCAL_CHANNEL_CODES.get(
            vocal_channel_code,
            'unknown'
        )
        emotion = EMOTION_CODES.get(emotion_code, 'unknown')

        return {
            'modality': modality,
            'modality_code': modality_code,
            'vocal_channel': vocal_channel,
            'vocal_channel_code': vocal_channel_code,
            'emotion': emotion,
            'emotion_code': emotion_code,
            'intensity': intensity,
            'statement': statement,
            'repetition': repetition,
            'actor_id': actor_id,
            'gender': gender,
            'filename': filename.name
        }


# ============================================================================
# Audio Processor
# ============================================================================

class AudioProcessor:
    """Extract and preprocess audio from video files."""

    def __init__(self, target_sample_rate: int = 16000):
        self.target_sample_rate = target_sample_rate

    def extract_audio_from_video(
        self,
        video_path: Path,
        output_audio_path: Path
    ) -> Dict:
        """
        Extract audio track from video and save as WAV.

        Args:
            video_path: Path to input video file
            output_audio_path: Path to save extracted audio

        Returns:
            Dictionary with audio metadata (duration, sample_rate, channels)

        Raises:
            ValueError: If video has no audio track
        """
        # Validate audio exists
        validate_video_has_audio(video_path)

        # Extract audio using MoviePy
        with VideoFileClip(str(video_path)) as video_clip:
            temp_audio_path = output_audio_path.with_suffix(".temp.wav")

            video_clip.audio.write_audiofile(
                str(temp_audio_path),
                fps=self.target_sample_rate,
                nbytes=2,
                codec="pcm_s16le",
                verbose=False,
                logger=None
            )

        # Load and normalize audio
        waveform, sample_rate = torchaudio.load(str(temp_audio_path))

        # Resample if needed
        if sample_rate != self.target_sample_rate:
            resampler = torchaudio.transforms.Resample(
                sample_rate,
                self.target_sample_rate
            )
            waveform = resampler(waveform)

        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Save processed audio
        torchaudio.save(
            str(output_audio_path),
            waveform,
            self.target_sample_rate
        )

        # Clean up temp file
        temp_audio_path.unlink(missing_ok=True)

        return {
            'duration_seconds': waveform.shape[1] / self.target_sample_rate,
            'sample_rate': self.target_sample_rate,
            'num_channels': waveform.shape[0],
            'num_samples': waveform.shape[1]
        }


# ============================================================================
# Video Processor with Face Alignment
# ============================================================================

class VideoFaceProcessor:
    """Process video frames with face detection and alignment."""

    def __init__(self, config: PreprocessingConfig):
        self.config = config
        self.video_size = (config.video_width, config.video_height)

        # Initialize face alignment model
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.face_aligner = face_alignment.FaceAlignment(
            face_alignment.LandmarksType.TWO_D,
            flip_input=False,
            device=device
        )

    def _compute_face_crop_region(
        self,
        frame_shape: Tuple[int, int],
        landmarks: np.ndarray
    ) -> Tuple[int, int, int, int]:
        """
        Compute face crop region with dynamic padding based on inter-ocular distance.

        Args:
            frame_shape: (height, width) of the original frame
            landmarks: (68, 2) array of facial landmarks

        Returns:
            (x_min, y_min, x_max, y_max) crop coordinates
        """
        frame_height, frame_width = frame_shape

        # Get bounding box from landmarks
        x_coords, y_coords = landmarks[:, 0], landmarks[:, 1]
        x_min, x_max = int(x_coords.min()), int(x_coords.max())
        y_min, y_max = int(y_coords.min()), int(y_coords.max())

        # Calculate inter-ocular distance for adaptive padding
        left_eye_center = landmarks[36:42].mean(axis=0)  # Left eye landmarks
        right_eye_center = landmarks[42:48].mean(axis=0)  # Right eye landmarks
        inter_ocular_distance = np.linalg.norm(
            left_eye_center - right_eye_center
        )

        # Dynamic padding: larger faces get more padding
        padding = int(max(
            self.config.min_face_padding_pixels,
            self.config.face_padding_multiplier * inter_ocular_distance
        ))

        # Apply padding with boundary checks
        x_min = max(0, x_min - padding)
        y_min = max(0, y_min - padding)
        x_max = min(frame_width, x_max + padding)
        y_max = min(frame_height, y_max + padding)

        return x_min, y_min, x_max, y_max

    def _fallback_center_crop(self, frame: np.ndarray) -> np.ndarray:
        """
        Fallback center crop when face detection fails.

        Args:
            frame: RGB frame

        Returns:
            Center-cropped and resized face region
        """
        height, width = frame.shape[:2]
        crop_width, crop_height = self.video_size

        # Center coordinates
        center_x, center_y = width // 2, height // 2
        half_crop_width = crop_width // 2
        half_crop_height = crop_height // 2

        # Calculate crop bounds
        x_min = max(0, center_x - half_crop_width)
        y_min = max(0, center_y - half_crop_height)
        x_max = min(width, center_x + half_crop_width)
        y_max = min(height, center_y + half_crop_height)

        cropped = frame[y_min:y_max, x_min:x_max]

        if cropped.size == 0:
            # Emergency fallback: return resized full frame
            cropped = frame

        return cv2.resize(cropped, self.video_size, interpolation=cv2.INTER_AREA)

    def detect_and_crop_face(
        self,
        frame_rgb: np.ndarray
    ) -> Tuple[np.ndarray, bool]:
        """
        Detect face landmarks and crop face region.

        Args:
            frame_rgb: RGB frame (H, W, 3)

        Returns:
            Tuple of (cropped_face, detection_success)
            - cropped_face: Resized face crop (H, W, 3)
            - detection_success: True if face was detected
        """
        # Attempt face detection
        detected_landmarks = self.face_aligner.get_landmarks(frame_rgb)

        if not detected_landmarks:
            # No face detected - use fallback
            cropped_face = self._fallback_center_crop(frame_rgb)
            return cropped_face, False

        # Use first detected face
        landmarks = detected_landmarks[0]

        if landmarks.shape[0] != self.config.num_facial_landmarks:
            warnings.warn(
                f"Expected {self.config.num_facial_landmarks} landmarks, "
                f"got {landmarks.shape[0]}. Using fallback."
            )
            cropped_face = self._fallback_center_crop(frame_rgb)
            return cropped_face, False

        # Compute crop region
        x_min, y_min, x_max, y_max = self._compute_face_crop_region(
            frame_rgb.shape[:2],
            landmarks
        )

        # Crop and resize
        face_region = frame_rgb[y_min:y_max, x_min:x_max]

        if face_region.size == 0:
            cropped_face = self._fallback_center_crop(frame_rgb)
            return cropped_face, False

        cropped_face = cv2.resize(
            face_region,
            self.video_size,
            interpolation=cv2.INTER_AREA
        )

        return cropped_face, True

    def process_video_to_frames(
        self,
        video_path: Path,
        output_frames_dir: Path
    ) -> Dict:
        """
        Extract, detect faces, and save uniformly sampled frames.

        Args:
            video_path: Path to input video
            output_frames_dir: Directory to save frame arrays

        Returns:
            Dictionary with processing statistics
        """
        ensure_directory_exists(output_frames_dir)

        # Open video
        video_capture = cv2.VideoCapture(str(video_path))
        if not video_capture.isOpened():
            raise ValueError(f"Cannot open video: {video_path}")

        # Get video properties
        total_frames = int(video_capture.get(cv2.CAP_PROP_FRAME_COUNT))
        original_fps = video_capture.get(cv2.CAP_PROP_FPS) or 25.0

        # Determine frames to sample
        frame_indices = sample_frame_indices(
            total_frames,
            self.config.num_frames_per_clip
        )

        num_successful_detections = 0
        frames_saved = 0

        for output_index, source_frame_index in enumerate(frame_indices):
            # Seek to frame
            video_capture.set(cv2.CAP_PROP_POS_FRAMES, int(source_frame_index))
            success, frame_bgr = video_capture.read()

            if not success:
                # Frame read failed - use previous frame or black frame
                if output_index > 0:
                    previous_frame = np.load(
                        output_frames_dir / f"frame_{output_index-1:05d}.npy"
                    )
                    np.save(
                        output_frames_dir / f"frame_{output_index:05d}.npy",
                        previous_frame
                    )
                else:
                    # First frame failed - save black frame
                    black_frame = np.zeros(
                        (*self.video_size[::-1], 3),
                        dtype=np.float32
                    )
                    np.save(
                        output_frames_dir / f"frame_{output_index:05d}.npy",
                        black_frame
                    )
                frames_saved += 1
                continue

            # Convert BGR to RGB
            frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)

            # Detect and crop face
            face_crop, detection_success = self.detect_and_crop_face(frame_rgb)

            if detection_success:
                num_successful_detections += 1

            # Normalize to [0, 1]
            face_normalized = face_crop.astype(np.float32) / 255.0

            # Save as numpy array
            np.save(
                output_frames_dir / f"frame_{output_index:05d}.npy",
                face_normalized
            )
            frames_saved += 1

        video_capture.release()

        return {
            'num_frames_saved': frames_saved,
            'num_successful_face_detections': num_successful_detections,
            'face_detection_rate': (
                num_successful_detections / len(frame_indices)
                if len(frame_indices) > 0 else 0.0
            ),
            'original_fps': float(original_fps),
            'frames_directory': str(output_frames_dir)
        }


# ============================================================================
# Main Preprocessor
# ============================================================================

class RAVDESSPreprocessor:
    """
    Main preprocessor for RAVDESS dataset.
    Extracts audio and video with face alignment.

    Expected directory structure:
        raw_data/
        ├── Actor_01/
        │   ├── 01-01-01-01-01-01-01.mp4
        │   └── ...
        ├── Actor_02/
        └── ...
    """

    def __init__(
        self,
        dataset_root: Union[str, Path],
        output_root: Union[str, Path],
        config: Optional[PreprocessingConfig] = None
    ):
        """
        Args:
            dataset_root: Root directory containing Actor_XX folders
            output_root: Directory to save processed data
            config: Preprocessing configuration (uses defaults if None)
        """
        self.dataset_root = Path(dataset_root)
        self.output_root = Path(output_root)
        self.config = config or PreprocessingConfig()

        # Validate dataset structure
        self._validate_dataset_structure()

        # Initialize processors
        self.audio_processor = AudioProcessor(self.config.target_sample_rate)
        self.video_processor = VideoFaceProcessor(self.config)

        # Create output directories
        self.audio_dir = ensure_directory_exists(self.output_root / 'audio')
        self.frames_dir = ensure_directory_exists(self.output_root / 'frames')

    def _validate_dataset_structure(self):
        """Validate that the dataset has Actor_XX folders."""
        actor_folders = list(self.dataset_root.glob('Actor_*'))

        if not actor_folders:
            raise ValueError(
                f"No Actor_XX folders found in {self.dataset_root}.\n"
                f"Expected structure: {self.dataset_root}/Actor_01/, Actor_02/, etc."
            )

        print(f"Found {len(actor_folders)} actor folders in {self.dataset_root}")

        # Check for videos in actor folders
        total_videos = sum(
            len(list(actor_folder.glob('*.mp4')))
            for actor_folder in actor_folders
        )

        if total_videos == 0:
            raise ValueError(
                f"No .mp4 files found in Actor folders under {self.dataset_root}"
            )

        print(f"Found {total_videos} total video files across all actors")

    def get_all_video_files(self) -> List[Path]:
        """
        Get all video files from Actor folders.

        Returns:
            List of paths to video files
        """
        video_files = []
        actor_folders = sorted(self.dataset_root.glob('Actor_*'))

        for actor_folder in actor_folders:
            actor_videos = list(actor_folder.glob('*.mp4'))
            video_files.extend(actor_videos)

        return sorted(video_files)

    def process_single_video(
        self,
        video_path: Path
    ) -> Optional[Dict]:
        """
        Process a single RAVDESS video file.

        Args:
            video_path: Path to video file (e.g., Actor_01/01-01-01-01-01-01-01.mp4)

        Returns:
            Metadata dictionary or None if processing failed
        """
        try:
            # Parse filename
            file_metadata = RAVDESSFilenameParser.parse(video_path)
            sample_id = video_path.stem

            # Add actor folder info
            actor_folder = video_path.parent.name
            file_metadata['actor_folder'] = actor_folder

            # Process audio
            audio_output_path = self.audio_dir / f"{sample_id}.wav"
            if not audio_output_path.exists():
                audio_metadata = self.audio_processor.extract_audio_from_video(
                    video_path,
                    audio_output_path
                )
            else:
                # Load existing audio for metadata
                waveform, sample_rate = torchaudio.load(str(audio_output_path))
                audio_metadata = {
                    'duration_seconds': waveform.shape[1] / sample_rate,
                    'sample_rate': sample_rate,
                    'num_channels': waveform.shape[0],
                    'num_samples': waveform.shape[1]
                }

            # Process video
            video_frames_output_dir = self.frames_dir / sample_id
            existing_frames = list(video_frames_output_dir.glob('frame_*.npy'))

            if len(existing_frames) != self.config.num_frames_per_clip:
                video_metadata = self.video_processor.process_video_to_frames(
                    video_path,
                    video_frames_output_dir
                )
            else:
                # Frames already processed
                video_metadata = {
                    'num_frames_saved': self.config.num_frames_per_clip,
                    'num_successful_face_detections': len(existing_frames),
                    'face_detection_rate': 1.0,
                    'original_fps': 25.0,
                    'frames_directory': str(video_frames_output_dir)
                }

            # Combine all metadata
            complete_metadata = {
                'sample_id': sample_id,
                'original_video_path': str(video_path),
                'audio_path': str(audio_output_path),
                'video_frames_directory': video_metadata['frames_directory'],
                'num_frames': video_metadata['num_frames_saved'],
                'audio_duration_seconds': audio_metadata['duration_seconds'],
                'target_num_frames': self.config.num_frames_per_clip,
                'video_size': [self.config.video_width, self.config.video_height],
                'face_detection_rate': video_metadata['face_detection_rate'],
                **file_metadata
            }

            return complete_metadata

        except Exception as error:
            warnings.warn(
                f"Failed to process {video_path.name}: {str(error)}"
            )
            return None

    def process_dataset(
        self,
        modality_filter: str = 'audio_video',
        vocal_channel_filter: str = 'speech'
    ) -> List[Dict]:
        """
        Process entire RAVDESS dataset from Actor folders.

        Args:
            modality_filter: 'audio_video', 'video_only', or 'audio_only'
            vocal_channel_filter: 'speech' or 'song'

        Returns:
            List of metadata dictionaries for all processed samples
        """
        # Get all video files from Actor folders
        video_files = self.get_all_video_files()
        print(f"\nFound {len(video_files)} video files across all actor folders")

        processed_metadata = []
        skipped_modality = 0
        skipped_vocal = 0
        failed = 0

        for video_path in tqdm(video_files, desc="Processing RAVDESS videos"):
            # Process video
            metadata = self.process_single_video(video_path)

            if metadata is None:
                failed += 1
                continue

            # Apply modality filter
            if metadata['modality'] != modality_filter:
                skipped_modality += 1
                continue

            # Apply vocal channel filter
            if metadata['vocal_channel'] != vocal_channel_filter:
                skipped_vocal += 1
                continue

            processed_metadata.append(metadata)

        # Save metadata
        metadata_output_path = self.output_root / 'metadata.json'
        with open(metadata_output_path, 'w', encoding='utf-8') as file:
            json.dump(processed_metadata, file, indent=2, ensure_ascii=False)

        # Print summary
        print("\n" + "=" * 80)
        print("Processing Summary")
        print("=" * 80)
        print(f"Total videos found:          {len(video_files)}")
        print(f"Successfully processed:      {len(processed_metadata)}")
        print(f"Skipped (modality filter):   {skipped_modality}")
        print(f"Skipped (vocal filter):      {skipped_vocal}")
        print(f"Failed:                      {failed}")
        print(f"\nMetadata saved to: {metadata_output_path}")
        print("=" * 80)

        return processed_metadata


# ============================================================================
# PyTorch Dataset
# ============================================================================

class RAVDESSEmotionDataset(Dataset):
    """
    PyTorch Dataset for RAVDESS emotion recognition.

    Returns batches containing:
        - audio: 1D waveform tensor (variable length)
        - video: (T, C, H, W) tensor normalized to [0, 1]
        - emotion_label: Integer in [0, 7]
        - sample_id: String identifier
    """

    def __init__(
        self,
        metadata_path: Union[str, Path],
        modality: str = 'both'
    ):
        """
        Args:
            metadata_path: Path to metadata.json from preprocessor
            modality: 'audio', 'video', or 'both'
        """
        self.modality = modality
        self.metadata_path = Path(metadata_path)

        # Load metadata
        with open(self.metadata_path, 'r', encoding='utf-8') as file:
            self.samples = json.load(file)

        if len(self.samples) == 0:
            raise ValueError(f"No samples found in {metadata_path}")

        # Extract configuration from first sample
        first_sample = self.samples[0]
        self.num_frames = first_sample['target_num_frames']
        self.video_height = first_sample['video_size'][1]
        self.video_width = first_sample['video_size'][0]

        print(
            f"Loaded {len(self.samples)} samples from {metadata_path}\n"
            f"Configuration: {self.num_frames} frames, "
            f"{self.video_width}x{self.video_height} resolution"
        )

    def __len__(self) -> int:
        return len(self.samples)

    def _load_audio(self, audio_path: str) -> torch.Tensor:
        """Load audio waveform."""
        waveform, sample_rate = torchaudio.load(audio_path)

        # Ensure 16kHz mono
        if sample_rate != 16000:
            resampler = torchaudio.transforms.Resample(sample_rate, 16000)
            waveform = resampler(waveform)

        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        return waveform.squeeze(0)  # Shape: [num_samples]

    def _load_video(self, frames_directory: str) -> torch.Tensor:
        """Load video frames as (T, C, H, W) tensor."""
        frames_dir = Path(frames_directory)
        frame_files = sorted(frames_dir.glob('frame_*.npy'))

        if len(frame_files) == 0:
            warnings.warn(f"No frames found in {frames_directory}")
            return torch.zeros(
                (self.num_frames, 3, self.video_height, self.video_width),
                dtype=torch.float32
            )

        # Ensure correct number of frames
        if len(frame_files) != self.num_frames:
            indices = sample_frame_indices(len(frame_files), self.num_frames)
            frame_files = [frame_files[idx] for idx in indices]

        # Load frames
        loaded_frames = []
        for frame_path in frame_files:
            frame_array = np.load(frame_path)  # (H, W, C) in [0, 1]

            # Validate shape
            if frame_array.ndim != 3 or frame_array.shape[2] != 3:
                warnings.warn(f"Malformed frame: {frame_path}")
                frame_array = np.zeros(
                    (self.video_height, self.video_width, 3),
                    dtype=np.float32
                )

            loaded_frames.append(frame_array)

        # Stack and convert to tensor: (T, H, W, C) -> (T, C, H, W)
        frames_array = np.stack(loaded_frames, axis=0)
        frames_tensor = torch.from_numpy(frames_array).permute(0, 3, 1, 2)

        return frames_tensor.contiguous().float()

    def __getitem__(self, index: int) -> Dict:
        """Get a single sample."""
        sample = self.samples[index]

        output = {
            'sample_id': sample['sample_id'],
            'emotion_label': EMOTION_TO_INDEX[sample['emotion']]
        }

        # Load audio if requested
        if self.modality in ('audio', 'both'):
            output['audio'] = self._load_audio(sample['audio_path'])

        # Load video if requested
        if self.modality in ('video', 'both'):
            output['video'] = self._load_video(sample['video_frames_directory'])

        return output


def collate_batch(batch: List[Dict]) -> Dict:
    """
    Custom collate function for DataLoader.

    Videos are stacked into (B, T, C, H, W).
    Audio remains as list due to variable lengths.
    """
    sample_ids = [item['sample_id'] for item in batch]
    emotion_labels = torch.tensor(
        [item['emotion_label'] for item in batch],
        dtype=torch.long
    )

    collated = {
        'sample_id': sample_ids,
        'emotion_label': emotion_labels
    }

    # Collate audio (keep as list - processors will pad later)
    if 'audio' in batch[0]:
        collated['audio'] = [item['audio'] for item in batch]

    # Collate video (stack into batch)
    if 'video' in batch[0]:
        videos = [item['video'] for item in batch]
        collated['video'] = torch.stack(videos, dim=0)  # (B, T, C, H, W)

    return collated


# ============================================================================
# Validation Functions
# ============================================================================

def validate_processed_dataset(output_root: Path):
    """Run comprehensive sanity checks on processed data."""
    metadata_path = output_root / 'metadata.json'

    if not metadata_path.exists():
        raise FileNotFoundError(f"Metadata not found: {metadata_path}")

    with open(metadata_path, 'r') as f:
        samples = json.load(f)

    print("\n" + "=" * 80)
    print("Dataset Validation Report")
    print("=" * 80)

    # Basic statistics
    print(f"\nTotal samples: {len(samples)}")

    # Actor distribution
    actor_counts = {}
    for sample in samples:
        actor_id = sample['actor_id']
        actor_counts[actor_id] = actor_counts.get(actor_id, 0) + 1

    print(f"\nActors represented: {len(actor_counts)}")
    print(f"Samples per actor: {len(samples) / len(actor_counts):.1f} average")

    # Gender distribution
    gender_counts = {}
    for sample in samples:
        gender = sample['gender']
        gender_counts[gender] = gender_counts.get(gender, 0) + 1

    print("\nGender Distribution:")
    for gender, count in sorted(gender_counts.items()):
        print(f"  {gender}: {count} ({count/len(samples)*100:.1f}%)")

    # Emotion distribution
    emotion_counts = {}
    for sample in samples:
        emotion = sample['emotion']
        emotion_counts[emotion] = emotion_counts.get(emotion, 0) + 1

    print("\nEmotion Distribution:")
    for emotion, count in sorted(emotion_counts.items(), key=lambda x: x[1], reverse=True):
        print(f"  {emotion:12s}: {count:3d} ({count/len(samples)*100:.1f}%)")

    # Intensity distribution
    intensity_counts = {}
    for sample in samples:
        intensity = sample['intensity']
        intensity_counts[intensity] = intensity_counts.get(intensity, 0) + 1

    print("\nIntensity Distribution:")
    for intensity, count in sorted(intensity_counts.items()):
        intensity_name = "normal" if intensity == 1 else "strong"
        print(f"  {intensity_name}: {count} ({count/len(samples)*100:.1f}%)")

    # Face detection rates
    detection_rates = [s['face_detection_rate'] for s in samples]
    avg_rate = sum(detection_rates) / len(detection_rates)
    min_rate = min(detection_rates)
    max_rate = max(detection_rates)

    print(f"\nFace Detection Statistics:")
    print(f"  Average: {avg_rate:.2%}")
    print(f"  Min: {min_rate:.2%}")
    print(f"  Max: {max_rate:.2%}")

    # Find problematic samples
    low_detection = [s for s in samples if s['face_detection_rate'] < 0.5]
    if low_detection:
        print(f"\n⚠️  {len(low_detection)} samples with <50% face detection:")
        for s in low_detection[:5]:
            print(f"  - {s['sample_id']} (Actor {s['actor_id']}): {s['face_detection_rate']:.2%}")
        if len(low_detection) > 5:
            print(f"  ... and {len(low_detection) - 5} more")
    else:
        print("\n✓ All samples have ≥50% face detection rate")

    # Audio duration statistics
    durations = [s['audio_duration_seconds'] for s in samples]
    avg_duration = sum(durations) / len(durations)
    min_duration = min(durations)
    max_duration = max(durations)

    print(f"\nAudio Duration Statistics:")
    print(f"  Average: {avg_duration:.2f}s")
    print(f"  Min: {min_duration:.2f}s")
    print(f"  Max: {max_duration:.2f}s")

    # Verify files exist
    print("\nVerifying file existence...")
    missing_audio = 0
    missing_frames = 0

    for sample in samples:
        if not Path(sample['audio_path']).exists():
            missing_audio += 1
        if not Path(sample['video_frames_directory']).exists():
            missing_frames += 1

    if missing_audio == 0 and missing_frames == 0:
        print("✓ All audio and frame files exist")
    else:
        if missing_audio > 0:
            print(f"⚠️  {missing_audio} audio files missing")
        if missing_frames > 0:
            print(f"⚠️  {missing_frames} frame directories missing")

    print("\n" + "=" * 80)


# ============================================================================
# Example Usage
# ============================================================================

if __name__ == "__main__":
    # Configuration
    DATASET_ROOT = Path("/content/raw_data")  # Contains Actor_01, Actor_02, etc.
    OUTPUT_ROOT = Path("/content/processed_data")

    # Step 1: Preprocess dataset
    print("=" * 80)
    print("RAVDESS Dataset Preprocessing Pipeline")
    print("=" * 80)

    preprocessor = RAVDESSPreprocessor(
        dataset_root=DATASET_ROOT,
        output_root=OUTPUT_ROOT,
        config=PreprocessingConfig(
            target_sample_rate=16000,
            video_height=224,
            video_width=224,
            num_frames_per_clip=32
        )
    )

    # Process only audio-video speech samples (recommended for your research)
    metadata = preprocessor.process_dataset(
        modality_filter='audio_video',  # Only AV files
        vocal_channel_filter='speech'    # Only speech (not song)
    )

    # Step 2: Validate processed data
    validate_processed_dataset(OUTPUT_ROOT)

    # Step 3: Create PyTorch Dataset
    print("\n" + "=" * 80)
    print("Creating PyTorch Dataset")
    print("=" * 80)

    dataset = RAVDESSEmotionDataset(
        metadata_path=OUTPUT_ROOT / 'metadata.json',
        modality='both'
    )

    # Inspect single sample
    sample = dataset[0]
    print(f"\nSample ID: {sample['sample_id']}")
    print(f"Emotion Label: {sample['emotion_label']}")
    print(f"Audio Shape: {sample['audio'].shape}")
    print(f"Video Shape: {sample['video'].shape}")

    # Step 4: Create DataLoader
    print("\n" + "=" * 80)
    print("Testing DataLoader")
    print("=" * 80)

    dataloader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        collate_fn=collate_batch,
        num_workers=0
    )

    batch = next(iter(dataloader))
    print(f"\nBatch Size: {len(batch['sample_id'])}")
    print(f"Emotion Labels: {batch['emotion_label']}")
    print(f"Video Shape: {batch['video'].shape}")  # (B, T, C, H, W)
    print(f"Audio: {len(batch['audio'])} waveforms (variable length)")

    print("\n" + "=" * 80)
    print("✓ Preprocessing and Dataset Creation Complete!")
    print("=" * 80)

RAVDESS Dataset Preprocessing Pipeline
Found 6 actor folders in /content/raw_data
Found 360 total video files across all actors

Found 360 video files across all actor folders



  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


  s = torchaudio.io.StreamReader(src, format, None, buffer_size)


  s = torchaudio.io.StreamWriter(uri, format=muxer, buffer_size=buffer_size)

Processing RAVDESS videos: 100%|██████████| 360/360 [2:43:51<00:00, 27.31s/it]


Processing Summary
Total videos found:          360
Successfully processed:      360
Skipped (modality filter):   0
Skipped (vocal filter):      0
Failed:                      0

Metadata saved to: /content/processed_data/metadata.json

Dataset Validation Report

Total samples: 360

Actors represented: 6
Samples per actor: 60.0 average

Gender Distribution:
  female: 180 (50.0%)
  male: 180 (50.0%)

Emotion Distribution:
  calm        :  48 (13.3%)
  happy       :  48 (13.3%)
  sad         :  48 (13.3%)
  angry       :  48 (13.3%)
  fearful     :  48 (13.3%)
  disgust     :  48 (13.3%)
  surprised   :  48 (13.3%)
  neutral     :  24 (6.7%)

Intensity Distribution:
  normal: 192 (53.3%)
  strong: 168 (46.7%)

Face Detection Statistics:
  Average: 99.91%
  Min: 96.88%
  Max: 100.00%

✓ All samples have ≥50% face detection rate

Audio Duration Statistics:
  Average: 3.80s
  Min: 3.07s
  Max: 5.27s

Verifying file existence...
✓ All audio and frame files exist


Creating PyTorch Dataset
L


