In [16]:
import cv2
import numpy as np
from typing import Generator, Tuple, List
from pathlib import Path


def pad_frame_buffer(frame_buffer: List[np.ndarray], buffer_size: int) -> List[np.ndarray]:
    """
    Pad frame buffer with last frame if needed.

    Args:
        frame_buffer: List of frames
        buffer_size: Desired buffer size

    Returns:
        Padded frame buffer
    """
    while len(frame_buffer) < buffer_size:
        frame_buffer.append(frame_buffer[-1])
    return frame_buffer

def extract_frame_tensors(
    video_path: str,
    frames_per_second: int = 8,
    buffer_size: int = 16,
    frame_stride: int = 8
) -> Generator[Tuple[np.ndarray], None, None]:
    """
    Extract frame tensors from video in a memory-efficient way using frame buffers.

    Args:
        video_path: Path to video file
        frames_per_second: Number of frames to extract per second
        buffer_size: Number of frames in each tensor buffer
        frame_stride: Number of frames to stride between buffers

    Returns:
        Generator yielding tuples of (frame_tensor, timestamp)
        frame_tensor shape: (buffer_size, height, width, channels)
    """
    if not Path(video_path).exists():
        raise FileNotFoundError(f"Video file not found: {video_path}")

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError("Failed to open video file")

    try:
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        duration = total_frames / fps
        target_frames = np.linspace(0,
                                    total_frames - 1,
                                    num=int(duration * frames_per_second),
                                    dtype=np.int32)

        frame_buffer = []
        current_frame = 0

        for target_idx in target_frames:
            # Skip to target frame
            while current_frame < target_idx:
                cap.read()
                current_frame += 1

            ret, frame = cap.read()
            if not ret:
                break
            frame_buffer.append(frame)
            current_frame += 1

            if len(frame_buffer) == buffer_size:
                print(current_frame)
                frame_tensor = np.stack(frame_buffer, axis=0)
                yield frame_tensor

                # Slide buffer window by stride
                frame_buffer = frame_buffer[frame_stride:]

        # Handle remaining frames if any
        if len(frame_buffer) >= buffer_size // 2:
            frame_buffer = pad_frame_buffer(frame_buffer, buffer_size)
            frame_tensor = np.stack(frame_buffer, axis=0)
            yield frame_tensor

    finally:
        cap.release()

def process_video(video_path: str) -> None:
    """Example usage of the frame tensor extractor"""
    frame_tensors = extract_frame_tensors(
        video_path,
        frames_per_second=8,
        buffer_size=16,
        frame_stride=8
    )

    for idx, tensor in enumerate(frame_tensors):
        print(f"Frame tensor {idx}: shape {tensor.shape}")

In [17]:
process_video("./datasets/ATMA-V/videos/train/BT-aug/3-aug-h_flip.mp4")

57
Frame tensor 0: shape (16, 224, 224, 3)
87
Frame tensor 1: shape (16, 224, 224, 3)
118
Frame tensor 2: shape (16, 224, 224, 3)
148
Frame tensor 3: shape (16, 224, 224, 3)
178
Frame tensor 4: shape (16, 224, 224, 3)
208
Frame tensor 5: shape (16, 224, 224, 3)
238
Frame tensor 6: shape (16, 224, 224, 3)
268
Frame tensor 7: shape (16, 224, 224, 3)
299
Frame tensor 8: shape (16, 224, 224, 3)
329
Frame tensor 9: shape (16, 224, 224, 3)
359
Frame tensor 10: shape (16, 224, 224, 3)
389
Frame tensor 11: shape (16, 224, 224, 3)
419
Frame tensor 12: shape (16, 224, 224, 3)
450
Frame tensor 13: shape (16, 224, 224, 3)
480
Frame tensor 14: shape (16, 224, 224, 3)
510
Frame tensor 15: shape (16, 224, 224, 3)
540
Frame tensor 16: shape (16, 224, 224, 3)
570
Frame tensor 17: shape (16, 224, 224, 3)
601
Frame tensor 18: shape (16, 224, 224, 3)
631
Frame tensor 19: shape (16, 224, 224, 3)
661
Frame tensor 20: shape (16, 224, 224, 3)
691
Frame tensor 21: shape (16, 224, 224, 3)
721
Frame tensor 22: s