In [2]:
import numpy as np
import pickle
import os
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torch
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import ipywidgets
from ipywidgets import interact, IntSlider
from smplx import SMPLX
import torch



In [3]:
# List of joint names
joint_names = [
    'pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee',
    'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot',
    'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder',
    'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist',
    'jaw', 'left_eye_smplhf', 'right_eye_smplhf', 'left_index1', 'left_index2',
    'left_index3', 'left_middle1', 'left_middle2', 'left_middle3',
    'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2',
    'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1',
    'right_index2', 'right_index3', 'right_middle1', 'right_middle2',
    'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3',
    'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1',
    'right_thumb2', 'right_thumb3'
]

In [4]:
sample_point_path = 'data/motion/smplx/train/Ballet_001_001_00.npy'
pos3d_path = 'data/motion/pos3d/train/Ballet_001_001_00.npy'

In [5]:
smplx_data = np.load(sample_point_path, allow_pickle=True).item()
pos3d_data = np.load(pos3d_path)

In [6]:
for key, value in smplx_data.items():
    print(key, type(value))

transl <class 'numpy.ndarray'>
poses <class 'numpy.ndarray'>
betas <class 'numpy.ndarray'>
global_orient <class 'numpy.ndarray'>
meta <class 'dict'>


In [7]:
for key, value in smplx_data.items():
    if isinstance(value, np.ndarray):
        print(key, value.shape)

transl (2293, 3)
poses (2293, 165)
betas (2293, 16)
global_orient (2293, 3)


In [8]:
def vis_one_point(data):
    '''data is in (T, 165)'''

    # Example data
    T = data.shape[0]  # Number of frames
    joints_data = data.reshape(-1, 55, 3)  # Replace this with your actual data

    # Function to plot the skeleton
    def plot_skeleton(joints, ax, c):
        # Define the connections between joints
        skeleton_connections = [
            ('pelvis', 'spine1'),('pelvis', 'right_hip'), ('pelvis','left_hip'), ('spine1', 'spine2'), ('spine2', 'spine3'), ('spine3', 'neck'),
            ('neck', 'head'), ('spine3', 'left_collar'), ('spine3', 'right_collar'),
            ('left_collar', 'left_shoulder'), ('right_collar', 'right_shoulder'),
            ('left_shoulder', 'left_elbow'), ('right_shoulder', 'right_elbow'),
            ('left_elbow', 'left_wrist'), ('right_elbow', 'right_wrist'),
            ('left_hip', 'left_knee'), ('right_hip', 'right_knee'),
            ('left_knee', 'left_ankle'), ('right_knee', 'right_ankle'),
            ('left_ankle', 'left_foot'), ('right_ankle', 'right_foot'),
            ('head', 'jaw'), ('jaw', 'left_eye_smplhf'), ('jaw', 'right_eye_smplhf'),
            ('left_wrist', 'left_thumb1'), ('left_wrist', 'left_index1'), ('left_wrist', 'left_middle1'),
            ('left_wrist', 'left_ring1'), ('left_wrist', 'left_pinky1'),
            ('right_wrist', 'right_thumb1'), ('right_wrist', 'right_index1'), ('right_wrist', 'right_middle1'),
            ('right_wrist', 'right_ring1'), ('right_wrist', 'right_pinky1'), ('left_thumb1', 'left_thumb2'),
            ('left_index1', 'left_index2'), ('left_middle1', 'left_middle2'), ('left_ring1', 'left_ring2'),
            ('left_pinky1', 'left_pinky2'), ('right_thumb1', 'right_thumb2'), ('right_index1', 'right_index2'),
            ('right_middle1', 'right_middle2'), ('right_ring1', 'right_ring2'), ('right_pinky1', 'right_pinky2'),
            ('left_thumb2', 'left_thumb3'), ('left_index2', 'left_index3'), ('left_middle2', 'left_middle3'),
            ('left_ring2', 'left_ring3'), ('left_pinky2', 'left_pinky3'), ('right_thumb2', 'right_thumb3'),
            ('right_index2', 'right_index3'), ('right_middle2', 'right_middle3'), ('right_ring2', 'right_ring3'),
            ('right_pinky2', 'right_pinky3')
        ]
        

        # Convert joint names to indices
        joint_indices = {name: i for i, name in enumerate(joint_names)}

        for joint1_name, joint2_name in skeleton_connections:
            joint1_index = joint_indices[joint1_name]
            joint2_index = joint_indices[joint2_name]

            joint1 = joints[joint1_index]
            joint2 = joints[joint2_index]

            ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], [joint1[2], joint2[2]], color=c)

    def plot_frame(t):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        # Plot joints
        ax.scatter(joints_data[t, :, 0], joints_data[t, :, 1], joints_data[t, :, 2], c='r')
        
        # Plot skeleton
        plot_skeleton(joints_data[t], ax, 'r')

        # Set labels and title
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f'Skeleton Visualization - Frame {t+1}')

        xmin = np.min(joints_data[:, :, 0])
        xmax = np.max(joints_data[:, :, 0])
        ymin = np.min(joints_data[:, :, 1])
        ymax = np.max(joints_data[:, :, 1])
        zmin = np.min(joints_data[:, :, 2])
        zmax = np.max(joints_data[:, :, 2])

        # Set fixed axis limits
        ax.set_xlim([xmin, xmax])  # Adjust xmin and xmax according to your data
        ax.set_ylim([ymin, ymax])  # Adjust ymin and ymax according to your data
        ax.set_zlim([zmin, zmax])  # Adjust zmin and zmax according to your data


        # Show plot
        plt.show()

    return ipywidgets.interact(plot_frame, t=IntSlider(min=0, max=data.shape[0]-1, step=1))


In [9]:
def vis_two_point(data, data1):
    '''data is in (T, 165)'''

    # Example data
    T = data.shape[0]  # Number of frames
    joints_data = data.reshape(-1, 55, 3)  # Replace this with your actual data
    joints_data_1 = data1.reshape(-1, 55, 3)

    # Function to plot the skeleton
    def plot_skeleton(joints, ax, c):
        # Define the connections between joints
        skeleton_connections = [
            ('pelvis', 'spine1'),('pelvis', 'right_hip'), ('pelvis','left_hip'), ('spine1', 'spine2'), ('spine2', 'spine3'), ('spine3', 'neck'),
            ('neck', 'head'), ('spine3', 'left_collar'), ('spine3', 'right_collar'),
            ('left_collar', 'left_shoulder'), ('right_collar', 'right_shoulder'),
            ('left_shoulder', 'left_elbow'), ('right_shoulder', 'right_elbow'),
            ('left_elbow', 'left_wrist'), ('right_elbow', 'right_wrist'),
            ('left_hip', 'left_knee'), ('right_hip', 'right_knee'),
            ('left_knee', 'left_ankle'), ('right_knee', 'right_ankle'),
            ('left_ankle', 'left_foot'), ('right_ankle', 'right_foot'),
            ('head', 'jaw'), ('jaw', 'left_eye_smplhf'), ('jaw', 'right_eye_smplhf'),
            ('left_wrist', 'left_thumb1'), ('left_wrist', 'left_index1'), ('left_wrist', 'left_middle1'),
            ('left_wrist', 'left_ring1'), ('left_wrist', 'left_pinky1'),
            ('right_wrist', 'right_thumb1'), ('right_wrist', 'right_index1'), ('right_wrist', 'right_middle1'),
            ('right_wrist', 'right_ring1'), ('right_wrist', 'right_pinky1'), ('left_thumb1', 'left_thumb2'),
            ('left_index1', 'left_index2'), ('left_middle1', 'left_middle2'), ('left_ring1', 'left_ring2'),
            ('left_pinky1', 'left_pinky2'), ('right_thumb1', 'right_thumb2'), ('right_index1', 'right_index2'),
            ('right_middle1', 'right_middle2'), ('right_ring1', 'right_ring2'), ('right_pinky1', 'right_pinky2'),
            ('left_thumb2', 'left_thumb3'), ('left_index2', 'left_index3'), ('left_middle2', 'left_middle3'),
            ('left_ring2', 'left_ring3'), ('left_pinky2', 'left_pinky3'), ('right_thumb2', 'right_thumb3'),
            ('right_index2', 'right_index3'), ('right_middle2', 'right_middle3'), ('right_ring2', 'right_ring3'),
            ('right_pinky2', 'right_pinky3')
        ]
        

        # Convert joint names to indices
        joint_indices = {name: i for i, name in enumerate(joint_names)}

        for joint1_name, joint2_name in skeleton_connections:
            joint1_index = joint_indices[joint1_name]
            joint2_index = joint_indices[joint2_name]

            joint1 = joints[joint1_index]
            joint2 = joints[joint2_index]

            ax.plot([joint1[0], joint2[0]], [joint1[1], joint2[1]], [joint1[2], joint2[2]], color=c)

    def plot_frame(t):
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        
        # Plot joints
        ax.scatter(joints_data[t, :, 0], joints_data[t, :, 1], joints_data[t, :, 2], c='r', s=0.1)
        ax.scatter(joints_data_1[t, :, 0], joints_data_1[t, :, 1], joints_data_1[t, :, 2], c='b', s=0.1)
        
        # Plot skeleton
        plot_skeleton(joints_data[t], ax, 'r')
        plot_skeleton(joints_data_1[t], ax, 'b')

        # Set labels and title
        ax.set_xlabel('X')
        ax.set_ylabel('Y')
        ax.set_zlabel('Z')
        ax.set_title(f'Skeleton Visualization - Frame {t+1}')

        xmin = np.min(joints_data[:, :, 0])
        xmin = min(xmin, np.min(joints_data_1[:, :, 0]))
        xmax = np.max(joints_data[:, :, 0])
        xmax = max(xmax, np.max(joints_data_1[:, :, 0]))
        ymin = np.min(joints_data[:, :, 1])
        ymin = min(ymin, np.min(joints_data_1[:, :, 1]))
        ymax = np.max(joints_data[:, :, 1])
        ymax = max(ymax, np.max(joints_data_1[:, :, 1]))
        zmin = np.min(joints_data[:, :, 2])
        zmin = min(zmin, np.min(joints_data_1[:, :, 2]))
        zmax = np.max(joints_data[:, :, 2])
        zmax = max(zmax, np.max(joints_data_1[:, :, 2]))

        # Set fixed axis limits
        ax.set_xlim([xmin, xmax])  # Adjust xmin and xmax according to your data
        ax.set_ylim([ymin, ymax])  # Adjust ymin and ymax according to your data
        ax.set_zlim([zmin, zmax])  # Adjust zmin and zmax according to your data


        # Show plot
        plt.show()

    return ipywidgets.interact(plot_frame, t=IntSlider(min=0, max=data.shape[0]-1, step=1))

In [10]:
vis = vis_one_point(pos3d_data)

interactive(children=(IntSlider(value=0, description='t', max=2292), Output()), _dom_classes=('widget-interact…

In [11]:
def smplx_to_pos3d(data, global_orient=None, betas=None):
    smplx = None
    
    if global_orient is not None:
        global_orient = global_orient
    else:
        global_orient = data['global_orient']

    if betas is not None:
        betas = betas
    else:
        betas = data['betas']

    smplx = SMPLX(model_path='./smplx', betas=betas[:, :10], gender=data['meta']['gender'], \
        batch_size=len(data['betas']), num_betas=10, use_pca=False, use_face_contour=True, flat_hand_mean=True)

    keypoints3d = smplx.forward(
        global_orient=torch.from_numpy(global_orient).float(),
        body_pose=torch.from_numpy(data['poses'][:, 3:66]).float(),
        jaw_pose=torch.from_numpy(data['poses'][:, 66:69]).float(),
        leye_pose=torch.from_numpy(data['poses'][:, 69:72]).float(),
        reye_pose=torch.from_numpy(data['poses'][:, 72:75]).float(),
        left_hand_pose=torch.from_numpy(data['poses'][:, 75:120]).float(),
        right_hand_pose=torch.from_numpy(data['poses'][:, 120:]).float(),
        transl=torch.from_numpy(data['transl']).float(),
        betas=torch.from_numpy(betas[:, :10]).float()
        ).joints.detach().numpy()[:, :55]

    nframes = keypoints3d.shape[0]
    return keypoints3d.reshape(nframes, -1)

In [12]:
smplx_pos3d = smplx_to_pos3d(smplx_data)
vis = vis_one_point(smplx_pos3d)

interactive(children=(IntSlider(value=0, description='t', max=2292), Output()), _dom_classes=('widget-interact…

In [13]:
global_orient_sample = np.array([-2.20529366, -2.71022929, -2.19756915])[None, :].repeat(2293, axis=0)
smplx_pos3d_with_changed_global_orient = smplx_to_pos3d(smplx_data, global_orient_sample)
vis = vis_two_point(smplx_pos3d, smplx_pos3d_with_changed_global_orient)

interactive(children=(IntSlider(value=0, description='t', max=2292), Output()), _dom_classes=('widget-interact…

In [14]:
betas_sample = np.array([-1.20529366, -0.71022929])[None, :].repeat(2293, axis=0).repeat(8, axis=1)
smplx_pos3d_with_changed_betas = smplx_to_pos3d(smplx_data, betas=betas_sample)
vis = vis_two_point(smplx_pos3d, smplx_pos3d_with_changed_betas)

interactive(children=(IntSlider(value=0, description='t', max=2292), Output()), _dom_classes=('widget-interact…

In [15]:
sample_point_path_1 = 'data/motion/smplx/train/Ballet_001_001_00.npy'
sample_point_path_2 = 'data/motion/smplx/train/Jive_001_002_00.npy'
sample_point_path_3 = 'data/motion/smplx/train/Lumba_007_004_00.npy'
sample_point_path_4 = 'data/motion/smplx/train/Samba_009_002_00.npy'
sample_point_path_5 = 'data/motion/smplx/train/Waltz_009_000_01.npy'

In [16]:
smplx_data_1 = np.load(sample_point_path_1, allow_pickle=True).item()
smplx_data_2 = np.load(sample_point_path_2, allow_pickle=True).item()
smplx_data_3 = np.load(sample_point_path_3, allow_pickle=True).item()
smplx_data_4 = np.load(sample_point_path_4, allow_pickle=True).item()
smplx_data_5 = np.load(sample_point_path_5, allow_pickle=True).item()

In [27]:
global_orient_1 = smplx_data_1['global_orient']
global_orient_2 = smplx_data_2['global_orient']
global_orient_3 = smplx_data_3['global_orient']
global_orient_4 = smplx_data_4['global_orient']
global_orient_5 = smplx_data_5['global_orient']

betas_1 = smplx_data_1['betas']
betas_2 = smplx_data_2['betas']
betas_3 = smplx_data_3['betas']
betas_4 = smplx_data_4['betas']
betas_5 = smplx_data_5['betas']

transl_1 = smplx_data_1['transl']
transl_2 = smplx_data_2['transl']
transl_3 = smplx_data_3['transl']
transl_4 = smplx_data_4['transl']
transl_5 = smplx_data_5['transl']

poses_1 = smplx_data_1['poses']
poses_2 = smplx_data_2['poses']
poses_3 = smplx_data_3['poses']
poses_4 = smplx_data_4['poses']
poses_5 = smplx_data_5['poses']

In [18]:
global_orient_1

array([[-1.20529366, -2.71022929, -2.19756915],
       [-1.21856464, -2.71093919, -2.19654832],
       [-1.23163479, -2.71233633, -2.19318053],
       ...,
       [ 0.57785498,  2.29601516,  1.82616812],
       [ 0.57932835,  2.29418565,  1.82392229],
       [ 0.58006274,  2.29270154,  1.82238811]])

In [19]:
global_orient_2

array([[-4.68571992, -0.42934413,  0.07549469],
       [-4.67228904, -0.4606938 ,  0.09051098],
       [-4.65406505, -0.471385  ,  0.06196224],
       ...,
       [ 1.42776035, -0.99183342, -1.02153071],
       [ 1.43381482, -1.02973781, -1.08804582],
       [ 1.44991503, -1.08463763, -1.12267701]])

In [20]:
global_orient_3

array([[-0.57672734,  2.27797361,  2.35620388],
       [-0.57572755,  2.2784426 ,  2.35571599],
       [-0.57546977,  2.27836701,  2.35562341],
       ...,
       [-3.28442229, -1.97932754, -2.18627699],
       [-3.3049245 , -1.98592084, -2.21367749],
       [-3.32779602, -2.0146971 , -2.19443053]])

In [24]:
betas_1

array([[ 0.10251013, -0.61846808,  0.91269899, ...,  1.65483133,
         0.45901326, -0.17232701],
       [ 0.10251013, -0.61846808,  0.91269899, ...,  1.65483133,
         0.45901326, -0.17232701],
       [ 0.10251013, -0.61846808,  0.91269899, ...,  1.65483133,
         0.45901326, -0.17232701],
       ...,
       [ 0.10251013, -0.61846808,  0.91269899, ...,  1.65483133,
         0.45901326, -0.17232701],
       [ 0.10251013, -0.61846808,  0.91269899, ...,  1.65483133,
         0.45901326, -0.17232701],
       [ 0.10251013, -0.61846808,  0.91269899, ...,  1.65483133,
         0.45901326, -0.17232701]])

In [25]:
betas_2

array([[-0.12324779,  0.14492362,  2.34369865, ...,  2.20512181,
         0.64467053, -0.77347555],
       [-0.12324779,  0.14492362,  2.34369865, ...,  2.20512181,
         0.64467053, -0.77347555],
       [-0.12324779,  0.14492362,  2.34369865, ...,  2.20512181,
         0.64467053, -0.77347555],
       ...,
       [-0.12324779,  0.14492362,  2.34369865, ...,  2.20512181,
         0.64467053, -0.77347555],
       [-0.12324779,  0.14492362,  2.34369865, ...,  2.20512181,
         0.64467053, -0.77347555],
       [-0.12324779,  0.14492362,  2.34369865, ...,  2.20512181,
         0.64467053, -0.77347555]])

In [50]:
betas_3

array([[ 0.97719742, -0.61611486,  0.39504089, ...,  1.4639554 ,
         2.18054136,  1.11218266],
       [ 0.97719742, -0.61611486,  0.39504089, ...,  1.4639554 ,
         2.18054136,  1.11218266],
       [ 0.97719742, -0.61611486,  0.39504089, ...,  1.4639554 ,
         2.18054136,  1.11218266],
       ...,
       [ 0.97719742, -0.61611486,  0.39504089, ...,  1.4639554 ,
         2.18054136,  1.11218266],
       [ 0.97719742, -0.61611486,  0.39504089, ...,  1.4639554 ,
         2.18054136,  1.11218266],
       [ 0.97719742, -0.61611486,  0.39504089, ...,  1.4639554 ,
         2.18054136,  1.11218266]])

In [28]:
transl_1

array([[-0.65621082, -1.64838262,  0.89533397],
       [-0.65358141, -1.64641348,  0.89558762],
       [-0.65075645, -1.64410038,  0.89569508],
       ...,
       [ 0.31251716,  1.28121588,  0.89527482],
       [ 0.31204591,  1.28155545,  0.89525168],
       [ 0.31165998,  1.28201863,  0.89523992]])

In [29]:
transl_2

array([[-0.1392291 ,  0.85531709,  0.87647833],
       [-0.13817478,  0.85791543,  0.87572498],
       [-0.13695065,  0.86209409,  0.87633207],
       ...,
       [ 0.0819093 ,  1.46657049,  0.82925938],
       [ 0.08618805,  1.49243552,  0.82033775],
       [ 0.09191696,  1.51709278,  0.81286401]])

In [30]:
transl_3

array([[ 0.48883599, -1.67993234,  0.91055326],
       [ 0.48860488, -1.67950831,  0.91050977],
       [ 0.4884215 , -1.67921145,  0.91035904],
       ...,
       [ 0.24785432,  2.19335491,  0.93813843],
       [ 0.24655881,  2.19034343,  0.93676046],
       [ 0.2459825 ,  2.18775017,  0.93498332]])

In [31]:
poses_1

array([[-1.20529366e+00, -2.71022929e+00, -2.19756915e+00, ...,
        -1.35963051e-08, -1.83402816e-08,  1.03306387e-01],
       [-1.21856464e+00, -2.71093919e+00, -2.19654832e+00, ...,
         2.55313225e-08,  7.40311933e-08,  1.02987865e-01],
       [-1.23163479e+00, -2.71233633e+00, -2.19318053e+00, ...,
        -1.93153325e-08, -5.61626370e-08,  1.03213361e-01],
       ...,
       [ 5.77854983e-01,  2.29601516e+00,  1.82616812e+00, ...,
        -3.17678840e-08,  2.86997946e-08,  8.88824630e-02],
       [ 5.79328348e-01,  2.29418565e+00,  1.82392229e+00, ...,
         2.39310321e-08,  4.02893858e-08,  8.92121556e-02],
       [ 5.80062743e-01,  2.29270154e+00,  1.82238811e+00, ...,
         5.05226760e-08,  3.17170247e-08,  9.18301495e-02]])

In [32]:
poses_2

array([[-4.68571992e+00, -4.29344128e-01,  7.54946885e-02, ...,
        -1.74372591e-07, -9.16162478e-09,  1.04984487e-01],
       [-4.67228904e+00, -4.60693800e-01,  9.05109809e-02, ...,
        -1.39355985e-07,  3.24940511e-07,  1.04300719e-01],
       [-4.65406505e+00, -4.71384997e-01,  6.19622352e-02, ...,
         6.88481161e-08,  2.10488461e-08,  1.03451023e-01],
       ...,
       [ 1.42776035e+00, -9.91833418e-01, -1.02153071e+00, ...,
        -5.49431830e-09,  1.57015567e-07,  6.99558333e-02],
       [ 1.43381482e+00, -1.02973781e+00, -1.08804582e+00, ...,
         5.53979482e-08, -8.53949890e-08,  7.01191612e-02],
       [ 1.44991503e+00, -1.08463763e+00, -1.12267701e+00, ...,
         4.43862447e-08, -2.60418594e-07,  7.25361979e-02]])

In [33]:
poses_3

array([[-5.76727338e-01,  2.27797361e+00,  2.35620388e+00, ...,
         1.08262203e-07, -8.28945931e-08,  8.25499895e-02],
       [-5.75727545e-01,  2.27844260e+00,  2.35571599e+00, ...,
        -5.95362381e-08,  1.72271904e-07,  8.25760124e-02],
       [-5.75469774e-01,  2.27836701e+00,  2.35562341e+00, ...,
         2.04966909e-07,  1.13363483e-07,  8.31172914e-02],
       ...,
       [-3.28442229e+00, -1.97932754e+00, -2.18627699e+00, ...,
        -3.49561837e-08,  1.73538594e-08,  5.69439848e-03],
       [-3.30492450e+00, -1.98592084e+00, -2.21367749e+00, ...,
         8.64797904e-08,  1.57514957e-07,  1.00069500e-02],
       [-3.32779602e+00, -2.01469710e+00, -2.19443053e+00, ...,
        -2.12159408e-07,  3.82480500e-07,  1.41854427e-02]])

In [40]:
poses_3.min()

-4.704158503118254

In [41]:
poses_3.max()

2.817045597871275

In [42]:
poses_3.mean()

0.0007535540178554263

In [46]:
poses_3.std()

0.3877961526511464