# Label Shot Segmentation

In the initial data processing notebook, the tennis matches were split into point segments.  
However, the downstream goals have shifted and the tennis matches and its annotations should now be split into shot segments.  
This is a patch notebook that takes care of mapping the point segments labels to shot segment labels.

In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

### Step X - Map tennis point segments to shot segments

In [12]:
# Import libraries
import os
import json
import cv2
import numpy as np
from tqdm import tqdm

In [13]:
# Constants
from __init__ import data_path

# Choose dataset
dataset = "tenniset"
dataset_path = os.path.join(data_path, dataset)
videos_path = os.path.join(dataset_path, "videos")
annotations_path = os.path.join(dataset_path, "annotations")
segments_path = os.path.join(dataset_path, "segments")
labels_path = os.path.join(dataset_path, "labels")

# Read videos
videos = sorted(os.listdir(videos_path))
print("___VIDEOS___")
for video in videos: 
    print(video)

___VIDEOS___
V006.mp4
V007.mp4
V008.mp4
V009.mp4
V010.mp4


In [14]:
# Create new directories for the shot segments and shot labels
shot_segments_path = os.path.join(dataset_path, "shot_segments")
shot_labels_path = os.path.join(dataset_path, "shot_labels")
os.makedirs(shot_segments_path, exist_ok=True)
os.makedirs(shot_labels_path, exist_ok=True)

In [74]:
def frame_to_timestamp(frame_nr, fps=25):
    timestamp_seconds = frame_nr / fps
    timestamp_minutes = int(timestamp_seconds / 60)
    timestamp_seconds = int(timestamp_seconds % 60)
    return f"{timestamp_minutes:02d}:{timestamp_seconds:02d}"

In [78]:
from data_utils import read_segment_frames, read_segment_labels

In [112]:
missing_splits = {}
for video in videos:
    missing_splits[video] = []
    print(f"Processing video: {video}")
    # Video & annotation path
    video_name, video_ext = os.path.splitext(video)
    annotation_path = os.path.join(annotations_path, video.replace(video_ext, ".json"))

    # Load annotation
    with open(annotation_path) as annotation_file:
        annotation = json.load(annotation_file)
    
    # Load point segment split points
    point_splits = annotation["classes"]["Point"]
    print(f" - Number of point splits: {len(point_splits)}")
    point_split_start_frames = np.array([split["start"] for split in point_splits])
    point_split_end_frames = np.array([split["end"] for split in point_splits])

    # Load hit segments
    set_splits = annotation["classes"]["Set"]
    hit_splits = annotation["classes"]["Hit"]
    serve_splits = annotation["classes"]["Serve"]
    hits_and_serve_splits = sorted(hit_splits + serve_splits, key=lambda x: x["start"])
    # Process hit segments
    print(f" - Number of hit and serve splits: {len(hits_and_serve_splits)}")
    for split_index, split in enumerate(tqdm(hits_and_serve_splits)):
        # Parse split information
        split_name = f"{split_index}".zfill(4)
        split_name = f"{video_name}_{split_name}"
        start_frame = split["start"]
        end_frame = split["end"]

        # Figure out which player is performing the shot
        split_set = [set_split for set_split in set_splits if set_split["start"] <= start_frame and set_split["end"] >= end_frame]
        assert len(split_set) == 1, f"Hit segment {split_index} is part of {len(split_set)} set segments"
        split_set = split_set[0]
        near_player = split_set["custom"]["Near"]
        split_player = split["custom"]["Player"]
        split_info = {
            "player": split_player,
            "player_is_near":  near_player == split_player,
        }

        # Find point segment that this hit segment is part of
        mask_1 = start_frame >= point_split_start_frames
        mask_2 = end_frame <= point_split_end_frames
        mask = np.logical_and(mask_1, mask_2)
        indx = np.argwhere(mask).reshape(-1)

        # Handle case where both segments are not fully aligned
        if len(indx) == 0:
            #print(f"WARNING: Hit segment {split_index} not fully contained in any point segment, marking missing")
            missing_splits[video].append(split_index)
            continue

        # Make sure everythin makes sense
        assert len(indx) > 0, f"Hit segment {split_index} is part of no known point segment"
        assert len(indx) == 1, f"Hit segment {split_index} is part of multiple point segments"
        indx = indx[0]
        point_start_frame = point_split_start_frames[indx]
        point_end_frame = point_split_end_frames[indx]
        assert point_start_frame <= start_frame and point_end_frame >= end_frame, f"Hit segment {split_index} is not fully contained in point segment {indx}"

        # Load frames for the point segments
        point_split_name = f"{indx}".zfill(4)
        point_segment_path = os.path.join(segments_path, f"{video_name}_{point_split_name}.mp4")
        point_frames, fps = read_segment_frames(point_segment_path, labels_path=labels_path, load_valid_frames_only=False)
        assert point_end_frame - point_start_frame == len(point_frames), f"Unexpected number of point segment frames"
        
        # Load labels for the point segment
        # TODO: Load 3D Annotations
        (
            frame_validity,
            court_sequence,
            ball_sequence,
            player_btm_bbox_pose_sequence,
            player_top_bbox_pose_sequence,
            player_btm_pose_sequence,
            player_top_pose_sequence,
        ) = read_segment_labels(
            point_segment_path,
            labels_path=labels_path,
            load_frame_validity=True,
            load_court=True,
            load_ball=True,
            load_player_bbox=True,
            load_player_pose=True,
            use_pose_bbox=True,
        )
        (
            _,
            _,
            _,
            player_btm_bbox_sequence,
            player_top_bbox_sequence,
            _,
            _,
        ) = read_segment_labels(
            point_segment_path,
            labels_path=labels_path,
            load_frame_validity=False,
            load_court=False,
            load_ball=False,
            load_player_bbox=True,
            load_player_pose=False,
            use_pose_bbox=False,
        )
        
        # Extract frames, frame validity for the hit segment
        start_index = int(start_frame - point_start_frame)
        end_index = int(end_frame - point_start_frame)
        split_frames = point_frames[start_index:end_index]
        split_frame_validity = frame_validity[start_index:end_index]

        # Extract the labels for the valid frames
        start_label_index = np.count_nonzero(frame_validity[:start_index])
        end_label_index = np.count_nonzero(frame_validity[:end_index])
        split_court_sequence = court_sequence[start_label_index:end_label_index]
        split_ball_sequence = ball_sequence[start_label_index:end_label_index]
        split_player_btm_bbox_sequence = player_btm_bbox_sequence[start_label_index:end_label_index]
        split_player_top_bbox_sequence = player_top_bbox_sequence[start_label_index:end_label_index]
        split_player_btm_bbox_pose_sequence = player_btm_bbox_pose_sequence[start_label_index:end_label_index]
        split_player_top_bbox_pose_sequence = player_top_bbox_pose_sequence[start_label_index:end_label_index]
        split_player_btm_pose_sequence = player_btm_pose_sequence[start_label_index:end_label_index]
        split_player_top_pose_sequence = player_top_pose_sequence[start_label_index:end_label_index]
        # TODO: Add 3D Annotations

        # Sanity checks
        assert np.count_nonzero(split_frame_validity) == len(split_court_sequence), f"Extracted court sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_ball_sequence), f"Extracted ball sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_player_btm_bbox_sequence), f"Extracted bottom player bbox sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_player_top_bbox_sequence), f"Extracted top player bbox sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_player_btm_bbox_pose_sequence), f"Extracted bottom player bbox pose sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_player_top_bbox_pose_sequence), f"Extracted top player bbox pose sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_player_btm_pose_sequence), f"Extracted bottom player pose sequence has wrong length"
        assert np.count_nonzero(split_frame_validity) == len(split_player_top_pose_sequence), f"Extracted top player pose sequence has wrong length"
        # TODO: Add 3D Annotations

        # Save split info
        with open(os.path.join(shot_labels_path, f"{split_name}_info.json"), "w") as f:
            json.dump(split_info, f)

        # Save video
        frame_height, frame_width, _ = split_frames[0].shape
        split_video_path = os.path.join(shot_segments_path, f"{split_name}.mp4")
        writer = cv2.VideoWriter(split_video_path, cv2.VideoWriter_fourcc(*"avc1"), fps, (frame_width, frame_height))
        for frame in split_frames:
            writer.write(frame)
        writer.release()

        # Save labels
        np.save(os.path.join(shot_labels_path, f"{split_name}_frame_validity.npy"), split_frame_validity)
        np.save(os.path.join(shot_labels_path, f"{split_name}_court.npy"), split_court_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_ball.npy"), split_ball_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_player_btm_bbox.npy"), split_player_btm_bbox_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_player_top_bbox.npy"), split_player_top_bbox_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_player_btm_bbox_pose.npy"), split_player_btm_bbox_pose_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_player_top_bbox_pose.npy"), split_player_top_bbox_pose_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_player_btm_pose.npy"), split_player_btm_pose_sequence)
        np.save(os.path.join(shot_labels_path, f"{split_name}_player_top_pose.npy"), split_player_top_pose_sequence)
        
        
        if split_index > 20:
            break


        #print(point_split_name)
        
    print()
    break


print("Following split indices are missing")
for key, value in missing_splits.items():
    print(f"{key}: {len(value)} - {value}")

Processing video: V006.mp4
 - Number of point splits: 81
 - Number of hit and serve splits: 332


  6%|▋         | 21/332 [00:13<03:25,  1.51it/s]


Following split indices are missing
V006.mp4: 3 - [6, 8, 15]





In [123]:
# Extract segment videos from full 
for video, missing_indxs in missing_splits.items():
    print(f"Processing video: {video}")
    # Video & annotation path
    video_name, video_ext = os.path.splitext(video)
    annotation_path = os.path.join(annotations_path, video.replace(video_ext, ".json"))
    video_path = os.path.join(videos_path, video)

    # Load annotation
    with open(annotation_path) as annotation_file:
        annotation = json.load(annotation_file)

    # Load video
    capture = cv2.VideoCapture(video_path)
    frame = 0

    # Get resolution & framerate from capture
    frame_width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
    frame_height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = capture.get(cv2.CAP_PROP_FPS)
    
    # Load point segment split points
    point_splits = annotation["classes"]["Point"]
    print(f" - Number of point splits: {len(point_splits)}")
    point_split_start_frames = np.array([split["start"] for split in point_splits])
    point_split_end_frames = np.array([split["end"] for split in point_splits])

    # Load hit segments
    set_splits = annotation["classes"]["Set"]
    hit_splits = annotation["classes"]["Hit"]
    serve_splits = annotation["classes"]["Serve"]
    hits_and_serve_splits = sorted(hit_splits + serve_splits, key=lambda x: x["start"])

    # Process hit segments
    print(f" - Number of hit and serve splits: {len(hits_and_serve_splits)}")
    for indx in tqdm(missing_indxs):
        # Parse split information
        split_name = f"{indx}".zfill(4)
        split_name = f"{video_name}_{split_name}"
        split = hits_and_serve_splits[indx]
        start, end = int(split["start"]), int(split["end"])

        # Figure out which player is performing the shot
        split_set = [set_split for set_split in set_splits if set_split["start"] <= start_frame and set_split["end"] >= end_frame]
        assert len(split_set) == 1, f"Hit segment {split_index} is part of {len(split_set)} set segments"
        split_set = split_set[0]
        near_player = split_set["custom"]["Near"]
        split_player = split["custom"]["Player"]
        split_info = {
            "player": split_player,
            "player_is_near":  near_player == split_player,
        }

        # Save split info
        with open(os.path.join(shot_labels_path, f"{split_name}_info.json"), "w") as f:
            json.dump(split_info, f)

        # Fast forward to start of segment
        frame = start
        capture.set(1, frame)

        # Open writer
        split_video_path = os.path.join(shot_segments_path, f"{split_name}.mp4")
        writer = cv2.VideoWriter(split_video_path, cv2.VideoWriter_fourcc(*"avc1"), fps, (frame_width, frame_height))

        # Save frames to segment
        while_safety = 0
        max_while_safety = 500
        while frame < end:
            # Read frame
            ret, img = capture.read()

            # Sometimes OpenCV reads None's during a video, in which case we want to just skip
            assert while_safety < max_while_safety, f"ERROR, cv2 read {max_while_safety} Nones"
            if ret == 0 or img is None: 
                while_safety += 1
                continue 
            while_safety = 0

            # Write frame
            writer.write(img)

            # Increase frame counter
            frame += 1

        # Release writer
        writer.release()
    
    capture.release()

Processing video: V006.mp4
 - Number of point splits: 81
 - Number of hit and serve splits: 332
{'custom': {'Player': 'Williams', 'Result': 'Fault'}, 'name': '0004', 'start': 27708, 'desc': '', 'end': 27779}
/home/florsanders/Code/columbia_university/advanced_deep_learning/adl_ai_tennis_coach/data/tenniset/shot_segments/V006_0006.mp4
{'custom': {'Player': 'Williams', 'Result': 'Fault'}, 'name': '0004', 'start': 27708, 'desc': '', 'end': 27779}
/home/florsanders/Code/columbia_university/advanced_deep_learning/adl_ai_tennis_coach/data/tenniset/shot_segments/V006_0008.mp4
{'custom': {'Player': 'Williams', 'Result': 'Fault'}, 'name': '0006', 'start': 29302, 'desc': '', 'end': 29365}
/home/florsanders/Code/columbia_university/advanced_deep_learning/adl_ai_tennis_coach/data/tenniset/shot_segments/V006_0015.mp4
