In [1]:
import cv2
import numpy as np
import torch
import json
from detectron2.config import get_cfg
from detectron2.engine import DefaultPredictor
from detectron2 import model_zoo
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog
from detectron2.structures import Instances
import os
import glob

In [2]:
# Step 1: Setup Detectron2
def setup_detectron2():
    cfg = get_cfg()
    cfg.merge_from_file(model_zoo.get_config_file("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml"))
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x.yaml")
    cfg.MODEL.DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    predictor = DefaultPredictor(cfg)
    metadata = MetadataCatalog.get(cfg.DATASETS.TRAIN[0])
    return predictor, metadata

In [3]:
def select_roi(video_path, roi_dir="rois"):
    os.makedirs(roi_dir, exist_ok=True)
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    roi_path = os.path.join(roi_dir, f"{video_name}_roi.json")

    # Try to load saved ROI
    if os.path.exists(roi_path):
        with open(roi_path, "r") as f:
            data = json.load(f)
            print(f"Loaded saved ROI from: {roi_path}")
            return tuple(data["roi"])

    # If not found, ask for manual ROI selection
    cap = cv2.VideoCapture(video_path)
    ret, frame = cap.read()
    cap.release()
    if not ret:
        raise Exception("Failed to read video")

    bbox = cv2.selectROI("Select Player Area", frame, fromCenter=False, showCrosshair=True)
    cv2.destroyWindow("Select Player Area")

    # Save ROI
    with open(roi_path, "w") as f:
        json.dump({"roi": list(bbox)}, f)
        print(f"Saved ROI to: {roi_path}")

    return bbox

In [4]:
def load_roi(video_path, roi_dir="rois"):
    video_name = os.path.splitext(os.path.basename(video_path))[0]
    roi_path = os.path.join(roi_dir, f"{video_name}_roi.json")
    
    if os.path.exists(roi_path):
        with open(roi_path, "r") as f:
            data = json.load(f)
            return np.array(data["roi"])
    else:
        return None


In [5]:
# Step 3: Score function based on proximity to selected ROI
def score_by_proximity(box, selected_box):
    sx, sy, sw, sh = selected_box
    sel_cx, sel_cy = sx + sw / 2, sy + sh / 2
    x1, y1, x2, y2 = box
    box_cx = (x1 + x2) / 2
    box_cy = (y1 + y2) / 2
    return -np.sqrt((box_cx - sel_cx) ** 2 + (box_cy - sel_cy) ** 2)

In [6]:
# Step 4: Main pipeline
def extract_keypoints_from_roi(input_video_path, output_npz_path, output_video_path):
    predictor, metadata = setup_detectron2()
    selected_box = load_roi(input_video_path)
    if selected_box is None:
        print(f"no ROI found for {input_video_path}")
        

    cap = cv2.VideoCapture(input_video_path)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = cap.get(cv2.CAP_PROP_FPS)
    out_vid = cv2.VideoWriter(output_video_path, cv2.VideoWriter_fourcc(*"mp4v"), fps, (width, height))

    all_keypoints = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break

        outputs = predictor(frame)
        instances = outputs["instances"].to("cpu")

        if len(instances) == 0:
            out_vid.write(frame)
            continue

        boxes = instances.pred_boxes.tensor.numpy()
        keypoints = instances.pred_keypoints.numpy()

        scores = [score_by_proximity(box, selected_box) for box in boxes]
        best_idx = int(np.argmax(scores))
        selected_kpts = keypoints[best_idx][:, :2]  # drop confidence

        all_keypoints.append(selected_kpts)

        # Draw only the selected person
        single_instance = Instances(image_size=frame.shape[:2])
        single_instance.pred_boxes = instances.pred_boxes[[best_idx]]
        single_instance.pred_keypoints = instances.pred_keypoints[[best_idx]]
        single_instance.scores = instances.scores[[best_idx]]
        single_instance.pred_classes = instances.pred_classes[[best_idx]]

        vis = Visualizer(frame[:, :, ::-1], metadata=metadata, scale=1.0)
        vis_frame = vis.draw_instance_predictions(single_instance)
        result = vis_frame.get_image()[:, :, ::-1]

        out_vid.write(result)

    cap.release()
    out_vid.release()

    np.savez_compressed(output_npz_path, keypoints=np.array(all_keypoints))
    print(f"Saved keypoints to: {output_npz_path}")
    print(f"Saved video to: {output_video_path}")


In [7]:
# ROI Detection
input_folder = 'inputdir'
output_folder_data = 'data/custom'
output_folder_video = 'outputdir'

# Get all .mp4 files in the input folder
video_files = glob.glob(os.path.join(input_folder, '*.mp4'))

for video_path in video_files:
    video_name = os.path.splitext(os.path.basename(video_path))[0]  # e.g., "Clip14Miss"
    
    input_video = video_path
    output_data = os.path.join(output_folder_data, f'{video_name}.npz')
    output_video = os.path.join(output_folder_video, f'New{video_name}.mp4')
    
    select_roi(f'{input_video}')

Loaded saved ROI from: rois\Clip101Miss_roi.json
Loaded saved ROI from: rois\Clip102Miss_roi.json
Loaded saved ROI from: rois\Clip103Miss_roi.json
Loaded saved ROI from: rois\Clip107Miss_roi.json
Loaded saved ROI from: rois\Clip108Miss_roi.json
Loaded saved ROI from: rois\Clip109Miss_roi.json
Loaded saved ROI from: rois\Clip110Miss_roi.json
Loaded saved ROI from: rois\Clip11Hit_roi.json
Loaded saved ROI from: rois\Clip12Hit_roi.json
Loaded saved ROI from: rois\Clip14Miss_roi.json
Loaded saved ROI from: rois\Clip16Hit_roi.json
Loaded saved ROI from: rois\Clip17Hit_roi.json
Loaded saved ROI from: rois\Clip20Miss_roi.json
Loaded saved ROI from: rois\Clip21Hit_roi.json
Loaded saved ROI from: rois\Clip22Hit_roi.json
Loaded saved ROI from: rois\Clip25Miss_roi.json
Loaded saved ROI from: rois\Clip26Hit_roi.json
Loaded saved ROI from: rois\Clip27Hit_roi.json
Loaded saved ROI from: rois\Clip28Miss_roi.json
Loaded saved ROI from: rois\Clip2Hit_roi.json
Loaded saved ROI from: rois\Clip30Miss_roi.

In [8]:
# KeyPoint Detection
input_folder = 'inputdir'
output_folder_data = 'data/custom'
output_folder_video = 'outputdir'

# Get all .mp4 files in the input folder
video_files = glob.glob(os.path.join(input_folder, '*.mp4'))

for video_path in video_files:
        video_name = os.path.splitext(os.path.basename(video_path))[0]  # e.g., "Clip14Miss"
        
        input_video = video_path
        output_data = os.path.join(output_folder_data, f'{video_name}.npz')
        output_video = os.path.join(output_folder_video, f'New{video_name}.mp4')

        if os.path.exists(output_data) or os.path.exists(output_video):
            print(f"Skipping {video_name} — already processed.")
            continue

        print(f"Processing {video_name}...")
        
        extract_keypoints_from_roi(input_video, output_data, output_video)

Skipping Clip101Miss — already processed.
Skipping Clip102Miss — already processed.
Skipping Clip103Miss — already processed.
Skipping Clip107Miss — already processed.
Skipping Clip108Miss — already processed.
Skipping Clip109Miss — already processed.
Skipping Clip110Miss — already processed.
Skipping Clip11Hit — already processed.
Skipping Clip12Hit — already processed.
Skipping Clip14Miss — already processed.
Skipping Clip16Hit — already processed.
Skipping Clip17Hit — already processed.
Skipping Clip20Miss — already processed.
Skipping Clip21Hit — already processed.
Skipping Clip22Hit — already processed.
Skipping Clip25Miss — already processed.
Skipping Clip26Hit — already processed.
Skipping Clip27Hit — already processed.
Skipping Clip28Miss — already processed.
Skipping Clip2Hit — already processed.
Skipping Clip30Miss — already processed.
Skipping Clip32Hit — already processed.
Skipping Clip33Hit — already processed.
Skipping Clip35Miss — already processed.
Skipping Clip37Hit — 

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


Saved keypoints to: data/custom\Clip43Miss.npz
Saved video to: outputdir\NewClip43Miss.mp4
Processing Clip44Hit...
Saved keypoints to: data/custom\Clip44Hit.npz
Saved video to: outputdir\NewClip44Hit.mp4
Processing Clip46Hit...
Saved keypoints to: data/custom\Clip46Hit.npz
Saved video to: outputdir\NewClip46Hit.mp4
Processing Clip4Hit...
Saved keypoints to: data/custom\Clip4Hit.npz
Saved video to: outputdir\NewClip4Hit.mp4
Processing Clip50Hit...
Saved keypoints to: data/custom\Clip50Hit.npz
Saved video to: outputdir\NewClip50Hit.mp4
Processing Clip51Hit...
Saved keypoints to: data/custom\Clip51Hit.npz
Saved video to: outputdir\NewClip51Hit.mp4
Processing Clip53Miss...
Saved keypoints to: data/custom\Clip53Miss.npz
Saved video to: outputdir\NewClip53Miss.mp4
Processing Clip55Hit...
Saved keypoints to: data/custom\Clip55Hit.npz
Saved video to: outputdir\NewClip55Hit.mp4
Processing Clip56Hit...
Saved keypoints to: data/custom\Clip56Hit.npz
Saved video to: outputdir\NewClip56Hit.mp4
Proce