# Tutorial 2: SMPL-X principle

Author: [Hejia Chen](http://harryxd2018.github.io), BUAA

In this tutorial, we will learn how to use SMPL-X model to generate a 3D human mesh.

## Prerequisites

In [1]:
import torch
from utils import write_obj, load_pkl

In [2]:
smplx_neutral = load_pkl('./models/smplx/SMPLX_NEUTRAL.pkl', to_torch=True)
smplx_neutral.keys()

dict_keys(['dynamic_lmk_bary_coords', 'hands_componentsl', 'ft', 'lmk_faces_idx', 'f', 'J_regressor', 'hands_componentsr', 'kintree_table', 'hands_coeffsr', 'joint2num', 'hands_meanl', 'lmk_bary_coords', 'weights', 'posedirs', 'dynamic_lmk_faces_idx', 'part2num', 'vt', 'hands_meanr', 'hands_coeffsl', 'v_template', 'shapedirs'])

## Step 1: Shape and expression control

The SMPL-X model is parameterized by shape and expression parameters. The shape parameters control the body shape, while the expression parameters control the facial expression. The shape and expression parameters are 300-dimensional and 100-dimensional vectors, respectively. We can set the shape and expression parameters to zero to get the neutral body. The corresponding blendshapes will be added to the neutral body to generate the final mesh, which is stored as `shapedirs` of the SMPL-X model.

In [3]:
print('shapedirs dim:', smplx_neutral['shapedirs'].shape)

betas = torch.zeros([1, 300], dtype=torch.float32)      # as the shape parameters
psi = torch.zeros([1, 100], dtype=torch.float32)        # as the expression parameters

shapedirs dim: torch.Size([10475, 3, 400])


In [4]:
shape_deformation = torch.sum(smplx_neutral['shapedirs'][..., :300] * betas, dim=-1)
expression_deformation = torch.sum(smplx_neutral['shapedirs'][..., 300:] * psi, dim=-1)
print(f"shape_deformation shape: {shape_deformation.shape}")
print(f"expression_deformation shape: {expression_deformation.shape}")
# new_vertice = smpl_neutral['v_template'].unsqueeze(0) + smpl_neutral['shapedirs'][..., :300] * betas.unsqueeze(-1) + smpl_neutral['expressions'][..., :100] * psi.unsqueeze(-1)
# new_vertice.shape

shape_deformation shape: torch.Size([10475, 3])
expression_deformation shape: torch.Size([10475, 3])


In [5]:
new_vertice = smplx_neutral['v_template'] + shape_deformation + expression_deformation
# print(f"new_vertice shape: {new_vertice.shape}")
write_obj(vertices=new_vertice, faces=smplx_neutral['f'], file_name='./obj/smplx_neutral.obj')

Of course, the result is expected to be a neutral body, because we haven't changed the parameters yet. Let's try to change the shape parameters to see what will happen.

In [6]:
betas[0, 0] = 1.0
shape_deformation = torch.sum(smplx_neutral['shapedirs'][..., :300] * betas, dim=-1)
expression_deformation = torch.sum(smplx_neutral['shapedirs'][..., 300:] * psi, dim=-1)
new_vertice = smplx_neutral['v_template'] + shape_deformation + expression_deformation
write_obj(vertices=new_vertice, faces=smplx_neutral['f'], file_name='./obj/smplx_tall.obj')

Now we can see that the body is taller than the neutral body. We can also change the expression parameters to see what will happen.

In [7]:
psi[0, 0] = 1.0
shape_deformation = torch.sum(smplx_neutral['shapedirs'][..., :300] * betas, dim=-1)
expression_deformation = torch.sum(smplx_neutral['shapedirs'][..., 300:] * psi, dim=-1)
new_vertice = smplx_neutral['v_template'] + shape_deformation + expression_deformation
write_obj(vertices=new_vertice, faces=smplx_neutral['f'], file_name='./obj/smplx_smile.obj')

## Step 2: Pose deformation

Body pose is also affecting the deformation of the body. In SMPL-X model, there are 55 joints, and each joint has 3 degrees of freedom. Therefore, the pose parameters `theta` are a 3x55 matrix. For each joint, the pose is represented under the rotation vector format. We can set the pose parameters to zero to get the T-pose body. The pose deformation blendshape are stored as `posedirs` of the SMPL-X model, which is corresponding to each element in the **rotaion matrix**. So to get the final pose deformation, we need to convert the rotation matrix to rotation vector first.

For those who are not familiar with rotation vector, you can refer to following docs in Chinese:
- [blog 1](https://zhuanlan.zhihu.com/p/451579313)
- [blog 2](https://blog.csdn.net/Crystal_YS/article/details/103622853)
- [blog 3](https://zhuanlan.zhihu.com/p/147791525)

The conversion from rotation vector to rotation matrix is calculated by the **Rodrigues' rotation formula**:

In [8]:
def batch_rodrigues(
    rot_vecs: torch.Tensor,
    epsilon: float = 1e-8,
) -> torch.Tensor:
    ''' Calculates the rotation matrices for a batch of rotation vectors
        Parameters
        ----------
        rot_vecs: torch.tensor Nx3
            array of N axis-angle vectors
        Returns
        -------
        R: torch.tensor Nx3x3
            The rotation matrices for the given axis-angle parameters
    '''

    batch_size = rot_vecs.shape[0]
    device, dtype = rot_vecs.device, rot_vecs.dtype

    angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
    rot_dir = rot_vecs / angle

    cos = torch.unsqueeze(torch.cos(angle), dim=1)
    sin = torch.unsqueeze(torch.sin(angle), dim=1)

    # Bx1 arrays
    rx, ry, rz = torch.split(rot_dir, 1, dim=1)
    K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)

    zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
    K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
        .view((batch_size, 3, 3))

    ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
    rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
    return rot_mat

Now we can get the pose deformation blendshape:

In [9]:
theta = torch.zeros([1, 55, 3], dtype=torch.float32)
batch_size = theta.shape[0]
ident = torch.eye(3, dtype=torch.float32)
rot_mats = batch_rodrigues(theta.reshape(-1, 3)).reshape(batch_size, -1, 3, 3)
pose_feature = (rot_mats[:, 1:, :, :] - ident).reshape(batch_size, -1)
print('pose_feature shape: ', pose_feature.shape)

pose_feature shape:  torch.Size([1, 486])


The pose_feature is a 468-dimensional vector, obtained as 9*(55+1), which is corresponding to the pose deformation blendshape. We can get the final pose deformation by multiplying the pose_feature with the posedirs:

In [10]:
pose_deformation = torch.sum(smplx_neutral['posedirs'] * pose_feature, dim=-1)
print('pose_deformation shape: ', pose_deformation.shape)

pose_deformation shape:  torch.Size([10475, 3])


In [11]:
new_vertice = smplx_neutral['v_template']  + pose_deformation
write_obj(vertices=new_vertice, faces=smplx_neutral['f'], file_name='./obj/smplx_tpose.obj')

Now lets try to change the pose parameters to see what will happen. Before that, the index of each joint is stored as:

In [12]:
smplx_neutral['joint2num']

array({'L_Middle3': 30, 'R_Wrist': 21, 'R_Foot': 11, 'Jaw': 22, 'L_Eye': 23, 'Spine1': 3, 'Spine3': 9, 'Spine2': 6, 'R_Thumb1': 52, 'R_Thumb3': 54, 'R_Thumb2': 53, 'R_Elbow': 19, 'Head': 15, 'L_Collar': 13, 'R_Hip': 2, 'R_Eye': 24, 'L_Ring1': 34, 'L_Ring2': 35, 'L_Ring3': 36, 'L_Thumb3': 39, 'L_Thumb2': 38, 'L_Thumb1': 37, 'R_Ring2': 50, 'R_Ring3': 51, 'R_Ring1': 49, 'L_Index3': 27, 'L_Index2': 26, 'L_Index1': 25, 'R_Shoulder': 17, 'Neck': 12, 'L_Foot': 10, 'R_Index1': 40, 'R_Index3': 42, 'R_Index2': 41, 'L_Knee': 4, 'L_Elbow': 18, 'R_Middle3': 45, 'R_Middle2': 44, 'R_Middle1': 43, 'L_Pinky1': 31, 'L_Pinky2': 32, 'L_Pinky3': 33, 'L_Middle1': 28, 'R_Ankle': 8, 'R_Collar': 14, 'L_Middle2': 29, 'R_Pinky2': 47, 'L_Wrist': 20, 'R_Pinky3': 48, 'L_Shoulder': 16, 'L_Hip': 1, 'R_Knee': 5, 'Pelvis': 0, 'R_Pinky1': 46, 'L_Ankle': 7},
      dtype=object)

In [13]:
theta = torch.zeros([1, 55, 3], dtype=torch.float32)
theta[0, 15, 0] = torch.pi/6
batch_size = theta.shape[0]
ident = torch.eye(3, dtype=torch.float32)
rot_mats = batch_rodrigues(theta.reshape(-1, 3)).reshape(batch_size, -1, 3, 3)
pose_feature = (rot_mats[:, 1:, :, :] - ident).reshape(batch_size, -1)
pose_deformation = torch.sum(smplx_neutral['posedirs'] * pose_feature, dim=-1)
new_vertice = smplx_neutral['v_template'] + pose_deformation
write_obj(vertices=new_vertice, faces=smplx_neutral['f'], file_name='./obj/smplx_tpose_1.obj')

As you can see, the 15th joint is the head, and we rotate it by 60 degrees. The body is now in a T-pose. But the mesh in neck part is slightly different from the original one.

<img src='images/2_pose_deformation.png' width='80%'>

## Step 3: Mesh transformation
By now we only have a T-pose body, and we need to transform it to the target pose. The transformation is done by the **linear blend skinning** (LBS) algorithm. The LBS algorithm is a widely used method to deform the mesh. The basic idea is to use the transformation of the joints to deform the mesh.


In [14]:
new_vertice = new_vertice.unsqueeze(0)              # BxNx3 -> BxNx3x1

In [15]:
joints = torch.einsum('bik,ji->bjk', [new_vertice, smplx_neutral['J_regressor']])
joints = torch.unsqueeze(joints, dim=-1)            # BxNx3 -> BxNx3x1
print('joints shape: ', joints.shape)

joints shape:  torch.Size([1, 55, 3, 1])


We can get the transformation matrix of each joint by the rotation matrix and translation vector:

In [16]:
def transform_mat(R: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
    ''' Creates a batch of transformation matrices
        Args:
            - R: Bx3x3 array of a batch of rotation matrices
            - t: Bx3x1 array of a batch of translation vectors
        Returns:
            - T: Bx4x4 Transformation matrix
    '''
    # No padding left or right, only add an extra row
    return torch.cat([torch.nn.functional.pad(R, [0, 0, 0, 1]),
                      torch.nn.functional.pad(t, [0, 0, 0, 1], value=1)], dim=2)

In [17]:
rel_joints = joints.clone()                         # BxNx3x1
rel_joints[:, 1:] -= joints[:, smplx_neutral['kintree_table'][0][1:]]         # vector pointing from parent joint to child joint, translation

transforms_mat = transform_mat(
    rot_mats.reshape(-1, 3, 3),                     # BxNx3x3 -> (BxN)x3x3
    rel_joints.reshape(-1, 3, 1)                    # BxNx3x1 -> (BxN)x3x1
).reshape(-1, joints.shape[1], 4, 4)                # (BxN)x4x4 -> BxNx4x4

Now we have the transformation matrix of each joint, and we can get the global transformation matrix of each joint by multiplying the transformation matrix of the parent joint. The global transformation matrix of the root joint is the identity matrix. The global transformation matrix of the other joints can be calculated as:

In [18]:
transform_chain = [transforms_mat[:, 0]]            # Bx4x4 as the global transformation
for i in range(1, smplx_neutral['kintree_table'][0].shape[0]):
    # Subtract the joint location at the rest pose
    # No need for rotation, since it's identity when at rest
    curr_res = torch.matmul(transform_chain[smplx_neutral['kintree_table'][0][i]],
                            transforms_mat[:, i])
    transform_chain.append(curr_res)
transforms = torch.stack(transform_chain, dim=1)    # BxNx4x4

The last column of the transformations contains the posed joints, and the homogeneous coordinates of the posed joints can be obtained by padding the joints with 1. The relative transformation matrix of each joint can be calculated as:

In [19]:
posed_joints = transforms[:, :, :3, 3]              # BxNx3

joints_homogen = torch.nn.functional.pad(joints, [0, 0, 0, 1])        # BxNx3x1

rel_transforms = transforms - torch.nn.functional.pad(
    torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])     # BxNx4x4


Finally, we can get the posed mesh by applying the relative transformation matrix to the neutral mesh. For each vertex, we need to calculate the weighted sum of the relative transformation matrix of the joints that affect it. The weight of each joint is stored in the `weights` tensor. The final posed mesh can be calculated as:

In [20]:
W = smplx_neutral['weights'].unsqueeze(dim=0).expand([batch_size, -1, -1])
num_joints = smplx_neutral['J_regressor'].shape[0]
T = torch.matmul(W, rel_transforms.view(batch_size, num_joints, 16)) \
        .view(batch_size, -1, 4, 4)
homogen_coord = torch.ones([batch_size, new_vertice.shape[1], 1])
v_posed_homo = torch.cat([new_vertice, homogen_coord], dim=2)
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))

verts = v_homo[:, :, :3, 0]
write_obj(vertices=verts.squeeze(0), faces=smplx_neutral['f'], file_name='obj/smplx_look_down.obj')

By far, we can manipulate the SMPL-X body model on our own, check out the final pose:
<img src='./images/2_pose_look_down.png' width='400px'>