# Data & Model Results Visualizations

## Setup

In [None]:
# Library reloading
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
# Import libraries
import os
import glob
import sys
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [None]:
# Import utils
from __init__ import base_path
from data import (
    read_segment_frames,
    read_segment_info,
    read_segment_2d_annotations,
    read_segment_3d_annotations,
)
from visualizations import (
    plot_img,
    fig_to_img,
    crop_img_to_bbox,
    crop_img_to_content,
    visualize_frame_2d_annotations,
    visualize_segment_2d_annotations,
    plot_3d_pose,
    plot_3d_court,
    make_3d_figax,
    animate_3d_pose,
    visualize_segment_3d_annotations,
)

In [None]:
# Define data directories
figures_dir = "./figures"
data_dir = os.path.abspath(os.path.join(base_path, "..", "data"))
dataset_dir = os.path.join(data_dir, "tenniset")
segments_dir = os.path.join(dataset_dir, "segments")
labels_dir = os.path.join(dataset_dir, "labels")

In [None]:
# Load segments
segment_paths = np.sort(glob.glob(os.path.join(segments_dir, "*.mp4")))
n_segments = len(segment_paths)
print(f"Found {n_segments} segments")

## Data Visualizations

In [None]:
# Pick a random segment
segment_path = np.random.choice(segment_paths)
segment_path = "/home/florsanders/Code/columbia_university/advanced_deep_learning/adl_ai_tennis_coach/data/tenniset/segments/V006_0032.mp4"
print(segment_path)

### Frame Visualizations

In [None]:
# Load frames & annotations
frames, fps = read_segment_frames(segment_path, labels_path=labels_dir)
n_frames = len(frames)
try:
    info = read_segment_info(segment_path, labels_path=labels_dir)
except:
    info = {}
(
    court_sequence,
    ball_sequence,
    player_btm_bbox_sequence,
    player_top_bbox_sequence,
    player_btm_pose_sequence,
    player_top_pose_sequence,
) = read_segment_2d_annotations(segment_path, labels_path=labels_dir)
(
    player_btm_position_sequence,
    player_top_position_sequence,
    player_btm_pose_3d_sequence,
    player_top_pose_3d_sequence,
) = read_segment_3d_annotations(segment_path, labels_path=labels_dir)

print(f"Loaded {n_frames} valid frames for {os.path.basename(segment_path)}")
for key, value in info.items():
    print(f"- {key}: {value}")

In [None]:
# Pick a random frame
frame_index = np.random.randint(0, len(frames) - 1)
print(f"Picked frame {frame_index}")

In [None]:
# Save raw frame
cv2.imwrite(os.path.join(figures_dir, "frame.jpg"), frames[frame_index])

#### 2D Frame Visualizations

In [None]:
# Visualize frame
img = visualize_frame_2d_annotations(
    frames[frame_index],
    court_sequence[frame_index],
    ball_sequence[frame_index],
    player_btm_bbox_sequence[frame_index],
    player_top_bbox_sequence[frame_index],
    player_btm_pose_sequence[frame_index],
    player_top_pose_sequence[frame_index],
    show_court=True,
    show_court_numbers=False,
    show_ball=True,
    show_bbox=True,
    show_pose=False,
    show_picture_in_picture=False,
    show_img=True,
)
cv2.imwrite(os.path.join(figures_dir, "frame_2d_annotation.jpg"), img)

In [None]:
# Visualize frame
img = visualize_frame_2d_annotations(
    frames[frame_index],
    court_sequence[frame_index],
    ball_sequence[frame_index],
    player_btm_bbox_sequence[frame_index],
    player_top_bbox_sequence[frame_index],
    player_btm_pose_sequence[frame_index],
    player_top_pose_sequence[frame_index],
    show_court=True,
    show_court_numbers=False,
    show_ball=False,
    show_bbox=False,
    show_pose=True,
    show_picture_in_picture=True,
    show_img=True,
)
cv2.imwrite(os.path.join(figures_dir, "frame_2d_annotation_pip.jpg"), img)

In [None]:
# Visualize player btm bbox crop
cropped_img = crop_img_to_bbox(
    frames[frame_index], player_btm_bbox_sequence[frame_index], show_img=True, resize_to=500
)
cv2.imwrite(os.path.join(figures_dir, "player_btm_bbox.jpg"), cropped_img)

# Visualize player top crop
cropped_img = crop_img_to_bbox(
    frames[frame_index], player_top_bbox_sequence[frame_index], show_img=True, resize_to=500
)
cv2.imwrite(os.path.join(figures_dir, "player_top_bbox.jpg"), cropped_img)

# Visualize ball crop
x, y = ball_sequence[frame_index]
cropped_img = crop_img_to_bbox(
    frames[frame_index], (x, y, x, y), padding=25, show_img=True, resize_to=500
)
cv2.imwrite(os.path.join(figures_dir, "ball_bbox.jpg"), cropped_img)

#### 3D Frame Visualizations

In [None]:
# Plot bottom player pose
#print(player_btm_pose_3d_sequence.shape)
fig, ax = plot_3d_pose(
    player_btm_pose_3d_sequence[frame_index],
    x_global = 0,#player_btm_position_sequence[frame_index][0],
    y_global = 0,#player_btm_position_sequence[frame_index][1],
    fig = None,
    ax = None,
    color = 'blue',
)
ax.set_aspect("equal", adjustable="box")
plt.show()

# Save as image
img = fig_to_img(fig)
cv2.imwrite(os.path.join(figures_dir, 'player_btm_pose_3d.png'), crop_img_to_content(img))
fig.set_dpi(100)

In [None]:
# Plot top player pose
fig, ax = plot_3d_pose(
    player_top_pose_3d_sequence[frame_index],
    x_global = 0, #player_top_position_scaled[frame_index][0],
    y_global = 0, #player_top_position_scaled[frame_index][1],
    fig = None,
    ax = None,
    color = 'green',
)
ax.set_aspect("equal", adjustable="box")
plt.show()
fig.tight_layout()
# Save as image
img = fig_to_img(fig)
cv2.imwrite(os.path.join(figures_dir, 'player_top_pose_3d.png'), crop_img_to_content(img))
fig.set_dpi(100)

In [None]:
# Plot players on court
fig, ax = make_3d_figax()

# Plot 3D Court 
court_width = 10.97
court_length = 23.77
fig, ax = plot_3d_court(fig, ax, court_width, court_length, half=False)

# Plot bottom player
fig, ax = plot_3d_pose(
    player_btm_pose_3d_sequence[frame_index],
    x_global = player_btm_position_sequence[frame_index][0],
    y_global = player_btm_position_sequence[frame_index][1],
    fig = fig,
    ax = ax,
    color = 'blue',
    marker=None,
)

# Plot top player
fig, ax = plot_3d_pose(
    player_top_pose_3d_sequence[frame_index],
    x_global = player_top_position_sequence[frame_index][0],
    y_global = player_top_position_sequence[frame_index][1],
    fig = fig,
    ax = ax,
    color = 'green',
    marker=None
)

# Set axis properties
ax.set_xlim(-court_width/2, court_width/2)
ax.set_ylim(-court_length/2, 0)
ax.set_aspect("equal", adjustable="box")
fig.tight_layout()
plt.show()

# Save as img
img = fig_to_img(fig)
cv2.imwrite(os.path.join(figures_dir, "frame_3d_annotation.jpg"), crop_img_to_content(img))
fig.set_dpi(100)


### Segment Visualizations

#### Segment 2D Visualizations

In [None]:
visualize_segment_2d_annotations(
    segment_path,
    os.path.join(figures_dir, "segment_2d_annotation.mp4"),
    labels_path=labels_dir,
    show_court=True,
    show_court_numbers=False,
    show_ball=True,
    show_bbox=True,
    show_pose=False,
    show_picture_in_picture=False,
    show_img=False,
)

In [None]:
visualize_segment_2d_annotations(
    segment_path,
    os.path.join(figures_dir, "./segment_2d_annotation_pip.mp4"),
    labels_path=labels_dir,
    show_court=True,
    show_court_numbers=False,
    show_ball=False,
    show_bbox=False,
    show_pose=True,
    show_picture_in_picture=True,
    show_img=False,
)

#### Segment 3D Visualizations

In [None]:
# Animate bottom player on court
ani = animate_3d_pose(
    player_btm_pose_3d_sequence,
    player_btm_position_sequence,
    color="blue",
    plot_court=True,
    save_path=os.path.join(figures_dir, "player_btm_pose_3d_court.mp4"),
)
ani = animate_3d_pose(
    player_btm_pose_3d_sequence,
    player_btm_position_sequence,
    color="blue",
    plot_court=False,
    save_path=os.path.join(figures_dir, "player_btm_pose_3d.mp4"),
)
plt.show()

In [None]:
# Animate top player on court
ani = animate_3d_pose(
    player_top_pose_3d_sequence*np.expand_dims([-1, -1, 1], (0, 1)),
    player_top_position_sequence*(-1),
    color="green",
    plot_court=True,
    save_path=os.path.join(figures_dir, "player_top_pose_3d_court.mp4"),
)
ani = animate_3d_pose(
    player_top_pose_3d_sequence*np.expand_dims([-1, -1, 1], (0, 1)),
    player_top_position_sequence*(-1),
    color="green",
    plot_court=False,
    save_path=os.path.join(figures_dir, "player_top_pose_3d.mp4"),
)
plt.show()

In [None]:
ani = visualize_segment_3d_annotations(
    segment_path,
    os.path.join(figures_dir, "segment_3d_annotation.mp4"),
    labels_path=labels_dir,
)
plt.show()

In [None]:
# Postprocess 3D pose animations
import cv2
videos_to_process = glob.glob(os.path.join(figures_dir, "*3d*.mp4"))
margin=50
for video_path in videos_to_process:
    video_dir, video_filename = os.path.split(video_path)
    video_name, video_ext = os.path.splitext(video_filename)
    if "cropped" in video_name:
        continue
    print(video_path)
    # Load frames
    frames, fps = read_segment_frames(video_path, load_valid_frames_only=False)
    
    # Determine content bbox
    frames_mean = np.mean(np.asarray(frames), axis=(0, 3))
    flmask = np.array([0, -1])
    frames_h_mean = np.mean(frames_mean, axis=(1))
    x1, x2 = np.argwhere(frames_h_mean != 255).reshape(-1)[flmask]
    frames_w_mean = np.mean(frames_mean, axis=(0))
    y1, y2 = np.argwhere(frames_w_mean != 255).reshape(-1)[flmask]

    # Check content size
    img = crop_img_to_bbox(frames[0], (y1, x1, y2, x2), padding=margin, resize_to=None, square=False)
    h, w = img.shape[:2]

    # Video writer
    fourcc = cv2.VideoWriter_fourcc(*"avc1")
    writer = cv2.VideoWriter(
        os.path.join(video_dir, video_name + "_cropped" + video_ext),
        fourcc,
        fps,
        (w, h),
    )

    # Writer frames
    for frame in frames:
        img = crop_img_to_bbox(frame, (y1, x1, y2, x2), padding=margin, resize_to=None, square=False)
        writer.write(img)
    
    # Release writer
    writer.release()

## Data visualizations