In [16]:
import os
import sys
from IPython.display import HTML
sys.path.append("../")
from app.npybvh.bvh import Bvh
import plotly.express as px
import plotly.graph_objects as go
from typeguard import typechecked
from jaxtyping import Int, Float, Bool, Num, jaxtyped
from pydantic import BaseModel
from typing import Optional, Union
from plotly.graph_objects import layout
from ipywidgets import interact, interactive, fixed, interact_manual, interactive_output
import ipywidgets as widgets
from IPython.display import clear_output
from dataclasses import dataclass
import numpy as np

In [17]:
anim = Bvh()
anim.parse_file("/home/zlt/Documents/SkydivingPose/output/bvh/h36m_cxk.bvh")
joint = anim.joints
for k, v in joint.items():
    print(k)

Hip
RightHip
RightKnee
RightAnkle
RightAnkle_end
LeftHip
LeftKnee
LeftAnkle
LeftAnkle_end
Spine
Thorax
Neck
Neck_end
LeftShoulder
LeftElbow
LeftWrist
LeftWrist_end
RightShoulder
RightElbow
RightWrist
RightWrist_end


In [18]:
NDArray = np.ndarray
number = Union[int, float]
Color = tuple[int, int, int] | str
Step = layout.slider.Step
Slider = layout.Slider

class Joint(BaseModel):
    index: int
    opposite_index: Optional[int] = None
    name: str
    color: Color

# https://plotly.com/python-api-reference/generated/plotly.graph_objects.scatter3d.marker.html
# plotly.graph_objects.scatter3d.Marker
def to_rgb_str(color: tuple[int, int, int]) -> str:
    return f"rgb({color[0]},{color[1]},{color[2]})"

class Bone(BaseModel):
    joint1: Joint
    joint2: Joint
    name: str
    color: Color

    @staticmethod
    def from_indexes(joints: list[Joint], idx_1: int, idx_2: int, name: str,
                     color: Color) -> "Bone":
        return Bone(joint1=joints[idx_1],
                    joint2=joints[idx_2],
                    name=name,
                    color=color)

COLOR_SPINE = to_rgb_str((138, 201, 38))  # green, spine & head
COLOR_ARMS = to_rgb_str((255, 202, 58))  # yellow, arms & shoulders
COLOR_LEGS = to_rgb_str((25, 130, 196))  # blue, legs & hips
COLOR_FINGERS = to_rgb_str((255, 0, 0))  # red, fingers
COLOR_HANDS = COLOR_FINGERS
CIRCLE_SIZE = 2
LINE_WIDTH = 3

In [19]:
joints_map_list = [  
    Joint(index=0, name="Hip", color=COLOR_SPINE),  
    Joint(index=1, name="RightHip", color=COLOR_LEGS),  
    Joint(index=2, name="RightKnee", color=COLOR_LEGS),  
    Joint(index=3, name="RightAnkle", color=COLOR_LEGS),  
    Joint(index=4, name="RightAnkle_end", color=COLOR_LEGS),  
    Joint(index=5, name="LeftHip", color=COLOR_LEGS),  
    Joint(index=6, name="LeftKnee", color=COLOR_LEGS),  
    Joint(index=7, name="LeftAnkle", color=COLOR_LEGS),  
    Joint(index=8, name="LeftAnkle_end", color=COLOR_LEGS),  
    Joint(index=9, name="Spine", color=COLOR_SPINE),  
    Joint(index=10, name="Thorax", color=COLOR_SPINE),  
    Joint(index=11, name="Neck", color=COLOR_SPINE),  
    Joint(index=12, name="Neck_end", color=COLOR_SPINE),  
    Joint(index=13, name="LeftShoulder", color=COLOR_ARMS),  
    Joint(index=14, name="LeftElbow", color=COLOR_ARMS),  
    Joint(index=15, name="LeftWrist", color=COLOR_ARMS),  
    Joint(index=16, name="LeftWrist_end", color=COLOR_ARMS),  
    Joint(index=17, name="RightShoulder", color=COLOR_ARMS),  
    Joint(index=18, name="RightElbow", color=COLOR_ARMS),  
    Joint(index=19, name="RightWrist", color=COLOR_ARMS),  
    Joint(index=20, name="RightWrist_end", color=COLOR_ARMS),  
]

In [20]:
from typing import Callable


@dataclass
class PayloadJoint:
    joint: Joint
    payload: Num[NDArray, "F 1 3"]

    @staticmethod
    def from_stacked(joint: Joint, payload: Num[NDArray,
                                                "F J 3"]) -> "PayloadJoint":
        assert payload.shape[2] == 3, "must be 3D coordinates"
        assert len(payload.shape) == 3, "must be (frames, joints, coordinates)"
        idx = joint.index
        p = payload[:, idx].reshape(-1, 1, 3)
        return PayloadJoint(joint=joint, payload=p)

    def create_scatter(self, frame: int) -> go.Scatter3d:
        total = self.payload.shape[0]
        assert frame < total, f"frame {frame} out of range {total}"
        return go.Scatter3d(x=[self.payload[frame, 0, 0]],
                            y=[self.payload[frame, 0, 1]],
                            z=[self.payload[frame, 0, 2]],
                            mode="markers",
                            marker=dict(size=CIRCLE_SIZE,
                                        color=self.joint.color),
                            name=self.joint.name)


@dataclass
class PayloadBone:
    bone: Bone
    payload: Num[NDArray, "F 1 2 3"]

    @staticmethod
    def from_stacked(bone: Bone, payload: Num[NDArray,
                                              "F J 3"]) -> "PayloadBone":
        assert payload.shape[2] == 3, "must be 3D coordinates"
        assert len(payload.shape) == 3, "must be (frames, joints, coordinates)"
        idx1 = bone.joint1.index
        idx2 = bone.joint2.index
        payload_1 = payload[:, idx1].reshape(-1, 1, 1, 3)
        payload_2 = payload[:, idx2].reshape(-1, 1, 1, 3)
        return PayloadBone(bone=bone,
                           payload=np.concatenate([payload_1, payload_2],
                                                  axis=2))

    def create_lines(self, frame: int) -> go.Scatter3d:
        total = self.payload.shape[0]
        assert frame < total, f"frame {frame} out of range {total}"
        return go.Scatter3d(x=self.payload[frame, 0, :, 0].T,
                            y=self.payload[frame, 0, :, 1].T,
                            z=self.payload[frame, 0, :, 2].T,
                            mode="lines",
                            line=dict(width=LINE_WIDTH, color=self.bone.color),
                            name=self.bone.name)


@dataclass
class Skeleton:
    bone_schema: list[Bone]
    joint_schema: list[Joint]
    joints: list[PayloadJoint]
    bones: list[PayloadBone]

    @staticmethod
    def from_stacked(joints: list[Joint], bones: list[Bone],
                     payload: Num[NDArray, "F J 3"]) -> "Skeleton":
        return Skeleton(
            bone_schema=bones,
            joint_schema=joints,
            joints=[
                PayloadJoint.from_stacked(joint, payload) for joint in joints
            ],
            bones=[PayloadBone.from_stacked(bone, payload) for bone in bones])

    @jaxtyped(typechecker=typechecked)
    def to_stacked_joints(self) -> Num[NDArray, "F J 3"]:
        return np.stack([joint.payload for joint in self.joints], axis=1)

    @jaxtyped(typechecker=typechecked)
    def to_stacked_bones(self) -> Num[NDArray, "F J 2 3"]:
        return np.stack([bone.payload for bone in self.bones], axis=1)

    @property
    def total_frames(self) -> int:
        return self.joints[0].payload.shape[0]
    
    def get_joint_by_name(self, name: str) -> PayloadJoint:  
        for joint in self.joints:  
            if joint.joint.name == name:  
                return joint  
        raise ValueError(f"Joint with name {name} not found")  
    
    def filter(self, predicate: Callable[[str], bool]) -> "Skeleton":
        """
        filter joints and bones by name
        """
        joints = [joint for joint in self.joints if predicate(joint.joint.name)]
        bones = [bone for bone in self.bones if predicate(bone.bone.name)]
        bs = [bone for bone in self.bone_schema if predicate(bone.name)]
        js = [joint for joint in self.joint_schema if predicate(joint.name)]
        return Skeleton(joints=joints,
                        bones=bones,
                        joint_schema=js,
                        bone_schema=bs)


In [21]:
pos, rot = anim.all_frame_poses()
display(pos.shape)

(30, 21, 3)

In [22]:
@jaxtyped(typechecker=typechecked)
def preprocess_data(pos: Num[NDArray, "F J 3"]) -> Num[NDArray, "F J 3"]:
    pre = pos.copy()
    pre = np.concatenate([pre[:, :, 0:1], pre[:, :, 2:3], pre[:, :, 1:2]], axis=2)
    return pre

In [23]:
joints_map_list = [  
    Joint(index=0, name="Hip", color=COLOR_SPINE),  
    Joint(index=1, name="RightHip", color=COLOR_LEGS),  
    Joint(index=2, name="RightKnee", color=COLOR_LEGS),  
    Joint(index=3, name="RightAnkle", color=COLOR_LEGS),  
    Joint(index=4, name="RightAnkle_end", color=COLOR_LEGS),  
    Joint(index=5, name="LeftHip", color=COLOR_LEGS),  
    Joint(index=6, name="LeftKnee", color=COLOR_LEGS),  
    Joint(index=7, name="LeftAnkle", color=COLOR_LEGS),  
    Joint(index=8, name="LeftAnkle_end", color=COLOR_LEGS),  
    Joint(index=9, name="Spine", color=COLOR_SPINE),  
    Joint(index=10, name="Thorax", color=COLOR_SPINE),  
    Joint(index=11, name="Neck", color=COLOR_SPINE),  
    Joint(index=12, name="Neck_end", color=COLOR_SPINE),  
    Joint(index=13, name="LeftShoulder", color=COLOR_ARMS),  
    Joint(index=14, name="LeftElbow", color=COLOR_ARMS),  
    Joint(index=15, name="LeftWrist", color=COLOR_ARMS),  
    Joint(index=16, name="LeftWrist_end", color=COLOR_ARMS),  
    Joint(index=17, name="RightShoulder", color=COLOR_ARMS),  
    Joint(index=18, name="RightElbow", color=COLOR_ARMS),  
    Joint(index=19, name="RightWrist", color=COLOR_ARMS),  
    Joint(index=20, name="RightWrist_end", color=COLOR_ARMS),  
]

In [24]:
bone_map_lists = [  
    # Spine  
    Bone.from_indexes(joints_map_list, 0, 9, "hip_to_spine", COLOR_SPINE),  
    Bone.from_indexes(joints_map_list, 9, 10, "spine_to_thorax", COLOR_SPINE),  
    Bone.from_indexes(joints_map_list, 10, 11, "thorax_to_neck", COLOR_SPINE),  
    Bone.from_indexes(joints_map_list, 11, 12, "neck_to_neck_end", COLOR_SPINE),  

    # Right Arm  
    Bone.from_indexes(joints_map_list, 10, 17, "thorax_to_right_shoulder", COLOR_ARMS),  
    Bone.from_indexes(joints_map_list, 17, 18, "right_shoulder_to_elbow", COLOR_ARMS),  
    Bone.from_indexes(joints_map_list, 18, 19, "right_elbow_to_wrist", COLOR_ARMS),  
    Bone.from_indexes(joints_map_list, 19, 20, "right_wrist_to_end", COLOR_ARMS),  

    # Left Arm  
    Bone.from_indexes(joints_map_list, 10, 13, "thorax_to_left_shoulder", COLOR_ARMS),  
    Bone.from_indexes(joints_map_list, 13, 14, "left_shoulder_to_elbow", COLOR_ARMS),  
    Bone.from_indexes(joints_map_list, 14, 15, "left_elbow_to_wrist", COLOR_ARMS),  
    Bone.from_indexes(joints_map_list, 15, 16, "left_wrist_to_end", COLOR_ARMS),  

    # Right Leg  
    Bone.from_indexes(joints_map_list, 0, 1, "hip_to_right_hip", COLOR_LEGS),  
    Bone.from_indexes(joints_map_list, 1, 2, "right_hip_to_knee", COLOR_LEGS),  
    Bone.from_indexes(joints_map_list, 2, 3, "right_knee_to_ankle", COLOR_LEGS),  
    Bone.from_indexes(joints_map_list, 3, 4, "right_ankle_to_end", COLOR_LEGS),  

    # Left Leg  
    Bone.from_indexes(joints_map_list, 0, 5, "hip_to_left_hip", COLOR_LEGS),  
    Bone.from_indexes(joints_map_list, 5, 6, "left_hip_to_knee", COLOR_LEGS),  
    Bone.from_indexes(joints_map_list, 6, 7, "left_knee_to_ankle", COLOR_LEGS),  
    Bone.from_indexes(joints_map_list, 7, 8, "left_ankle_to_end", COLOR_LEGS),  
]

In [25]:
pre = preprocess_data(pos)
sk = Skeleton.from_stacked(joints_map_list, bone_map_lists, pre)

def filter_by_name(name: str) -> bool:
    if "finger" in name.lower():
        return False
    if "index" in name.lower():
        return False
    if "middle" in name.lower():
        return False
    if "ring" in name.lower():
        return False
    if "thumb" in name.lower():
        return False
    if "pinky" in name.lower():
        return False
    if "tip" in name.lower():
        return False
    return True

sk_f = sk.filter(filter_by_name)

In [26]:
from IPython.display import display, DisplayHandle

hdl:Optional[DisplayHandle] = None
is_first = True
fw = go.FigureWidget()
def plot_frame(sk: Skeleton, index: int):
    global fw
    global is_first
    frames = sk.total_frames
    assert 0 <= index < frames, f"index must be between 0 and {frames - 1} inclusive but got {index}; frames={frames}"
    fig = go.Figure()


    scatters = [j.create_scatter(index) for j in sk.joints]
    lines = [b.create_lines(index) for b in sk.bones]

    
    fig.add_traces(scatters + lines)
    # if there's a JavaScript error
    # restart Visual Studio Code (or use `.show()` method?)
    if is_first:
        is_first = False
    else:
        assert hdl is not None, "hdl must be initialized"
        fw = go.FigureWidget(fig)
        # https://stackoverflow.com/questions/52863305/plotly-scatter3d-how-can-i-force-3d-axes-to-have-the-same-scale-aspect-ratio
        fw.update_layout(height=600, scene=dict(aspectmode="data"))
        fw.show()
        hdl.update(None)

slider = widgets.IntSlider(min=0,
                           max=sk.total_frames - 1,
                           step=1,
                           value=0,
                           continue_update=False)

p = interactive(plot_frame, sk=fixed(sk_f), index=slider)
display(p, clear=True)
hdl = display(fw, display_id=True)

interactive(children=(IntSlider(value=0, description='index', max=29), Output()), _dom_classes=('widget-intera…

None