Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 83 additions & 54 deletions bitmind/cache/sampler/video_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import math
import random
import tempfile

from pathlib import Path
from typing import Dict, List, Any, Optional
from io import BytesIO

import ffmpeg
import numpy as np
Expand Down Expand Up @@ -119,73 +121,100 @@ async def _sample_frames(

video_path = random.choice(files[source])

duration = random.uniform(min_duration, max_duration)

try:
if not video_path.exists():
files[source].remove(video_path)
continue

video_info = get_video_metadata(str(video_path))
total_duration = video_info.get("duration", 0)
duration = min(total_duration, duration)

max_start = total_duration - duration
start_time = random.uniform(0, max_start)
try:
video_info = get_video_metadata(str(video_path))
total_duration = video_info.get("duration", 0)
width = int(video_info.get("width", 256))
height = int(video_info.get("height", 256))
reported_fps = float(video_info.get("fps", max_fps))
except Exception as e:
self.cache_fs._log_error(
f"Unable to extract video metadata from {str(video_path)}: {e}"
)
files[source].remove(video_path)
continue

width = int(video_info.get("width", 256))
height = int(video_info.get("height", 256))
reported_fps = float(video_info.get("fps", max_fps))
if reported_fps > max_fps or reported_fps <= 0 or not math.isfinite(reported_fps):
if (
reported_fps > max_fps
or reported_fps <= 0
or not math.isfinite(reported_fps)
):
self.cache_fs._log_warning(
f"Unreasonable FPS ({reported_fps}) detected in {video_path}, capping at {max_fps}"
)
fps = max_fps
frame_rate = max_fps
else:
fps = reported_fps
frame_rate = reported_fps

temp_dir = tempfile.mkdtemp()
try:
# Extract frames as PNGs for v2 parity
temp_frame_dir = os.path.join(temp_dir, "frame%04d.png")
ffmpeg.input(
str(video_path), ss=str(start_time), t=str(duration)
).output(temp_frame_dir, format="image2", vcodec="png").global_args(
"-loglevel", "error"
).global_args(
"-r", str(fps)
).run()

frame_files = sorted(
[f for f in os.listdir(temp_dir) if f.endswith(".png")]
)
target_duration = random.uniform(min_duration, max_duration)
target_duration = min(target_duration, total_duration)

if not frame_files:
self.cache_fs._log_warning(
f"No frames extracted from {video_path}"
num_frames = int(target_duration * frame_rate) + 1

actual_duration = (num_frames - 1) / frame_rate

max_start = max(0, total_duration - actual_duration)
start_time = random.uniform(0, max_start)

frames = []
no_data = []

for i in range(num_frames):
timestamp = start_time + (i / frame_rate)

try:
out_bytes, err = (
ffmpeg.input(str(video_path), ss=str(timestamp))
.filter("select", "eq(n,0)")
.output(
"pipe:",
vframes=1,
format="image2",
vcodec="png",
loglevel="error",
)
.run(capture_stdout=True, capture_stderr=True)
)

if not out_bytes:
no_data.append(timestamp)
continue

try:
frame = Image.open(BytesIO(out_bytes))
frame.load() # Verify image can be loaded
frames.append(np.array(frame))
except Exception as e:
self.cache_fs._log_error(
f"Failed to process frame at {timestamp}s: {e}"
)
continue

except ffmpeg.Error as e:
self.cache_fs._log_error(
f"FFmpeg error at {timestamp}s: {e.stderr.decode()}"
)
files[source].remove(video_path)
continue

frames = []
for frame_file in frame_files:
img = Image.open(os.path.join(temp_dir, frame_file))
frames.append(np.array(img))
if len(no_data) > 0:
tmin, tmax = min(no_data), max(no_data)
self.cache_fs._log_warning(
f"No data received for {len(no_data)} frames between {tmin} and {tmax}"
)

frames = np.stack(frames, axis=0)
num_frames = len(frames)
if not frames:
self.cache_fs._log_warning(
f"No frames successfully extracted from {video_path}"
)
files[source].remove(video_path)
continue

finally:
# Clean up temp directory and files
for file in os.listdir(temp_dir):
try:
os.remove(os.path.join(temp_dir, file))
except:
pass
try:
os.rmdir(temp_dir)
except:
pass
frames = np.stack(frames, axis=0)

if as_float32:
frames = frames.astype(np.float32) / 255.0
Expand Down Expand Up @@ -214,11 +243,11 @@ async def _sample_frames(
"metadata": metadata,
"segment": {
"start_time": start_time,
"duration": duration,
"fps": fps,
"duration": actual_duration,
"fps": frame_rate,
"width": width,
"height": height,
"num_frames": num_frames,
"num_frames": len(frames),
},
}

Expand All @@ -233,7 +262,7 @@ async def _sample_frames(
)

self.cache_fs._log_info(
f"Successfully sampled {duration}s segment from {video_path} ({num_frames} frames)"
f"Successfully sampled {actual_duration}s segment from {video_path} ({len(frames)} frames)"
)
return result

Expand Down