# SMPL-details


### Environment Preparation

In [1]:
# Packages you may use very often.
import torch
import numpy as np
from smplx import SMPL
from smplx.vertex_ids import vertex_ids as VERTEX_IDS
from pytorch3d import transforms  # You may use this package when performing rotation representation transformation.

# Things you don't need to care about. They are just for driving the tutorials.
from lib.logger.look_tool import look_tensor
from lib.utils.path_manager import PathManager
from lib.viewer.wis3d_utils import HWis3D as Wis3D
from lib.skeleton import Skeleton_SMPL24

pm = PathManager()

In [2]:
body_model_smpl = SMPL(
        model_path = pm.inputs / 'body_models' / 'smpl',
        gender     = 'neutral',
    )



## Understand the Joint Regressor

You can get J_regressor from `smplx.SMPL(...).J_regressor`. It's used to regress the joints from the vertices. To be specific, it's a matrix of size `(24, 6890)`, and the joint positions can be got from the linear combination of the relative vertices positions.

In [31]:
J_regressor = body_model_smpl.J_regressor  # (24, 6890)
print(J_regressor.shape)

# Get 'g.t.' of the SMPL's output.
smpl_output = body_model_smpl()
v = smpl_output.vertices.detach()
j = smpl_output.joints[:, :24, :].detach()
print(f'v shape = {v.shape}, j shape = {j.shape}')

# Calculate the joints through J_regressor.
j_by_hand = torch.matmul(J_regressor[None], v)
print(f'j_by_hand shape = {j_by_hand.shape}')

# Check the difference.
delta = j - j_by_hand
_ = look_tensor(delta)

torch.Size([24, 6890])
v shape = torch.Size([1, 6890, 3]), j shape = torch.Size([1, 24, 3])
j_by_hand shape = torch.Size([1, 24, 3])
shape = (1, 24, 3)	dtype = torch.float32	device = cpu	min/max/mean/std = [ -0.000000 -> 0.000000 ] ~ ( 0.000000, 0.000000 )


As you can see, they have tiny difference (for the first 24 joints). But if you can get joints from the standard output object, you are still suggested to get the joints from that.  

> Sometimes, we want the SMPL style 24 joints, but we only have the SMPL-X style parameters. As we have already shown, you can't use SMPL-X parameters to get SMPL's output directly. But you can first use a regressor to get the SMPL style 6890 vertices from SMPL-X style 10475 vertices, and then use the `J_regressor` to get the SMPL style 24 joints from the regressed 6890 vertices. We won't talk about these details here, but you can keep this in mind.

Another thing you may need to know is that, the regressor is sparse, which means that most of the elements are zeros. Now let's get into the details.

In [40]:
# Statistics of the J_regressor.
zero_cnt = (J_regressor == 0).sum().item()
total_cnt = J_regressor.numel()
print(f'Among the total {total_cnt}, there are {zero_cnt} zeros, zero rate: {zero_cnt / total_cnt:.6f}%.')

# Visualize the joints.
active_masks = (J_regressor != 0)  # (24, 6890)
J_regressor_wis3d = Wis3D(
        pm.outputs / 'wis3d',
        'SMPL-J_regressor',
    )
# Add first reference skeleton and mesh.
J_regressor_wis3d.add_motion_verts(verts=v.repeat(25, 1, 1), name=f'vertices', offset=0)
J_regressor_wis3d.add_motion_skel(joints=j.repeat(25, 1, 1), bones=Skeleton_SMPL24.bones, colors=Skeleton_SMPL24.bone_colors, name=f'skeleton', offset=0)

# Visualize each part of the J_regressor.
for i in range(24):
    mask = active_masks[i]
    v_masked = v[0, mask]
    # Visualize the things of interest.
    J_regressor_wis3d.set_scene_id(i+1)
    J_regressor_wis3d.add_point_cloud(vertices=v_masked, name=f'VOI-{i}')  # Vertices of interest used to regress i-th joint.
    J_regressor_wis3d.add_spheres(centers=v_masked, radius=0.01, name=f'VOI-{i}')  # VOI used to regress i-th joint.
    J_regressor_wis3d.add_spheres(centers=j[:1, i], radius=0.02, name=f'joint-{i}')  # i-th joint.

Among the total 165360, there are 165124 zeros, zero rate: 0.998573%.


We now visualize the active elements of the regressor to show exactly which vertices influence the joints regressed. You are supposed to interact with the Wis3D UI to check the things you are interested in.

In [None]:
# Start the server. (Remember to terminate the cell before going on.)
!wis3d --vis_dir {pm.outputs / 'wis3d'} --host 0.0.0.0 --port 19090

### Understand the Joints 45

You may notice that, when I mention "joints", I always refer to the first 24. But there are actually 45 joints returned by the SMPL model by default. So what are they?

Actually, the 45 joints can be divided into two parts:

1. The common SMPL joints, which are first regressed from shapped vertices, and then transformed by the pose parameters. They are the first 24 joints.
2. Joints selected (by hand, in advance) from vertices. There are 21 of them by default. You can get the name and the vertices index of them from `smplx.vertex_ids.vertex_ids['smplh']`, I print them below.


In [51]:
print(f'Selected joints from vertices.')
print(f'---------------------------------')
print(f'|  joint_name \t|\tvid\t|')
print(f'---------------------------------')
for k, v in VERTEX_IDS['smplh'].items():
    print(f'|  {k}  \t|\t{v}\t|')
print(f'---------------------------------')

Selected joints from vertices.
---------------------------------
|  joint_name 	|	vid	|
---------------------------------
|  nose  	|	332	|
|  reye  	|	6260	|
|  leye  	|	2800	|
|  rear  	|	4071	|
|  lear  	|	583	|
|  rthumb  	|	6191	|
|  rindex  	|	5782	|
|  rmiddle  	|	5905	|
|  rring  	|	6016	|
|  rpinky  	|	6133	|
|  lthumb  	|	2746	|
|  lindex  	|	2319	|
|  lmiddle  	|	2445	|
|  lring  	|	2556	|
|  lpinky  	|	2673	|
|  LBigToe  	|	3216	|
|  LSmallToe  	|	3226	|
|  LHeel  	|	3387	|
|  RBigToe  	|	6617	|
|  RSmallToe  	|	6624	|
|  RHeel  	|	6787	|
---------------------------------


You can check the implementation at `smplx.SMPL` for more details, SMPL provide some APIs to customize the outputs. They will be useful if you need to change the definition of the output joints (in order, or maybe select more joints from vertices).

### Become Lighter

The sparsity of the regressor tells that, if we only need the joint positions, we don't have to get all the vertices. Recall the inference process of SMPL, we found that we can ignore a lot of calculation while performing the Linear Blend Skinning, LBS.

## Understand the Blending Skinning

In [6]:
def load_eg_params(eg_path):
    eg_params = np.load(eg_path, allow_pickle=True).item()
    betas = torch.from_numpy(eg_params['betas']).squeeze(0)  # (10,)
    poses = torch.cat([
            torch.from_numpy(eg_params['global_orient']).squeeze(0),  # (1, 3)
            torch.from_numpy(eg_params['body_pose']).squeeze(0),  # (23, 3)
        ], dim=0)  # (24, 3)
    poses[0, :] = 0  # clear the global rotation for easy visualization
    poses = transforms.axis_angle_to_matrix(poses)  # (24, 3, 3)
    return betas, poses

In [7]:
# Load example data.
betas, poses = load_eg_params(pm.inputs / 'examples/ballerina.npy')

### Reproduce the LBS

For easier testing on different cases, you can simply change the loaded data and just rerun this sections (4 blocks in total).

In [8]:
# Obtain the necessary data matrices from the body model for reproducing.
# These things are gender-specific and are contained in the SMPL model files.
v_template  = body_model_smpl.v_template   # (6890, 3)
shape_disp  = body_model_smpl.shapedirs    # (6890, 3, 10)
poses_disp  = body_model_smpl.posedirs     # (207=23*9, 20670=6890*3)
lbs_weights = body_model_smpl.lbs_weights  # (6890, 24)
parents     = body_model_smpl.parents      # (24,), indicating the kinematic stuctures
J_regressor = body_model_smpl.J_regressor  # (24, 6890)

In [9]:
# Reproduce smplx.lbs.lbs().

# \bar{T}, the template in the rest pose.
v_temp = v_template.clone()  # (6890, 3)

# B_S, the shape-specific template deformation.
Bs = torch.einsum('l,mkl->mk', betas, shape_disp).clone()  # (6890, 3)
v_temp_Bs = v_temp + Bs  # (6890, 3)

# B_P, the pose-dependent template deformation.
ident = torch.eye(3)  # (3, 3), indicates 'offsets'
poses_feat = (poses[1:] - ident[None]).reshape(-1)  # (207=23*9,) root orientation not included.
Bp = torch.einsum('l,lk->k', poses_feat, poses_disp).reshape(6890, 3).clone()  # (6890, 3)
v_temp_Bs_Bp = v_temp_Bs + Bp  # (6890, 3)
v_temp_Bp = v_temp + Bp  # (6890, 3), for checking

# B_{pose}, we should perfom a FK to get the global transformations for each joint.
from smplx.lbs import batch_rigid_transform, vertices2joints

def rig_v_temp(v):
    J_temp = vertices2joints(J_regressor, v[None])  # (1, 24, 3), the transformations is performed based on the skeleton
    J_final, joint_global_orents = batch_rigid_transform(poses, J_temp, parents)  # FK the local rotations to global ones
    A = joint_global_orents.squeeze(0)  # (24, 4, 4)
    W = lbs_weights  # (6890, 24)
    T = torch.einsum('vj,jxy->vxy', W, A)  # (6890, 4, 4), remove the rest pose
    homogen_coord = torch.ones([6890, 1])  # (6890, 1)
    v_posed_homo = torch.cat([v, homogen_coord], dim=1)  # (6890, 4)
    v_homo = torch.einsum('vij,vj->vi', T, v_posed_homo)  # (6890, 4)
    return {
        'j_temp' : J_temp.squeeze(),
        'j'      : J_final.squeeze(),
        'v'      : v_homo[:, :3]
    }

final_on_temp_Bs_Bp = rig_v_temp(v_temp_Bs_Bp)
final_on_temp_Bs    = rig_v_temp(v_temp_Bs)
final_on_temp       = rig_v_temp(v_temp)
final_on_temp_Bp    = rig_v_temp(v_temp_Bp)

In [None]:
# Visualize.
LBS_wis3d = Wis3D(
        pm.outputs / 'wis3d',
        'SMPL-LBS',
    )
faces = body_model_smpl.faces

LBS_wis3d.set_scene_id(0)
LBS_wis3d.add_mesh(vertices = v_temp, faces = faces, name = 'v_template')
LBS_wis3d.add_motion_skel(
    joints = final_on_temp['j_temp'][None],
    bones  = Skeleton_SMPL24.bones,
    colors = Skeleton_SMPL24.bone_colors,
    name   = 'j_temp at mean pose',
    offset = 0,
)

LBS_wis3d.set_scene_id(1)
LBS_wis3d.add_mesh(vertices = v_temp_Bs, faces = faces, name = 'v_shaped')
LBS_wis3d.add_motion_skel(
    joints = final_on_temp_Bs['j_temp'][None],
    bones  = Skeleton_SMPL24.bones,
    colors = Skeleton_SMPL24.bone_colors,
    name   = 'j_temp after shape deformation',
    offset = 1,
)

LBS_wis3d.set_scene_id(2)
LBS_wis3d.add_mesh(vertices = v_temp_Bs_Bp, faces = faces, name = 'v_posed')
LBS_wis3d.add_motion_skel(
    joints = final_on_temp_Bs_Bp['j_temp'][None],
    bones  = Skeleton_SMPL24.bones,
    colors = Skeleton_SMPL24.bone_colors,
    name   = 'j_temp after shape deformation and pose-dependent deformation',
    offset = 2,
)

LBS_wis3d.set_scene_id(3)
LBS_wis3d.add_mesh(vertices=final_on_temp_Bs_Bp['v'], faces=faces, name='v_final')
LBS_wis3d.add_mesh(vertices=final_on_temp_Bs['v'],    faces=faces, name='v_final without B_P')
LBS_wis3d.add_mesh(vertices=final_on_temp_Bp['v'],    faces=faces, name='v_final without B_S')
LBS_wis3d.add_mesh(vertices=final_on_temp['v'],       faces=faces, name='v_final without B_P and B_S')
LBS_wis3d.add_motion_skel(
    joints = final_on_temp_Bs_Bp['j'][None],
    bones  = Skeleton_SMPL24.bones,
    colors = Skeleton_SMPL24.bone_colors,
    name   = 'j_final',
    offset = 3,
)

torch.Size([24, 3])


In [None]:
# Start the server. (Remember to terminate the cell before going on.)
!wis3d --vis_dir {pm.outputs / 'wis3d'} --host 0.0.0.0 --port 19090