# 2D -> 3D Data Processing
Execution of this notebook is meant to follow `1_data_processing.ipynb`. 

## Sample the data

In [None]:
import os

# Constants
from PoseLifter import PoseLifter
from data_utils import crop_frame

# Choose dataset
dataset_path = "/home/florsanders/adl_ai_tennis_coach/data/tenniset"
write_path = "/home/georgetamer/3d_poses"
segments_path = os.path.join(dataset_path, "segments")
labels_path = os.path.join(dataset_path,"labels")

In [None]:
from heuristics import keep_largest_volume_3D_pose_heuristic

pose_lifter = PoseLifter(
    crop_fn=crop_frame,
    dedup_heuristic_fn=keep_largest_volume_3D_pose_heuristic,
    dataset_path=dataset_path,
    write_path=write_path,
    duplicate_work=False,
)

In [None]:
pose_lifter.extract_3d_poses()

## Post-process 3D Poses

The extracted 3D poses are not oriented upright and their orientations are not consistent between frames.

To resolve this, we match certain keypoints in the detected 3D poses by matching them with the correctly oriented 2D poses by finding the [best fitting rotation matrix between them](https://nghiaho.com/?page_id=671).

In [None]:
# Library reloading
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Import libraries
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
# Import utilities
from data import (
    read_segment_frames,
    read_segment_2d_annotations, 
    read_segment_3d_annotations,
    rotate_3d_poses,
    rotate_pose_3d_to_match_2d,
)
from visualizations import (
    make_3d_figax,
    plot_3d_pose,
    plot_2d_pose,
    plot_img,
    visualize_frame_2d_annotations,
)

In [None]:
# Define data directories
data_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir, "data"))
segments_path = os.path.join(data_path, "tenniset", "shot_segments")
labels_path = os.path.join(data_path, "tenniset", "shot_labels")
print(segments_path, len(os.listdir(segments_path)))
print(labels_path, len(os.listdir(labels_path)))

In [None]:
# Read segment files
segment_files = np.sort(glob.glob(os.path.join(segments_path, "*.mp4")))
n_segments = len(segment_files)
print("Number of segments:", n_segments)

In [None]:
overwrite = False
for segment_path in tqdm(segment_files):
    # Parse segment filename
    segment_dir,  segment_filename = os.path.split(segment_path)
    segment_name, segment_ext = os.path.splitext(segment_filename)
    #print(f"Processing {segment_name}...")

    # Avoid doing double work
    btm_path = os.path.join(labels_path, f"{segment_name}_player_btm_pose_3d_rot.npy")
    top_path = os.path.join(labels_path, f"{segment_name}_player_top_pose_3d_rot.npy")
    if not overwrite and os.path.exists(btm_path) and os.path.exists(top_path):
        continue

    # Load annotations
    try:
        (
            _,
            _,
            _,
            _,
            player_btm_pose_sequence,
            player_top_pose_sequence,
        ) = read_segment_2d_annotations(segment_path, labels_path=labels_path)
        (
            _,
            _,
            player_btm_pose_3d_sequence,
            player_top_pose_3d_sequence,
        ) = read_segment_3d_annotations(segment_path, labels_path=labels_path, use_rotated=False)
    except Exception as e:
        print(f"Warning: missing annotation files for {segment_name}")
        continue
        
    n_frames = len(player_btm_pose_sequence)

    # Rotate poses
    player_btm_pose_3d_rot = rotate_3d_poses(
        player_btm_pose_3d_sequence, player_btm_pose_sequence
    )
    player_top_pose_3d_rot = rotate_3d_poses(
        player_btm_pose_3d_sequence, player_btm_pose_sequence
    )

    # Save rotated poses
    np.save(btm_path, player_btm_pose_3d_rot)
    np.save(top_path, player_top_pose_3d_rot)

### Reset validity labels

Label all data points as valid.

In [None]:
# Import libraries
import os
import glob
import numpy as np
import json
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

In [None]:
# Define data directory
data_path = os.path.abspath(os.path.join(os.getcwd(), os.pardir, os.pardir, "data"))
labels_path = os.path.join(data_path, "tenniset", "shot_labels")

In [None]:
info_files = np.sort(glob.glob(os.path.join(labels_path, "*_info.json")))
for info_file in tqdm(info_files):
    with open(info_file, "r") as f:
        info = json.load(f)
    info["is_valid"] = True
    with open(info_file, "w") as f:
        json.dump(info, f)