In [None]:
import cv2
import numpy as np
from typing import Generator, Tuple, List
from pathlib import Path


def pad_frame_buffer(frame_buffer: List[np.ndarray], buffer_size: int) -> List[np.ndarray]:
    """
    Pad frame buffer with last frame if needed.

    Args:
        frame_buffer: List of frames
        buffer_size: Desired buffer size

    Returns:
        Padded frame buffer
    """
    while len(frame_buffer) < buffer_size:
        frame_buffer.append(frame_buffer[-1])
    return frame_buffer

def extract_frame_tensors(
    video_path: str,
    frames_per_second: int = 8,
    buffer_size: int = 16,
    frame_stride: int = 8
) -> Generator[Tuple[np.ndarray], None, None]:
    """
    Extract frame tensors from video in a memory-efficient way using frame buffers.

    Args:
        video_path: Path to video file
        frames_per_second: Number of frames to extract per second
        buffer_size: Number of frames in each tensor buffer
        frame_stride: Number of frames to stride between buffers

    Returns:
        Generator yielding tuples of (frame_tensor, timestamp)
        frame_tensor shape: (buffer_size, height, width, channels)
    """
    if not Path(video_path).exists():
        raise FileNotFoundError(f"Video file not found: {video_path}")

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError("Failed to open video file")

    try:
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        fps = cap.get(cv2.CAP_PROP_FPS)
        duration = total_frames / fps
        target_frames = np.linspace(0,
                                    total_frames - 1,
                                    num=int(duration * frames_per_second),
                                    dtype=np.int32)
        print(target_frames)
        print(len(target_frames))

        frame_buffer = []
        current_frame = 0

        for target_idx in target_frames:
            # Skip to target frame
            while current_frame < target_idx:
                cap.read()
                current_frame += 1

            ret, frame = cap.read()
            if not ret:
                break
            frame_buffer.append(frame)
            current_frame += 1

            if len(frame_buffer) == buffer_size:
                frame_tensor = np.stack(frame_buffer, axis=0)
                frame_tensor = np.transpose(frame_tensor, (0, 3, 1, 2))
                yield frame_tensor

                # Slide buffer window by stride
                frame_buffer = frame_buffer[frame_stride:]

        # # Handle remaining frames if any
        # if len(frame_buffer) >= buffer_size // 2:
        #     frame_buffer = pad_frame_buffer(frame_buffer, buffer_size)
        #     frame_tensor = np.stack(frame_buffer, axis=0)
        #     frame_tensor = np.transpose(frame_tensor, (0, 3, 1, 2))
        #     yield frame_tensor

    finally:
        cap.release()

def process_video(video_path: str) -> None:
    """Example usage of the frame tensor extractor"""
    frame_tensors = extract_frame_tensors(
        video_path,
        frames_per_second=8,
        buffer_size=16,
        frame_stride=8
    )

    for idx, tensor in enumerate(frame_tensors):
        print(f"Frame tensor {idx}: shape {tensor.shape}")


In [None]:
process_video("./datasets/ATMA-V/videos/train/BT-aug/11-aug-gauss_blur.mp4")

In [1]:
from dataset import TimesformerData

dataset = TimesformerData(
    vid_folder_path="./datasets/ATMA-V/videos/train/aug",
    label_path="./datasets/ATMA-V/labels/labels.txt"
)

In [2]:
tensor, label = dataset[0]
tensor.shape, label

(torch.Size([8, 3, 224, 224]), tensor(0))

In [None]:
vid_tensor_count = {}
for tup in dataset.video_tensor_idx_mapping:
    if tup[0] not in vid_tensor_count:
        vid_tensor_count[tup[0]] = 0
    vid_tensor_count[tup[0]] += 1

In [None]:
dataset.video_tensor_idx_mapping

In [None]:
sorted_vid_tensor_count = dict(sorted(vid_tensor_count.items(), key=lambda item: item[1], reverse=True))
sorted_vid_tensor_count

In [None]:
stt = "13"
for tup in dataset.video_tensor_sequence_mapping:
    if stt in tup[0]:
        print(tup)

In [None]:
import torch
vid_tensor_normal_anomaly_count = {}

for vid_path, tensor_idx in dataset.video_tensor_idx_mapping:
    # vid_path: (normal, anomaly)
    if vid_path not in vid_tensor_normal_anomaly_count:
        vid_tensor_normal_anomaly_count[vid_path] = (0, 0)
    _, label = dataset._load_frame_tensor(vid_path, tensor_idx)
    
    if torch.equal(label, torch.tensor([1., 0.])):
        vid_tensor_normal_anomaly_count[vid_path] = (vid_tensor_normal_anomaly_count[vid_path][0] + 1, vid_tensor_normal_anomaly_count[vid_path][1])
    else:
        vid_tensor_normal_anomaly_count[vid_path] = (vid_tensor_normal_anomaly_count[vid_path][0], vid_tensor_normal_anomaly_count[vid_path][1] + 1)

In [None]:
vid_tensor_normal_anomaly_count

In [None]:
normal_count = 0
anomaly_count = 0
for tup in vid_tensor_normal_anomaly_count.values():
    normal_count += tup[0]
    anomaly_count += tup[1]

In [None]:
normal_count, anomaly_count