# 2D -> 3D Data Processing
Execution of this notebook is meant to follow `1_data_processing.ipynb`. 

The 

## Sample the data

In [2]:
import os
import numpy as np
from tqdm import tqdm
import json
import cv2

# Constants
from __init__ import data_path
from data_utils import read_segment_frames, read_segment_labels, visualize_frame_annotations, visualize_segment_labels

# Choose dataset
dataset_path = "/home/florsanders/adl_ai_tennis_coach/data/tenniset"
write_path = "/home/georgetamer/3d_poses"
segments_path = os.path.join(dataset_path, "segments")
labels_path = os.path.join(dataset_path,"labels")

In [None]:
def crop_frame(frame,
              bbox,
              crop_padding=50,
            crop_img_width=256):
    # Frame size
    frame_height, frame_width = frame.shape[:2]

    # Parse bounding box coords
    if np.any(bbox == None):
        return best_keypoints, best_bbox
    x1, y1, x2, y2 = bbox
    xc, yc =  (x1 + x2) / 2, (y1 + y2) / 2
    w, h = abs(x2 - x1), abs(y2 - y1)
    d = max(w, h) + crop_padding * 2

    # Define cropping indices
    x_crop1, x_crop2 = int(xc - d/2), int(xc + d/2)
    y_crop1, y_crop2 = int(yc - d/2), int(yc + d/2)

    # Make sure we don't crop past the edges of the frame
    x_crop_offset = min(frame_width - x_crop2, max(-x_crop1, 0))
    y_crop_offset = min(frame_height - y_crop2, max(-y_crop1, 0))
    x_crop1 += x_crop_offset
    x_crop2 += x_crop_offset
    y_crop1 += y_crop_offset
    y_crop2 += y_crop_offset
    
    # Crop image
    img = frame[y_crop1:y_crop2,  x_crop1:x_crop2].copy()

    # Resize img
    scale = d / crop_img_width
    img = cv2.resize(img, (crop_img_width, crop_img_width))
    return img

def detect_pose(
    frame,
    bbox,
    crop_padding=50,
    crop_img_width=256
):

    # Crop image
    crop_padding=10
    cropped_frame = crop_frame(frame, bbox, crop_padding,crop_img_width)

    # Detect pose
    should_show = False
    result_generator = inferencer(cropped_frame, show=should_show, vis_out_dir="/home/georgetamer/3d_poses", return_vis=True)
    results = [result for result in result_generator]

    # # Parse keypoints
    # min_center_distance = np.inf
    # result = results[0]
    # prediction = result["predictions"][0]
    # item = prediction[0]
    # keypoints = item["keypoints"]

    # Keep best results
    min_center_distance = np.inf
    for result in results:
        print("result: ", result)
        for prediction in result["predictions"]:
            for item in prediction:
                # Parse item
                # pose_bbox = item["bbox"]
                keypoints = item["keypoints"]

                # Parse bbox
                # xb1, yb1, xb2, yb2 = pose_bbox[0]
                # center_x = (xb1 + xb2) / 2
                # center_y = (yb1 + yb2) / 2
                # center_distance = ((crop_img_width / 2 - center_x)**2  + (crop_img_width / 2 - center_y)**2)**(1/2)

                # # Keep track of best prediction
                # if center_distance < min_center_distance:
                #     min_center_distance = center_distance
                #     best_keypoints = keypoints
                #     best_bbox = pose_bbox[0]

    return np.array(keypoints)

In [None]:
def clean_bbox_sequence(
    bbox_sequence, 
    court_sequence, 
    is_btm,
    derivative_threshold=5000,
    make_plot=False,
):
    # Look at center points to gather inconsistencies
    center_points = np.zeros((len(bbox_sequence), 2))
    bbox_areas = np.zeros(len(bbox_sequence))
    bbox_sequence_clean = np.copy(bbox_sequence)
    missing_points = np.zeros(len(center_points), dtype=int)

    # Extract center points wrt court from 
    for i, (bbox, court_points) in enumerate(zip(bbox_sequence, court_sequence)):
        # Skip no bounding box detected
        if np.any(bbox == None):
            center_points[i, :] = np.inf
            continue
        xb1, yb1, xb2, yb2 = bbox
        bbox_areas[i] = np.abs((xb2 - xb1) * (yb2 - yb1))

        # Skip no court outline detected
        court_outline = court_points[:4]
        if np.any(court_outline == None):
            center_points[i, :] = np.inf
            continue
        
        # Get relevant center point of the court
        (xtl, ytl), (xtr, ytr), (xbl, ybl), (xbr, ybr) = court_outline
        x_ref = (xbl + xbr) / 2 if is_btm else (xtl + xtr) / 2
        y_ref = (ybl + ybr) / 2 if is_btm else (ytl + ytr) / 2

        # Get center point of the player's feet
        x_player = (xb1 + xb2) / 2
        y_player = yb2

        # Save player center point referenced to court point
        center_points[i, 0] = x_player - x_ref
        center_points[i, 1] = y_player - y_ref

    # Compute first derivative
    center_points_derivative = np.vstack(([[0, 0]], center_points[:-1] - center_points[1:]))
    center_points_derivative = center_points_derivative[:,0]**2 + center_points_derivative[:,1]**2
    bbox_areas_derivative = np.abs(np.concatenate(([0], bbox_areas[:-1] - bbox_areas[1:])))

    # Area jumps
    bbox_area_jumps = np.sort(np.argwhere(bbox_areas_derivative > derivative_threshold).reshape(-1))
    if len(bbox_area_jumps):
        # print("JUMPS DETECTED")
        # print(bbox_area_jumps)
        mean_area = np.mean(bbox_areas[:bbox_area_jumps[0]])
    else:
        mean_area = np.mean(bbox_areas)

    # Determine jump points
    jump_points = np.argwhere(np.logical_or(
        center_points_derivative > derivative_threshold,
        bbox_areas < mean_area / 2,
    )).reshape(-1)
    
    # Return if no cleaning needs to be done
    if len(jump_points) == 0:
        return missing_points.astype(bool), bbox_sequence_clean

    # Process jump points
    indx_last = None
    missing_start = False
    for indx in jump_points:
        #print(indx_last, indx)
        if indx_last is None:
            # First missing point
            #print("FIRST MISSING POINT")
            missing_points[indx] = 1
            missing_start = True
        elif np.any(bbox_sequence[indx] == None):
            # Missing point
            #print("MISSING POINT", indx)
            missing_points[indx] = 1
            missing_start = False
        elif indx_last == indx - 1:
            # Subsequent problematic points
            #print("SUBSEQUENT MISSING POINT")
            missing_points[indx] = 1
            missing_start = False
        else:
            # Distance between missing points
            #print("DISTANCE BETWEEN MISSING POINTS")
            if missing_start:
                # End point (hopefully)
                missing_points[indx_last:indx+1] = 1
                missing_start = False
            else:
                # Start point (hopefully)
                missing_points[indx] = 1
                missing_start = True

        # Update last indx
        indx_last = indx

    # Fill gaps in missing points by linear interpolation
    filled_center_points = np.copy(center_points)
    missing_starts = np.argwhere((missing_points[1:] - missing_points[:-1]) == 1).reshape(-1)
    missing_ends = np.argwhere((missing_points[1:] - missing_points[:-1]) == -1).reshape(-1)
    for i, missing_start in enumerate(missing_starts):
        # Get start value
        if missing_start != 0:
            # Previous value
            cp_start_value = filled_center_points[missing_start-1]
            bbox_start_value = bbox_sequence_clean[missing_start-1]
        else:
            # First valid value (TODO: fix if none is valid???)
            cp_start_value = filled_center_points[not missing_points.astype(bool)][0]
            bbox_start_value = bbox_sequence_clean[not missing_points.astype(bool)][0]

        # Get missing end
        if len(missing_ends) <= i:
            # No matched endpoint - constant from startpoint onward
            missing_end = len(filled_center_points) - 1
            cp_end_value = cp_start_value
            bbox_end_value = bbox_start_value
        else:
            # Get endpoint
            missing_end = missing_ends[i]
            cp_end_value = filled_center_points[missing_end]
            bbox_end_value = bbox_sequence_clean[missing_end]
            
        # Linearly interpolate
        n_points = missing_end - missing_start + 1
        filled_center_points[missing_start:missing_end+1] = np.linspace(cp_start_value, cp_end_value, n_points)
        bbox_sequence_clean[missing_start:missing_end+1] = np.linspace(bbox_start_value, bbox_end_value, n_points)

    return missing_points.astype(bool), bbox_sequence_clean

In [None]:
def process_segment(
    segment_path, 
    labels_path=labels_path,
    crop_padding=50,
    crop_width=224,
):
    # Load frames
    segment_dir, segment_filename = os.path.split(segment_path)
    segment_name, segment_ext = os.path.splitext(segment_filename)
    frames, fps = read_segment_frames(segment_path, labels_path=labels_path, load_valid_frames_only=True)
    if not len(frames):
        return False

    # Load labels
    (
        _,
        court_sequence,
        _,
        player_btm_bbox_sequence,
        player_top_bbox_sequence,
        _,
        _,
    ) = read_segment_labels(
        segment_path, 
        labels_path=labels_path,
        load_frame_validity=True,
        load_court=True,
        load_ball=False,
        load_player_bbox=True,
        load_player_pose=False,
        use_pose_bbox=True,
    )

    btm_missing_points, btm_bbox_clean = clean_bbox_sequence(
        player_btm_bbox_sequence,
        court_sequence,
        is_btm=True,
        make_plot=True,
    )
    top_missing_points, top_bbox_clean = clean_bbox_sequence(
        player_top_bbox_sequence,
        court_sequence,
        is_btm=False,
        make_plot=True,
    )

    # Process frames
    players_bbox_last = [None, None]
    players_bbox_sequences = [[None] *  len(frames) , [None] * len(frames)]
    players_pose_sequences = [[None] *  len(frames) , [None] * len(frames)]
    for frame_index, frame in tqdm(enumerate(frames)):
        # Get frame labels
        frame_height, frame_width, _ = frame.shape
        players_bbox = [player_top_bbox_sequence[frame_index], player_btm_bbox_sequence[frame_index]]
        players_bbox_clean = [top_bbox_clean[frame_index], btm_bbox_clean[frame_index]]
        players_missing = [top_missing_points[frame_index], btm_missing_points[frame_index]]

        # Perform pose detection
        for is_btm, bbox in enumerate(players_bbox):
            if players_missing[is_btm]:
                # Try to recover player pose from best knowledge
                for i, bbox_candidate in enumerate([players_bbox_last[is_btm], players_bbox_clean[is_btm], players_bbox[is_btm]]):
                    # Skip invalid bboxes
                    if bbox_candidate is None:
                        continue
                    
                    # Detect pose
                    pose_keypoints = detect_pose(
                        frame, 
                        bbox_candidate, 
                        crop_padding=crop_padding, 
                        crop_img_width=crop_width,
                    )

                    # Break if result is valid
                    if not np.any(pose_keypoints == None):
                        break
            else:
                # Detect pose
                pose_keypoints = detect_pose(
                    frame, 
                    bbox, 
                    crop_padding=crop_padding, 
                    crop_img_width=crop_width
                )

            # Save pose
            players_pose_sequences[is_btm][frame_index] = pose_keypoints
                    

    # Export labels
    for is_btm in range(2):
        player_name = "btm" if is_btm else "top"
        player_3d_pose_file = os.path.join(write_path, f"{segment_name}_player_{player_name}_pose_3d.npy")
        np.save(player_3d_pose_file, players_pose_sequences[is_btm])

    return True

In [4]:
seg_path = os.path.join(segments_path, "V009_0061.mp4")
print("Segment path", seg_path)

success = process_segment(
    segment_path=seg_path, 
    labels_path=labels_path)

In [None]:
import numpy as np

# Load the .npy file
sample_file = os.path.join(write_path, 'V009_0061_player_top_pose_3d.npy')
# sample_file = os.path.join(labels_path, 'V010_0071_player_top_bbox.npy')
data = np.load(sample_file, allow_pickle=True)

# Print the contents of the file
print(data)

In [None]:
# # visualize
# from mmpose.apis import visualize

# # extract first frame from segment
# img_path = os.path.join(write_path, '000000.jpg')
# keypoints = np.load(sample_file, allow_pickle=True)
# keypoint_scores = None

# metainfo = 'config/_base_/datasets/coco.py'

# visualize(
#     img_path,
#     keypoints,
#     keypoint_scores,
#     metainfo=metainfo,
#     show=True)

In [None]:
# take heuristic such as max area, or largest x-span, y-span
