In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.axes import Axes
import matplotlib.gridspec as gridspec
import numpy as np
from typing import cast

In [None]:
d = np.load("front.npz")
from numpy.typing import NDArray
kps = d["reconstruction"]
kps.shape

```json
"keypoints": { 0: "nose", 1: "left_eye", 2: "right_eye", 3: "left_ear", 4: "right_ear", 5: "left_shoulder", 6: "right_shoulder", 7: "left_elbow", 8: "right_elbow", 9: "left_wrist", 10: "right_wrist", 11: "left_hip", 12: "right_hip", 13: "left_knee", 14: "right_knee", 15: "left_ankle", 16: "right_ankle" }
```

In [None]:
import plotly.express as px
import plotly.graph_objects as go
from typeguard import typechecked
from jaxtyping import Int, Float, Bool, Num, jaxtyped
from pydantic import BaseModel
from typing import Optional, Union
from plotly.graph_objects import layout

NDArray = np.ndarray
number = Union[int, float]
Color = tuple[int, int, int] | str
Step = layout.slider.Step
Slider = layout.Slider


class Joint(BaseModel):
    index: int
    opposite_index: Optional[int] = None
    name: str
    color: Color


class Bone(BaseModel):
    joint1: Joint
    joint2: Joint
    name: str
    color: Color

    @staticmethod
    def from_indexes(joints: list[Joint], idx_1: int, idx_2: int, name: str,
                     color: Color) -> "Bone":
        return Bone(joint1=joints[idx_1],
                    joint2=joints[idx_2],
                    name=name,
                    color=color)


# https://plotly.com/python-api-reference/generated/plotly.graph_objects.scatter3d.marker.html
# plotly.graph_objects.scatter3d.Marker
def to_rgb_str(color: tuple[int, int, int]) -> str:
    return f"rgb({color[0]},{color[1]},{color[2]})"


COLOR_SPINE = to_rgb_str((138, 201, 38))  # green, spine & head
COLOR_ARMS = to_rgb_str((255, 202, 58))  # yellow, arms & shoulders
COLOR_LEGS = to_rgb_str((25, 130, 196))  # blue, legs & hips
CIRCLE_SIZE = 4
LINE_WIDTH = 3

# for coco
coco_joints = [
    Joint(index=0, name="nose", color=COLOR_SPINE),
    Joint(index=1, name="left_eye", opposite_index=2, color=COLOR_SPINE),
    Joint(index=2, name="right_eye", opposite_index=1, color=COLOR_SPINE),
    Joint(index=3, name="left_ear", opposite_index=4, color=COLOR_SPINE),
    Joint(index=4, name="right_ear", opposite_index=3, color=COLOR_SPINE),
    Joint(index=5, name="left_shoulder", opposite_index=6, color=COLOR_ARMS),
    Joint(index=6, name="right_shoulder", opposite_index=5, color=COLOR_ARMS),
    Joint(index=7, name="left_elbow", opposite_index=8, color=COLOR_ARMS),
    Joint(index=8, name="right_elbow", opposite_index=7, color=COLOR_ARMS),
    Joint(index=9, name="left_wrist", opposite_index=10, color=COLOR_ARMS),
    Joint(index=10, name="right_wrist", opposite_index=9, color=COLOR_ARMS),
    Joint(index=11, name="left_hip", opposite_index=12, color=COLOR_LEGS),
    Joint(index=12, name="right_hip", opposite_index=11, color=COLOR_LEGS),
    Joint(index=13, name="left_knee", opposite_index=14, color=COLOR_LEGS),
    Joint(index=14, name="right_knee", opposite_index=13, color=COLOR_LEGS),
    Joint(index=15, name="left_ankle", opposite_index=16, color=COLOR_LEGS),
    Joint(index=16, name="right_ankle", opposite_index=15, color=COLOR_LEGS),
]

# https://github.com/lllyasviel/ControlNet/discussions/266
# for human 3.6
human_36_joints = [
    Joint(index=0, name="bottom_torso", color=COLOR_SPINE),
    Joint(index=1, name="left_hip", opposite_index=4, color=COLOR_LEGS),
    Joint(index=2, name="left_knee", opposite_index=5, color=COLOR_LEGS),
    Joint(index=3, name="left_foot", opposite_index=6, color=COLOR_LEGS),
    Joint(index=4, name="right_hip", opposite_index=1, color=COLOR_LEGS),
    Joint(index=5, name="right_knee", opposite_index=2, color=COLOR_LEGS),
    Joint(index=6, name="right_foot", opposite_index=3, color=COLOR_LEGS),
    Joint(index=7, name="center_torso", color=COLOR_SPINE),
    Joint(index=8, name="upper_torso", color=COLOR_SPINE),
    Joint(index=9, name="neck_base", color=COLOR_SPINE),
    Joint(index=10, name="center_head", color=COLOR_SPINE),
    Joint(index=11, name="right_shoulder", opposite_index=14, color=COLOR_ARMS),
    Joint(index=12, name="right_elbow", opposite_index=15, color=COLOR_ARMS),
    Joint(index=13, name="right_hand", opposite_index=16, color=COLOR_ARMS),
    Joint(index=14, name="left_shoulder", opposite_index=11, color=COLOR_ARMS),
    Joint(index=15, name="left_elbow", opposite_index=12, color=COLOR_ARMS),
    Joint(index=16, name="left_hand", opposite_index=13, color=COLOR_ARMS),
]

human_36_bones = [
    Bone.from_indexes(human_36_joints, 0, 1, "left_hip", COLOR_LEGS),
    Bone.from_indexes(human_36_joints, 1, 2, "left_knee", COLOR_LEGS),
    Bone.from_indexes(human_36_joints, 2, 3, "left_foot", COLOR_LEGS),
    Bone.from_indexes(human_36_joints, 0, 4, "right_hip", COLOR_LEGS),
    Bone.from_indexes(human_36_joints, 4, 5, "right_knee", COLOR_LEGS),
    Bone.from_indexes(human_36_joints, 5, 6, "right_foot", COLOR_LEGS),
    Bone.from_indexes(human_36_joints, 0, 7, "center_torso", COLOR_SPINE),
    Bone.from_indexes(human_36_joints, 7, 8, "upper_torso", COLOR_SPINE),
    Bone.from_indexes(human_36_joints, 8, 9, "neck_base", COLOR_SPINE),
    Bone.from_indexes(human_36_joints, 9, 10, "center_head", COLOR_SPINE),
    Bone.from_indexes(human_36_joints, 7, 11, "right_shoulder", COLOR_ARMS),
    Bone.from_indexes(human_36_joints, 11, 12, "right_elbow", COLOR_ARMS),
    Bone.from_indexes(human_36_joints, 12, 13, "right_hand", COLOR_ARMS),
    Bone.from_indexes(human_36_joints, 7, 14, "left_shoulder", COLOR_ARMS),
    Bone.from_indexes(human_36_joints, 14, 15, "left_elbow", COLOR_ARMS),
    Bone.from_indexes(human_36_joints, 15, 16, "left_hand", COLOR_ARMS),
]

# note that the keypoints are in the format of (x, y, z)
# where x is the horizontal axis, z is the depth axis, and y is the vertical axis (upwards)
# i.e. the x-z plane is the ground plane


@jaxtyped(typechecker=typechecked)
def xz_ground_2_xy_ground(xyz: Num[NDArray, "17 3"]) -> Num[NDArray, "17 3"]:
    """
    essentially make xyz -> xzy
    """
    return np.array([xyz[:, 0], xyz[:, 2], xyz[:, 1]]).T


def swap_axes_direction(arr: Num[NDArray, "17 1"]) -> Num[NDArray, "17 1"]:
    return -arr



In [None]:
from ipywidgets import interact, interactive, fixed, interact_manual, interactive_output
import ipywidgets as widgets
from IPython.display import clear_output

# don't ask me why, it just works
is_first = True
# https://stackoverflow.com/questions/63716543/plotly-how-to-update-redraw-a-plotly-express-figure-with-new-data
fw = go.FigureWidget()
def plot_frame(kps: Num[NDArray, "N 17 3"], index: int):
    global fw
    global is_first
    assert 0 <= index < kps.shape[0]
    fig = go.Figure()
    sel = kps[index]
    sel = xz_ground_2_xy_ground(sel)
    # reverse the upright axis
    sel[:, 2] = -sel[:, 2]

    scatters = [
        go.Scatter3d(x=[sel[j.index, 0]],
                     y=[sel[j.index, 1]],
                     z=[sel[j.index, 2]],
                     mode='markers',
                     marker=dict(size=CIRCLE_SIZE, color=j.color),
                     name=j.name) for j in human_36_joints
    ]
    lines = [
        go.Scatter3d(x=[sel[b.joint1.index, 0], sel[b.joint2.index, 0]],
                     y=[sel[b.joint1.index, 1], sel[b.joint2.index, 1]],
                     z=[sel[b.joint1.index, 2], sel[b.joint2.index, 2]],
                     mode='lines',
                     line=dict(color=b.color, width=LINE_WIDTH),
                     name=b.name) for b in human_36_bones
    ]
    fig.add_traces(scatters + lines)
    fw = go.FigureWidget(fig)
    if is_first:
        is_first = False
    else:
        fw.show()
    return fig


slider = widgets.IntSlider(min=0,
                           max=kps.shape[0] - 1,
                           step=1,
                           value=0,
                           continue_update=False)

p = interactive(plot_frame, kps=fixed(kps), index=slider)
display(p)
display(fw) # ignore the JavaScript error (no idea why it's there)