In [None]:
# @title Load dependencies and define utilities

import mediapy
import numpy as np
from PIL import Image


def read_and_preprocess_video(
    filename: str, target_num_frames: int, target_frame_size: tuple[int, int]
):
  """Reads and preprocesses a video."""

  frames = mediapy.read_video(filename)

  # Sample to target number of frames.
  frame_indices = np.linspace(
      0, len(frames), num=target_num_frames, endpoint=False, dtype=np.int32
  )
  frames = np.asarray([frames[i] for i in frame_indices])

  # Resize to target size.
  original_height, original_width = frames.shape[-3:-1]
  target_height, target_width = target_frame_size
  assert (
      original_height * target_width == original_width * target_height
  ), 'Currently does not support aspect ratio mismatch.'
  frames = mediapy.resize_video(frames, shape=target_frame_size)

  # Normalize pixel values to [0.0, 1.0].
  frames = mediapy.to_float01(frames)

  return frames

In [None]:
import os
import glob
import numpy as np
import mediapy
from PIL import Image

def read_and_preprocess_frames_from_folder(
    folder_path: str, 
    target_num_frames: int, 
    target_frame_size: tuple[int, int] = (288, 288), 
    overlap: float = 0.0
):
  """
  Reads images from a folder, resizes them, and splits them into batches with optional overlap.
  
  Args:
      folder_path (str): Path to folder containing image sequences.
      target_num_frames (int): Number of frames per batch.
      target_frame_size (tuple): (height, width) for resizing. Defaults to (288, 288).
      overlap (float): Fraction of overlap between batches [0.0, 1.0). 
                       0.0 = distinct batches, 0.5 = 50% overlap.
  
  Returns:
      np.ndarray: A numpy array of shape (num_batches, target_num_frames, H, W, 3).
  """
  
  # Validate overlap
  if not (0.0 <= overlap < 1.0):
      raise ValueError("Overlap must be >= 0.0 and < 1.0")

  # 1. Find all image files
  extensions = ['*.jpg', '*.jpeg', '*.png']
  image_paths = []
  
  for ext in extensions:
      image_paths.extend(glob.glob(os.path.join(folder_path, ext)))
  
  # 2. Sort naturally to ensure temporal order
  image_paths.sort()

  if not image_paths:
      print(f"[WARN] No images found in {folder_path}")
      return np.array([])
      
  if len(image_paths) < target_num_frames:
      print(f"[WARN] Found {len(image_paths)} frames, but need {target_num_frames} for a batch.")
      return np.array([])

  print(f"Folder: {os.path.basename(folder_path)} | Loading {len(image_paths)} frames...")

  # 3. Load and Resize ALL frames into a continuous buffer
  buffer_frames = []
  target_height, target_width = target_frame_size

  for path in image_paths:
      try:
          img = Image.open(path).convert("RGB")
          img = img.resize((target_width, target_height), Image.Resampling.BILINEAR)
          buffer_frames.append(np.array(img))
      except Exception as e:
          print(f"[ERROR] Could not load {path}: {e}")

  buffer_frames = np.array(buffer_frames) # Shape: (Total_Frames, H, W, 3)

  # 4. Create Batches using Sliding Window
  stride = int(target_num_frames * (1 - overlap))
  stride = max(1, stride) # Prevent infinite loop if overlap is high/frames low
  
  batched_frames = []
  num_frames = len(buffer_frames)
  
  for start_idx in range(0, num_frames, stride):
      end_idx = start_idx + target_num_frames
      
      if end_idx > num_frames:
          break
          
      batch = buffer_frames[start_idx:end_idx]
      batched_frames.append(batch)

  if not batched_frames:
      return np.array([])

  # Stack into final shape: (Num_Batches, Sequence_Length, H, W, C)
  batched_frames = np.stack(batched_frames)

  # 5. Normalize pixel values to [0.0, 1.0]
  batched_frames = mediapy.to_float01(batched_frames)

  print(f"Generated {len(batched_frames)} batches with stride {stride}.")
  
  return batched_frames

In [None]:
# @title Load model

import jax
import jax.numpy as jnp
from videoprism import models as vp

MODEL_NAME = 'videoprism_public_v1_large'  # @param ['videoprism_public_v1_base', 'videoprism_public_v1_large'] {allow-input: false}
USE_BFLOAT16 = False  # @param { type: "boolean" }
NUM_FRAMES = 16
FRAME_SIZE = 288

fprop_dtype = jnp.bfloat16 if USE_BFLOAT16 else None
flax_model = vp.get_model(MODEL_NAME, fprop_dtype=fprop_dtype)
loaded_state = vp.load_pretrained_weights(MODEL_NAME)


@jax.jit
def forward_fn(inputs, train=False):
  return flax_model.apply(loaded_state, inputs, train=train)

In [None]:
VIDEO_FILE_PATH = 'videoprism_repo/videoprism/assets/water_bottle_drumming.mp4'  # @param {type: "string"}

frames = read_and_preprocess_video(
    VIDEO_FILE_PATH,
    target_num_frames=NUM_FRAMES,
    target_frame_size=[FRAME_SIZE, FRAME_SIZE],
)
mediapy.show_video(frames, fps=6.0)

frames = jnp.asarray(frames[None, ...])  # Add batch dimension.
if USE_BFLOAT16:
  frames = frames.astype(jnp.bfloat16)
print(f'Input shape: {frames.shape} [type: {frames.dtype}]')

embeddings, _ = forward_fn(frames)
print(f'Encoded embedding shape: {embeddings.shape} [type: {embeddings.dtype}]')
