In [1]:

import json
import os
from pathlib import Path
from typing import List, Tuple

import av
import numpy as np
from PIL import Image
from tqdm import tqdm
import av
import matplotlib.pyplot as plt
from PIL import Image

from train import POLISH_TO_ENGLISH

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

"""Main preprocessing function."""
    # Configuration
DATA_DIR = "Olympic Boxing Punch Classification Video Dataset"  # Adjust if needed
OUTPUT_DIR = "preprocessed_clips/jitter"
NUM_FRAMES = 16
SIZE = 224

print("="*60)
print("Video Clip Preprocessing")
print("="*60)
print(f"Data directory: {DATA_DIR}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Frames per clip: {NUM_FRAMES}")
print(f"Spatial size: {SIZE}x{SIZE}")
print(f"Output format: uint8 numpy arrays")
print("="*60)

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

Video Clip Preprocessing
Data directory: Olympic Boxing Punch Classification Video Dataset
Output directory: preprocessed_clips/jitter
Frames per clip: 16
Spatial size: 224x224
Output format: uint8 numpy arrays


In [3]:
data = {}

task_dirs = os.listdir(DATA_DIR)
for task_dir in task_dirs:
    task_path = os.path.join(DATA_DIR, task_dir)
    annotations = os.path.join(task_path, "annotations.json")
    
    video_dir = os.path.join(task_path, "data")
    video_files = os.listdir(video_dir)
    for video_file in video_files:
        video_path = os.path.join(video_dir, video_file)
        break
    data[video_path] = []
    with open(annotations, 'r') as f:
        annotations = json.load(f)
    for annotation in annotations:
        tracks = annotation["tracks"]
        for track in tracks:
            frame = track["frame"]
            label = track["label"]
            data[video_path].append({
                "frame": frame,
                "label": POLISH_TO_ENGLISH[label],
                "shapes": track["shapes"]
            })
     
data

{'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh199681/data/GH199681.mp4': [{'frame': 288,
   'label': 'LHHP',
   'shapes': [{'type': 'rectangle',
     'occluded': False,
     'z_order': 0,
     'points': [895.9000000000015, 412.0, 1044.2998046875, 509.0],
     'frame': 288,
     'outside': False,
     'attributes': []},
    {'type': 'rectangle',
     'occluded': False,
     'z_order': 0,
     'points': [895.9000000000015,
      416.298828125,
      1042.8984375,
      517.4000000000015],
     'frame': 289,
     'outside': False,
     'attributes': []},
    {'type': 'rectangle',
     'occluded': False,
     'z_order': 0,
     'points': [908.5009765625,
      412.099609375,
      1055.4994140624985,
      513.2007812500015],
     'frame': 290,
     'outside': False,
     'attributes': []},
    {'type': 'rectangle',
     'occluded': False,
     'z_order': 0,
     'points': [911.5, 415.0986328125, 1058.4984374999985, 516.1998046875015],
     'frame': 291,
     'outside': T

In [4]:
video_paths = list(data.keys())
video_paths

['Olympic Boxing Punch Classification Video Dataset/task_kam4_gh199681/data/GH199681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh189681/data/GH189681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh179681/data/GH179681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh169681/data/GH169681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh159681/data/GH159681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh149681/data/GH149681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh139681/data/GH139681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh129681/data/GH129681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh119681/data/GH119681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh109681/data/GH109681.mp4',
 'Olympic Boxing Punch Classification Video Dataset/task_kam4_gh099681/data/GH099681.mp4',

In [5]:
# import matplotlib.patches as patches

# video_path = list(data.keys())[1]
# frame_index = data[video_path][0]["frame"]
# label = data[video_path][0]["label"]
# bbox = data[video_path][0]["bbox"]

# container = av.open(video_path)

# video_stream = container.streams.video[0]
# for i,frame in enumerate(container.decode(video_stream)):
#     if i==frame_index:
#         img = frame.to_ndarray(format='rgb24')
#         break
    
# container.close()

# # Display image with bounding box overlay
# fig, ax = plt.subplots(figsize=(10, 6))
# ax.imshow(img)

# # Extract bbox coordinates
# x1, y1, x2, y2 = bbox
# width = x2 - x1
# height = y2 - y1

# # Create rectangle patch
# rect = patches.Rectangle(
#     (x1, y1), width, height,
#     linewidth=2,
#     edgecolor='red',
#     facecolor='none'
# )

# # Add the patch to the axes
# ax.add_patch(rect)

# ax.axis('off')
# ax.set_title(f'Frame {frame_index} - {label}')
# plt.show()



In [6]:
# def pad_image(img):
#     h, w = img.shape[:2]

#     # Determine the size of the square (use the larger dimension)
#     max_dim = max(h, w)

#     # Calculate padding needed
#     if h < w:
#         # Pad top and bottom
#         pad_total = w - h
#         pad_top = pad_total // 2
#         pad_bottom = pad_total - pad_top
#         pad_left = 0
#         pad_right = 0
#     else:
#         # Pad left and right
#         pad_total = h - w
#         pad_left = pad_total // 2
#         pad_right = pad_total - pad_left
#         pad_top = 0
#         pad_bottom = 0

#     # Pad the image with zeros (black) or you can use other values
#     padded_img = np.pad(img, 
#                         ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)),
#                         mode='constant', 
#                         constant_values=0)
#     return padded_img

# h, w = img.shape[:2]

# padded_img = pad_image(img)

# # Display the original and padded images
# fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# axes[0].imshow(img)
# axes[0].set_title(f'Original Image ({h}x{w})')
# axes[0].axis('off')

# axes[1].imshow(padded_img)
# axes[1].set_title(f'Padded Square Image ({padded_img.shape[0]}x{padded_img.shape[1]})')
# axes[1].axis('off')

# plt.tight_layout()
# plt.show()

# print(f"Original shape: {img.shape}")
# print(f"Padded shape: {padded_img.shape}")

In [7]:


# video_path = list(data.keys())[1]
# frame_index = data[video_path][0]["frame"]
# label = data[video_path][0]["label"]

# container = av.open(video_path)

# imgs = []
# video_stream = container.streams.video[0]
# for i,frame in enumerate(container.decode(video_stream)):
#     if i>=frame_index-8 and i<=frame_index+8:
#         img = frame.to_ndarray(format='rgb24')
#         imgs.append(img)
    
# container.close()



In [8]:
# imgs = [pad_image(img) for img in imgs]

# # Multiple frames (e.g., 16 frames in a grid)
# # your list of frame arrays
# fig, axes = plt.subplots(4, 4, figsize=(12, 12))
# for idx, ax in enumerate(axes.flat):
#     ax.imshow(imgs[idx])
#     ax.axis('off')
# plt.tight_layout()
# plt.show()


In [4]:
import numpy as np

def save_frames_to_mp4(frames, output_path, fps=30):
    """
    Save a list/array of RGB frames to an MP4 file.
    
    Args:
        frames: List or numpy array of frames, each shape (H, W, 3) uint8
        output_path: Path to save MP4 file
        fps: Frames per second (default 30)
    """
    # Convert to numpy array if it's a list
    if isinstance(frames, list):
        frames = np.stack(frames)
    
    # Get dimensions from first frame
    height, width = frames[0].shape[:2]
    
    # Create output container
    container = av.open(output_path, mode='w')
    stream = container.add_stream('h264', rate=fps)
    stream.width = width
    stream.height = height
    stream.pix_fmt = 'yuv420p'
    
    # Write frames
    for frame_array in frames:
        frame = av.VideoFrame.from_ndarray(frame_array, format='rgb24')
        for packet in stream.encode(frame):
            container.mux(packet)
    
    # Flush remaining packets
    for packet in stream.encode():
        container.mux(packet)
    
    container.close()

# Usage:
# save_frames_to_mp4(imgs, 'output_clip.mp4', fps=10)


In [10]:
# # Resize images using PIL and convert back to ndarray
# resized_frames = []
# for img in imgs:
#     # Convert numpy array to PIL Image
#     pil_img = Image.fromarray(img)
    
#     # Resize to target size (SIZE x SIZE)
#     pil_img_resized = pil_img.resize((SIZE, SIZE), Image.BILINEAR)
    
#     # Convert back to numpy array
#     img_array = np.array(pil_img_resized)
    
#     resized_frames.append(img_array)

# clip = np.stack(resized_frames)
# clip = clip.astype(np.uint8)
# output_path = f"preprocessed_clips/train/{label}/clip_{1}"
# os.makedirs(os.path.dirname(output_path), exist_ok=True)
# np.save(output_path, clip)

In [6]:
all_paths = list(data.keys())
count = 0
for video_path in all_paths:
    count += len(data[video_path]) * 16
count


73744

In [7]:
def flush_batch(clip_batch, batch_metadata):
    for clip, (split,label, clip_num) in zip(clip_batch, batch_metadata):
        output_path = f"preprocessed_clips/{split}/{label}/clip_{clip_num}"
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        np.save(output_path, clip)
    print("Flushed batch")


In [5]:
def seek_and_decode_range(container, start_frame, end_frame):
    video_stream = container.streams.video[0]
    framerate = video_stream.average_rate
    time_base = video_stream.time_base
    
    sec = start_frame / float(framerate)
    # https://pyav.org/docs/9.0.2/api/container.html
    # av.time_base = fraction
    container.seek(int(sec / time_base), backward=True, stream=video_stream)
    
    # gets start_frame to end_frame inclusive
    frames_dict = {}
    for frame in container.decode(video_stream):
        pts = frame.pts
        frame_index = int(pts*time_base*framerate)
        if frame_index > end_frame:
            break
        if frame_index >= start_frame:
            frames_dict[frame_index] = frame.to_ndarray(format='rgb24')
            
    return frames_dict
    

In [14]:
# video_path = list(data.keys())[1]
# container = av.open(video_path)

# seek_and_decode_range(container, 500, 505)


In [9]:
import time

# profiling accumulators
prof_seek_decode = 0.0
prof_seek_decode_n = 0

prof_preprocess = 0.0
prof_preprocess_n = 0

prof_flush = 0.0
prof_flush_n = 0


In [6]:
def flush_batch_2(clip_batch, batch_metadata):
    for clip, out_path in zip(clip_batch, batch_metadata):
        os.makedirs(os.path.dirname(out_path), exist_ok=True)
        np.save(out_path, clip)
    print("Flushed batch")


In [None]:
# Assume a square window with clip_height sides
# If the window does not contain the bbox, shift it until it does
# Try to leave a margin between the window and the bbox
BATCH_SIZE = 1500  # ~3.6GB per batch
CLIPS_PER_VIDEO = 10  # After processing this many clips from a video, cull frames_dict

MARGIN = 100
SPATIAL_STEP = 100

skipped_clips = []

clip_batch = []
batch_metadata = []  # track (split, label, clip_number) for each clip

count = 0
for video_path in list(data.keys()):
    print(video_path)
    
    container = av.open(video_path)
    frames_dict = {}  # Persistent frames_dict for this video
    clips_processed_for_video = 0

    for clip_data in data[video_path]:
     
        frame_index = clip_data["frame"]
        label = clip_data["label"]
        shapes = clip_data["shapes"]

        shape_frames = [shape_data["frame"] for shape_data in shapes]
        start_frame = max((frame_index-5, 0)) # -5 becuase too often it seems the start of 'tracks' contain a cut
        end_frame = min(max(shape_frames) + 7, container.streams.video[0].frames - 1) # need to make sure it's not past the end of the video

        if end_frame-start_frame+1 < 16:
            print(f"Adjusting {video_path} {frame_index} because frames_count < 16, actually {end_frame-start_frame+1}")
            start_frame = max(0, start_frame-2) # there are magic numbers
            end_frame = min(max(shape_frames) + 9, container.streams.video[0].frames - 1) # need to make sure it's not past the end of the video
        
        # Check which frames we need but don't have yet
        needed_frames = set(range(start_frame, end_frame + 1))
        missing_frames = needed_frames - frames_dict.keys()
        
        if missing_frames:
            # Seek and decode only the missing range
            min_missing = min(missing_frames)
            max_missing = max(missing_frames)
            new_frames = seek_and_decode_range(container, min_missing, max_missing)
            frames_dict.update(new_frames)

        img_width = frames_dict[frame_index].shape[1] # 1920
        img_height = frames_dict[frame_index].shape[0] # 1080

        min_bbox_x = min([int(shape_data["points"][0]) for shape_data in shapes])
        max_bbox_x = max([int(shape_data["points"][2]) for shape_data in shapes])

        img = frames_dict[frame_index]

        # Calculate spatial range
        x_1 = min(max(max_bbox_x + MARGIN - img_height, 0), 1920-img_height)
        x_2 = min(min_bbox_x - MARGIN, 1920-img_height)
        x_2 = max(x_2, 0)
        
        min_x = min(x_1, x_2)
        max_x = max(x_1, x_2)
        
        mid_x = (min_x + max_x) // 2
        jitter = int(max_x - mid_x * 0.1)
        
        spatial_shifts = []
        if jitter > 0:
            spatial_shifts = [mid_x-jitter, mid_x, mid_x+jitter]
        else:
            spatial_shifts = [mid_x]
    
        # print(frame_index)
        # print('label',label)
        # print('bbox',min_bbox_x, max_bbox_x)
        # print("spatial ",min_x, max_x, max_x-min_x)
        # print('\n')
        
        task_name = video_path.split("/")[-3]
        
        windows = []
        frames_count = end_frame-start_frame+1
        if frames_count > 16:
            sf1, ef1 = start_frame, start_frame+15
            
            diff = frames_count - 16
            sf2, ef2 = start_frame+diff, start_frame+diff+15
            
            windows.append((sf1, ef1))
            windows.append((sf2, ef2))
            
        elif frames_count != 16:
            print(f"Skipping {video_path} {frame_index} because frames_count != 16, actually {frames_count}")
            skipped_clips.append((video_path, frame_index))
            continue
        else:
            windows.append((start_frame, end_frame))
        

        for spatial_min_x in range(min_x, max_x+1, SPATIAL_STEP):
            for temporal_window_i, (sf, ef) in enumerate(windows):
                imgs = []
                for f_i in range(sf, ef+1):
                    img = frames_dict[f_i]
                    cropped_img = img[ :img_height, spatial_min_x:spatial_min_x+img_height]
                    imgs.append(cropped_img)
                out_path = f"preprocessed_clips_3/{label}/clip_{task_name}_{frame_index}_{temporal_window_i}_{spatial_min_x}.npy"
                # print(out_path)

                pil_imgs = [Image.fromarray(img) for img in imgs]
                pil_imgs_resized = [img.resize((SIZE, SIZE), Image.BILINEAR) for img in pil_imgs]
                resized_frames = [np.array(img) for img in pil_imgs_resized]
                
                clip = np.stack(resized_frames).astype(np.uint8)
                
                clip_batch.append(clip)
                batch_metadata.append(out_path)
                count += 1
                print(count, end="\r")
        clips_processed_for_video += 1

        if len(clip_batch) >= BATCH_SIZE:
            flush_batch_2(clip_batch, batch_metadata)

            clip_batch = []
            batch_metadata = []
        
        
        # Cull frames_dict periodically to manage memory
        if clips_processed_for_video % CLIPS_PER_VIDEO == 0:
            frames_dict = {}
            
    print('skipped clips', len(skipped_clips))
    container.close()
    flush_batch_2(clip_batch, batch_metadata)
    clip_batch = []
    batch_metadata = []
    



Olympic Boxing Punch Classification Video Dataset/task_kam4_gh199681/data/GH199681.mp4
288
label LHHP
bbox 895 1058
spatial  78 795 717


preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_78.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_88.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_98.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_108.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_118.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_128.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_138.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_148.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_158.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_168.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_178.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_188.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_198.npy
preprocessed_clips_3/LHHP/clip_task_kam4_gh199681_0_208.npy
preprocessed_clips_3/LHHP

KeyboardInterrupt: 

In [None]:
flush_batch_2(clip_batch, batch_metadata)


In [None]:


imgs = []
for frame_index in range(start_frame, end_frame+1):
    img = frames_dict[frame_index]

    # now extract a square window, 1080x1080
    cropped_img = img[min_x:min_x+img_height, :img_height]
    imgs.append(cropped_img)

save_frames_to_mp4(imgs, 'central_crop.mp4', fps=15)


In [13]:
import av
import numpy as np
import os
SPATIAL_STEP = 100

def visualize_spatial_temporal_jitter(video_path, clip_index, data, MARGIN=100):
    """
    Visualize the spatial and temporal jittering for a specific clip.
    
    Args:
        video_path: Path to the video file
        clip_index: Index of the clip in data[video_path]
        data: The data dictionary containing clip information
        MARGIN: Margin around bbox (default 100)
    """
    os.makedirs('spatial_temporal_tmp', exist_ok=True)
    
    # 15 vibrant colors (RGB)
    vibrant_colors = [
        [255, 0, 0],      # Red
        [0, 255, 0],      # Lime
        [0, 0, 255],      # Blue
        [255, 255, 0],    # Yellow
        [255, 0, 255],    # Magenta
        [0, 255, 255],    # Cyan
        [255, 128, 0],    # Orange
        [128, 0, 255],    # Purple
        [255, 0, 128],    # Hot Pink
        [0, 255, 128],    # Spring Green
        [128, 255, 0],    # Chartreuse
        [0, 128, 255],    # Sky Blue
        [255, 192, 203],  # Pink
        [255, 165, 0],    # Deep Orange
        [75, 0, 130],     # Indigo
    ]
    
    clip_data = data[video_path][clip_index]
    
    frame_index = clip_data["frame"]
    label = clip_data["label"]
    shapes = clip_data["shapes"]
    
    container = av.open(video_path)
    
    # Calculate frame range
    shape_frames = [shape_data["frame"] for shape_data in shapes]
    start_frame = max((frame_index-5, 0))
    end_frame = min(max(shape_frames) + 7, container.streams.video[0].frames - 1)
    
    if end_frame-start_frame+1 < 16:
        start_frame = max(0, start_frame-2)
        end_frame = min(max(shape_frames) + 9, container.streams.video[0].frames - 1)
    
    # Decode frames
    frames_dict = seek_and_decode_range(container, start_frame, end_frame)
    
    img_height = frames_dict[frame_index].shape[0]  # 1080
    
    # Calculate bbox range
    min_bbox_x = min([int(shape_data["points"][0]) for shape_data in shapes])
    max_bbox_x = max([int(shape_data["points"][2]) for shape_data in shapes])
    
    # Calculate spatial range
    x_1 = min(max(max_bbox_x + MARGIN - img_height, 0), 1920-img_height)
    x_2 = min(min_bbox_x - MARGIN, 1920-img_height)
    x_2 = max(x_2, 0)
    
    min_x = min(x_1, x_2)
    max_x = max(x_1, x_2)
    
    mid_x = (min_x + max_x) // 2
    jitter = int((max_x - mid_x) * 0.2)
    
    spatial_shifts = []
    if jitter > 0:
        spatial_shifts = [mid_x-jitter, mid_x, mid_x+jitter]
    else:
        spatial_shifts = [mid_x]
    
    # Collect all frames with overlays
    output_frames = []
    for f_i in range(start_frame, end_frame + 1):
        img = frames_dict[f_i].copy()
        
        # Draw all possible crop boxes with different colors
        color_idx = 0
        for spatial_min_x in spatial_shifts:
            crop_x1 = spatial_min_x
            crop_x2 = spatial_min_x + img_height
            
            # Get color from the list, cycling through
            color = vibrant_colors[color_idx % len(vibrant_colors)]
            
            # Draw crop box
            thickness = 3
            img[0:thickness, crop_x1:crop_x2] = color  # Top
            img[img_height-thickness:img_height, crop_x1:crop_x2] = color  # Bottom
            img[0:img_height, crop_x1:crop_x1+thickness] = color  # Left
            img[0:img_height, crop_x2-thickness:crop_x2] = color  # Right
            
            color_idx += 1
        
        # Draw bbox in white (high contrast)
        for shape_data in shapes:
            if shape_data["frame"] == f_i:
                x1 = int(shape_data["points"][0])
                y1 = int(shape_data["points"][1])
                x2 = int(shape_data["points"][2])
                y2 = int(shape_data["points"][3])
                
                # Draw bbox (white)
                thickness = 5
                img[y1:y1+thickness, x1:x2] = [255, 255, 255]  # Top
                img[y2-thickness:y2, x1:x2] = [255, 255, 255]  # Bottom
                img[y1:y2, x1:x1+thickness] = [255, 255, 255]  # Left
                img[y1:y2, x2-thickness:x2] = [255, 255, 255]  # Right
        
        output_frames.append(img)
    
    container.close()
    
    # Save output
    task_name = video_path.split("/")[-3]
    output_path = f'spatial_temporal_tmp/{label}_{task_name}_clip{clip_index}.mp4'
    save_frames_to_mp4(output_frames, output_path, fps=30)
    
    print(f"Saved visualization to {output_path}")
    print(f"Frames: {start_frame} to {end_frame} ({len(output_frames)} frames)")
    print(f"Label: {label}")
    print(f"Bbox range: {min_bbox_x} to {max_bbox_x}")
    print(f"Spatial crop range: {min_x} to {max_x} (jitter={jitter})")
    
    return output_path

# Usage example:
# video_path = list(data.keys())[0]
# visualize_spatial_temporal_jitter(video_path, clip_index=0, data=data)

paths_len = len(list(data.keys()))
for i in range(0,paths_len, 10):
    video_path = list(data.keys())[i]
    clips_len = len(data[video_path])
    for j in range(0,clips_len, 20):
        clip_index = j
        visualize_spatial_temporal_jitter(video_path, clip_index, data)


Saved visualization to spatial_temporal_tmp/LHHP_task_kam4_gh199681_clip0.mp4
Frames: 283 to 298 (16 frames)
Label: LHHP
Bbox range: 895 to 1058
Spatial crop range: 78 to 795 (jitter=71)
Saved visualization to spatial_temporal_tmp/LHHP_task_kam4_gh199681_clip20.mp4
Frames: 1435 to 1455 (21 frames)
Label: LHHP
Bbox range: 951 to 1127
Spatial crop range: 147 to 840 (jitter=69)
Saved visualization to spatial_temporal_tmp/RHBlP_task_kam4_gh199681_clip40.mp4
Frames: 2843 to 2861 (19 frames)
Label: RHBlP
Bbox range: 983 to 1087
Spatial crop range: 107 to 840 (jitter=73)
Saved visualization to spatial_temporal_tmp/RHHP_task_kam4_gh199681_clip60.mp4
Frames: 5395 to 5410 (16 frames)
Label: RHHP
Bbox range: 1000 to 1159
Spatial crop range: 179 to 840 (jitter=66)
Saved visualization to spatial_temporal_tmp/LHHP_task_kam4_gh199681_clip80.mp4
Frames: 7431 to 7449 (19 frames)
Label: LHHP
Bbox range: 721 to 852
Spatial crop range: 0 to 621 (jitter=62)
Saved visualization to spatial_temporal_tmp/RHHP_

In [7]:
import multiprocessing as mp
from functools import partial
from tqdm import tqdm
import glob


# Assume a square window with clip_height sides
# If the window does not contain the bbox, shift it until it does
# Try to leave a margin between the window and the bbox
BATCH_SIZE = 500  # ~3.6GB per batch
CLIPS_PER_VIDEO = 10  # After processing this many clips from a video, cull frames_dict

MARGIN = 100
SPATIAL_STEP = 250
NUM_WORKERS = 10  # Number of parallel processes

def process_single_video(video_path, video_clips, MARGIN, SPATIAL_STEP, SIZE, CLIPS_PER_VIDEO, BATCH_SIZE):
    """Process a single video, flushing batches as we go to minimize memory usage."""
    skipped_clips = []
    clip_batch = []
    batch_metadata = []

    container = av.open(video_path)
    frames_dict = {}  # Persistent frames_dict for this video
    clips_processed_for_video = 0
    
    total_clips_for_video = 0
    
    # Progress bar for this video's clip-level processing
    video_name = video_path.split("/")[-3]

    for clip_idx, clip_data in enumerate(video_clips):
        frame_index = clip_data["frame"]
        label = clip_data["label"]
        shapes = clip_data["shapes"]
        
        # Skip if this clip has already been processed
        task_name = video_path.split("/")[-3]
        check_path = f"preprocessed_clips_3/{label}/clip_{task_name}_{frame_index}_0_*.npy"
        if glob.glob(check_path):
            # print(f"Skipping {video_path} {frame_index} because it has already been processed")
            continue

        shape_frames = [shape_data["frame"] for shape_data in shapes]
        start_frame = max((frame_index-5, 0)) # -5 becuase too often it seems the start of 'tracks' contain a cut
        end_frame = min(max(shape_frames) + 7, container.streams.video[0].frames - 1) # need to make sure it's not past the end of the video

        if end_frame-start_frame+1 < 16:
            start_frame = max(0, start_frame-2) # there are magic numbers
            end_frame = min(max(shape_frames) + 9, container.streams.video[0].frames - 1) # need to make sure it's not past the end of the video

        # Check which frames we need but don't have yet
        needed_frames = set(range(start_frame, end_frame + 1))
        missing_frames = needed_frames - frames_dict.keys()

        if missing_frames:
            # Seek and decode only the missing range
            min_missing = min(missing_frames)
            max_missing = max(missing_frames)
            new_frames = seek_and_decode_range(container, min_missing, max_missing)
            frames_dict.update(new_frames)

        img_width = frames_dict[frame_index].shape[1] # 1920
        img_height = frames_dict[frame_index].shape[0] # 1080

        min_bbox_x = min([int(shape_data["points"][0]) for shape_data in shapes])
        max_bbox_x = max([int(shape_data["points"][2]) for shape_data in shapes])

        img = frames_dict[frame_index]

        # Calculate spatial range
        x_1 = min(max(max_bbox_x + MARGIN - img_height, 0), 1920-img_height)
        x_2 = min(min_bbox_x - MARGIN, 1920-img_height)
        x_2 = max(x_2, 0)

        min_x = min(x_1, x_2)
        max_x = max(x_1, x_2)
        
        mid_x = (min_x + max_x) // 2
        jitter = int((max_x - mid_x) * 0.2)
        
        spatial_shifts = []
        if jitter > 0:
            spatial_shifts = [mid_x-jitter, mid_x, mid_x+jitter]
        else:
            spatial_shifts = [mid_x]

        task_name = video_path.split("/")[-3]

        windows = []
        frames_count = end_frame-start_frame+1
        if frames_count > 16:
            sf1, ef1 = start_frame, start_frame+15

            diff = frames_count - 16
            sf2, ef2 = start_frame+diff, start_frame+diff+15

            windows.append((sf1, ef1))
            windows.append((sf2, ef2))

        elif frames_count != 16:
            skipped_clips.append((video_path, frame_index))
            continue
        else:
            windows.append((start_frame, end_frame))

        for spatial_min_x in spatial_shifts:
            for temporal_window_i, (sf, ef) in enumerate(windows):
                imgs = []
                for f_i in range(sf, ef+1):
                    img = frames_dict[f_i]
                    cropped_img = img[ :img_height, spatial_min_x:spatial_min_x+img_height]
                    imgs.append(cropped_img)
                out_path = f"preprocessed_clips_3/{label}/clip_{task_name}_{frame_index}_{temporal_window_i}_{spatial_min_x}.npy"

                pil_imgs = [Image.fromarray(img) for img in imgs]
                pil_imgs_resized = [img.resize((SIZE, SIZE), Image.BILINEAR) for img in pil_imgs]
                resized_frames = [np.array(img) for img in pil_imgs_resized]

                clip = np.stack(resized_frames).astype(np.uint8)

                clip_batch.append(clip)
                batch_metadata.append(out_path)
                total_clips_for_video += 1

        clips_processed_for_video += 1
        
        # Print progress every 50 clips
        if (clip_idx + 1) % 10 == 0:
            print(f"{video_name}: {clip_idx + 1}/{len(video_clips)} clips processed ({total_clips_for_video} output clips generated)")

        # Flush batch when it reaches BATCH_SIZE
        if len(clip_batch) >= BATCH_SIZE:
            flush_batch_2(clip_batch, batch_metadata)
            clip_batch = []
            batch_metadata = []

        # Cull frames_dict periodically to manage memory
        if clips_processed_for_video % CLIPS_PER_VIDEO == 0:
            frames_dict = {}

    # Flush any remaining clips for this video
    if clip_batch:
        flush_batch_2(clip_batch, batch_metadata)

    container.close()
    
    print(f"{video_name}: COMPLETE - {len(video_clips)} clips processed, {total_clips_for_video} output clips generated")

    return {
        'video_path': video_path,
        'total_clips': total_clips_for_video,
        'skipped': skipped_clips
    }


In [8]:

# Parallel processing
# Prepare video processing arguments
video_paths = list(data.keys())

# Create a partial function with fixed parameters
process_func = partial(
    process_single_video,
    MARGIN=MARGIN,
    SPATIAL_STEP=SPATIAL_STEP,
    SIZE=SIZE,
    CLIPS_PER_VIDEO=CLIPS_PER_VIDEO,
    BATCH_SIZE=BATCH_SIZE
)

# Prepare arguments: (video_path, video_clips) tuples
video_args = [(vp, data[vp]) for vp in video_paths]

# Process videos in parallel
all_skipped_clips = []
total_clips_saved = 0

print(f"Processing {len(video_paths)} videos with {NUM_WORKERS} workers...")
print(f"Total annotations to process: {sum(len(data[vp]) for vp in video_paths)}\n")

with mp.Pool(NUM_WORKERS) as pool:
    # Use imap to get results as they complete for better progress tracking
    results = []
    with tqdm(total=len(video_paths), desc="Videos completed", unit="video") as pbar:
        for result in pool.starmap(process_func, video_args):
            results.append(result)
            pbar.update(1)

# Collect and report results
print("\n" + "="*60)
print("Processing complete!")
print("="*60)
for result in results:
    video_path = result['video_path']
    clips_count = result['total_clips']
    skipped = result['skipped']

    total_clips_saved += clips_count
    all_skipped_clips.extend(skipped)

print(f"\nTotal clips saved: {total_clips_saved}")
print(f"Total skipped clips: {len(all_skipped_clips)}")


Processing 29 videos with 10 workers...
Total annotations to process: 4609



Videos completed:   0%|          | 0/29 [00:00<?, ?video/s]

task_kam4_gh179681: COMPLETE - 77 clips processed, 0 output clips generated
task_kam4_gh189681: COMPLETE - 100 clips processed, 0 output clips generated
task_kam4_gh109681: COMPLETE - 141 clips processed, 0 output clips generated
task_kam4_gh199681: COMPLETE - 148 clips processed, 0 output clips generated
task_kam4_gh169681: COMPLETE - 202 clips processed, 0 output clips generated
task_kam2_gh218416: COMPLETE - 31 clips processed, 0 output clips generated
task_kam2_gh208416: COMPLETE - 72 clips processed, 0 output clips generated
task_kam4_gh079681: COMPLETE - 199 clips processed, 0 output clips generated
task_kam2_gh198416: COMPLETE - 101 clips processed, 0 output clips generated
task_kam4_gh149681: 100/248 clips processed (9 output clips generated)
task_kam4_gh159681: 100/279 clips processed (15 output clips generated)
task_kam4_gh129681: 100/253 clips processed (33 output clips generated)
task_kam4_gh099681: 100/274 clips processed (33 output clips generated)
task_kam4_gh119681: 100

Videos completed: 100%|██████████| 29/29 [07:46<00:00, 16.08s/video]  



Processing complete!

Total clips saved: 9795
Total skipped clips: 1


In [15]:
glob.glob("/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_*.npy")

['/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_750.npy',
 '/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_650.npy',
 '/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_550.npy',
 '/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_450.npy',
 '/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_350.npy',
 '/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_250.npy',
 '/workspace/facts/preprocessed_clips_3/RHBlP/clip_task_kam4_gh119681_1039_0_150.npy']

In [10]:
files = list(Path('preprocessed_clips_3').glob('*/clip_task_kam2_gh108416_*'))
files

[PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6784_0_480.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6784_0_414.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6784_0_348.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6305_1_568.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6305_0_568.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6305_1_500.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6305_0_500.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6305_1_432.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_6305_0_432.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_5559_1_253.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_5559_0_253.npy'),
 PosixPath('preprocessed_clips_3/RHMP/clip_task_kam2_gh108416_5559_1_211.npy'),
 PosixPath('preprocessed_clips_3/RHMP/cl

In [None]:
from collections import defaultdict

BATCH_SIZE = 1500  # ~3.6GB per batch
clip_batch = []
batch_metadata = []  # track (split, label, clip_number) for each clip

count = 0
for video_path in data.keys():
    
    
    # Count how many times each frame is used in this video
    frame_usage_count = defaultdict(int)
    for clip_data in data[video_path]:
        frame_index = clip_data["frame"]
        for offset in range(-8, 8): # -8 to +7 - 16 frames
            frame_usage_count[frame_index + offset] += 1
    #collect the frame ranges we need to decode
    frame_ranges = []
    for clip_data in data[video_path]:
        frame_index = clip_data["frame"]
        start_frame = frame_index - 8
        end_frame = frame_index + 7
        frame_ranges.append((start_frame, end_frame))
    frame_ranges.sort(key=lambda x: x[0])
    
    # decode each range and store them all in frames_dict
    container = av.open(video_path)
    frames_dict ={}
    
    
    # for i, (start_frame, end_frame) in enumerate(frame_ranges):
    #     print(f"Decoding range {i+1} of {len(frame_ranges)}, {end_frame-start_frame+1} frames", end="\r")
    #     range_frames = seek_and_decode_range(container, start_frame, end_frame)
    #     frames_dict.update(range_frames)
    # container.close()
    
    # Now process all clips from this video
    for clip_data in data[video_path]:
       
        
        frame_index = clip_data["frame"]
        label = clip_data["label"]
        
        # -------------------------------
        # 1. SEEK + DECODE (missing frames only)
        # -------------------------------
        needed_frames = [frame_index + o for o in range(-8, 8)]
        missing_frames = [f for f in needed_frames if f not in frames_dict]
        
        if missing_frames:
            start = min(missing_frames)
            end   = max(missing_frames)
            
            if start < 0:
                continue # skip if start is before the first frame
            
            t0 = time.perf_counter()
            new_frames = seek_and_decode_range(container, start, end)
            t1 = time.perf_counter()
            
            prof_seek_decode += (t1 - t0)
            prof_seek_decode_n += 1
            
            frames_dict.update(new_frames)
        
        try:
            imgs = [frames_dict[f] for f in needed_frames]
        except:
            print(f"Missing frames for {video_path} {frame_index}")
            
            continue

        # -------------------------------
        # 2. PREPROCESS (pad + resize + stack)
        # -------------------------------
        t0 = time.perf_counter()
        
        imgs = [pad_image(img) for img in imgs]
        pil_imgs = [Image.fromarray(img) for img in imgs]
        pil_imgs_resized = [img.resize((SIZE, SIZE), Image.BILINEAR) for img in pil_imgs]
        resized_frames = [np.array(img) for img in pil_imgs_resized]
        
        clip = np.stack(resized_frames).astype(np.uint8)

        t1 = time.perf_counter()
        
        prof_preprocess += (t1 - t0)
        prof_preprocess_n += 1

        # -------------------------------
        # 3. BATCH FLUSH (only when needed)
        # -------------------------------
        clip_batch.append(clip)
        batch_metadata.append(("train", label, count))

        if len(clip_batch) >= BATCH_SIZE:
            t0 = time.perf_counter()
            flush_batch(clip_batch, batch_metadata)
            t1 = time.perf_counter()

            prof_flush += (t1 - t0)
            prof_flush_n += 1

            clip_batch = []
            batch_metadata = []

        count += 1

        # -------------------------------
        # PRINT EVERY 50 clips
        # -------------------------------
        if count % 10 == 0:
            print("\n--- Profiling summary ---")
            if prof_seek_decode_n > 0:
                print(f"Seek/Decode avg: {prof_seek_decode / prof_seek_decode_n:.4f} sec")
            if prof_preprocess_n > 0:
                print(f"Preprocess avg:  {prof_preprocess / prof_preprocess_n:.4f} sec")
            if prof_flush_n > 0:
                print(f"Batch flush avg: {prof_flush / prof_flush_n:.4f} sec")
            print(f"Processed {count} clips\n")

    # Free memory after processing each video
    del frames_dict
    container.close()
    

In [None]:
os.listdir("preprocessed_clips/train/")
greatest_clip_num = 0
for label in os.listdir("preprocessed_clips/train/"):
    for filename in os.listdir(f"preprocessed_clips/train/{label}"):
        # get the clip number
        clip_num = str(filename.split("_")[-1])
        clip_num = int(clip_num.strip(".npy"))
        greatest_clip_num = max(greatest_clip_num, clip_num)
        
greatest_clip_num

In [None]:
flush_batch(clip_batch, batch_metadata)


In [None]:
import av
import numpy as np
import matplotlib.patches as patches
from matplotlib import pyplot as plt

# Video path and frame range
video_path = "/workspace/facts/Olympic Boxing Punch Classification Video Dataset/task_kam2_gh078416/data/GH078416.mp4"
start_frame = 653 - 8  # 645
end_frame = 658 + 8    # 666

# Bounding boxes data - map frame number to bbox coordinates
bboxes = {
    653: [1062.2001953125, 421.16015625, 1205.1201171875, 496.8000000000011],
    655: [1067.0009765625, 425.9599609375, 1209.9208984375, 501.5998046875011],
    656: [1071.8017578125, 424.759765625, 1214.7216796875, 500.3996093750011],
    657: [1076.6015625, 424.759765625, 1219.521484375, 513.6000000000004],
    658: [1082.6015625, 424.759765625, 1225.521484375, 528.0]
}

# Extract frames using seek_and_decode_range
container = av.open(video_path)
frames_dict = seek_and_decode_range(container, start_frame, end_frame)
container.close()

# Draw bounding boxes on frames that have them
frames_with_boxes = []
for frame_num in range(start_frame, end_frame + 1):
    if frame_num in frames_dict:
        img = frames_dict[frame_num].copy()
        
        # Draw bounding box if this frame has one
        if frame_num in bboxes:
            x1, y1, x2, y2 = bboxes[frame_num]
            # Draw rectangle directly on numpy array
            # Top edge
            img[int(y1):int(y1)+2, int(x1):int(x2)] = [255, 0, 0]  # Red
            # Bottom edge
            img[int(y2)-2:int(y2), int(x1):int(x2)] = [255, 0, 0]  # Red
            # Left edge
            img[int(y1):int(y2), int(x1):int(x1)+2] = [255, 0, 0]  # Red
            # Right edge
            img[int(y1):int(y2), int(x2)-2:int(x2)] = [255, 0, 0]  # Red
        
        frames_with_boxes.append(img)

# Save to video
save_frames_to_mp4(frames_with_boxes, 'clip_with_boxes.mp4', fps=30)
print(f"Saved clip_with_boxes.mp4 with {len(frames_with_boxes)} frames")