Skip to content

Commit

Permalink
Add hmr head for human mesh estimation task. (open-mmlab#162)
Browse files Browse the repository at this point in the history
* Add the hmr head and discriminator for SMPL parameters. Add test codes and test data.

* Modify the codes according to review.
  • Loading branch information
zengwang430521 committed Sep 30, 2020
1 parent 3c519cd commit cbc3f99
Show file tree
Hide file tree
Showing 7 changed files with 526 additions and 0 deletions.
1 change: 1 addition & 0 deletions mmpose/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .detectors import * # noqa
from .keypoint_heads import * # noqa
from .losses import * # noqa
from .mesh_heads import * # noqa
from .registry import BACKBONES, HEADS, LOSSES, POSENETS

__all__ = [
Expand Down
3 changes: 3 additions & 0 deletions mmpose/models/mesh_heads/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .hmr_head import MeshHMRHead

__all__ = ['MeshHMRHead']
286 changes: 286 additions & 0 deletions mmpose/models/mesh_heads/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,286 @@
# ------------------------------------------------------------------------------
# Adapted from https://github.com/akanazawa/hmr
# Original licence: Copyright (c) 2018 akanazawa, under the MIT License.
# ------------------------------------------------------------------------------

from abc import abstractmethod

import torch
import torch.nn as nn

from .geometric_layers import batch_rodrigues


class BaseDiscriminator(nn.Module):
"""Base linear module for SMPL parameter discriminator.
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (9, 32, 32, 1)
use_dropout (Tuple): Tuple of bool define use dropout or not
for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0.5, 0)
use_activation(Tuple): Tuple of bool define use active function
or not, such as (True, True, False)
"""

def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
super().__init__()
self.fc_layers = fc_layers
self.use_dropout = use_dropout
self.drop_prob = drop_prob
self.use_activation = use_activation
self._check()
self.create_layers()

def _check(self):
"""Check input to avoid ValueError."""
if not isinstance(self.fc_layers, tuple):
raise TypeError(f'fc_layers require tuple, '
f'get {type(self.fc_layers)}')

if not isinstance(self.use_dropout, tuple):
raise TypeError(f'use_dropout require tuple, '
f'get {type(self.use_dropout)}')

if not isinstance(self.drop_prob, tuple):
raise TypeError(f'drop_prob require tuple, '
f'get {type(self.drop_prob)}')

if not isinstance(self.use_activation, tuple):
raise TypeError(f'use_activation require tuple, '
f'get {type(self.use_activation)}')

l_fc_layer = len(self.fc_layers)
l_use_drop = len(self.use_dropout)
l_drop_prob = len(self.drop_prob)
l_use_activation = len(self.use_activation)

pass_check = (
l_fc_layer >= 2 and l_use_drop < l_fc_layer
and l_drop_prob < l_fc_layer and l_use_activation < l_fc_layer
and l_drop_prob == l_use_drop)

if not pass_check:
msg = 'Wrong BaseDiscriminator parameters!'
raise ValueError(msg)

def create_layers(self):
"""Create layers."""
l_fc_layer = len(self.fc_layers)
l_use_drop = len(self.use_dropout)
l_use_activation = len(self.use_activation)

self.fc_blocks = nn.Sequential()

for i in range(l_fc_layer - 1):
self.fc_blocks.add_module(
name=f'regressor_fc_{i}',
module=nn.Linear(
in_features=self.fc_layers[i],
out_features=self.fc_layers[i + 1]))

if i < l_use_activation and self.use_activation[i]:
self.fc_blocks.add_module(
name=f'regressor_af_{i}', module=nn.ReLU())

if i < l_use_drop and self.use_dropout[i]:
self.fc_blocks.add_module(
name=f'regressor_fc_dropout_{i}',
module=nn.Dropout(p=self.drop_prob[i]))

@abstractmethod
def forward(self, inputs):
"""Forward function."""
msg = 'the base class [BaseDiscriminator] is not callable!'
raise NotImplementedError(msg)


class ShapeDiscriminator(BaseDiscriminator):
"""Discriminator for SMPL shape parameters, the inputs is (batch_size x 10)
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (10, 5, 1)
use_dropout (Tuple): Tuple of bool define use dropout or
not for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0)
use_activation(Tuple): Tuple of bool define use active
function or not, such as (True, False)
"""

def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
if fc_layers[-1] != 1:
msg = f'the neuron count of the last layer ' \
f'must be 1, but got {fc_layers[-1]}'
raise ValueError(msg)

super().__init__(fc_layers, use_dropout, drop_prob, use_activation)

def forward(self, inputs):
"""Forward function."""
return self.fc_blocks(inputs)


class PoseDiscriminator(nn.Module):
"""Discriminator for SMPL pose parameters of each joint. It is composed of
discriminators for each joints. The inputs is (batch_size x joint_count x
9)
Args:
channels (Tuple): Tuple of channel number,
such as (9, 32, 32, 1)
joint_count (int): Joint number, such as 23
"""

def __init__(self, channels, joint_count):
super().__init__()
if channels[-1] != 1:
msg = f'the neuron count of the last layer ' \
f'must be 1, but got {channels[-1]}'
raise ValueError(msg)
self.joint_count = joint_count

self.conv_blocks = nn.Sequential()
len_channels = len(channels)
for idx in range(len_channels - 2):
self.conv_blocks.add_module(
name=f'conv_{idx}',
module=nn.Conv2d(
in_channels=channels[idx],
out_channels=channels[idx + 1],
kernel_size=1,
stride=1))

self.fc_layer = nn.ModuleList()
for idx in range(joint_count):
self.fc_layer.append(
nn.Linear(
in_features=channels[len_channels - 2], out_features=1))

def forward(self, inputs):
"""Forward function.
The input is (batch_size x joint_count x 9)
"""
# shape: batch_size x 9 x 1 x joint_count
inputs = inputs.transpose(1, 2).unsqueeze(2).contiguous()
# shape: batch_size x c x 1 x joint_count
internal_outputs = self.conv_blocks(inputs)
outputs = []
for idx in range(self.joint_count):
outputs.append(self.fc_layer[idx](internal_outputs[:, :, 0, idx]))

return torch.cat(outputs, 1), internal_outputs


class FullPoseDiscriminator(BaseDiscriminator):
"""Discriminator for SMPL pose parameters of all joints.
Args:
fc_layers (Tuple): Tuple of neuron count,
such as (736, 1024, 1024, 1)
use_dropout (Tuple): Tuple of bool define use dropout or not
for each layer, such as (True, True, False)
drop_prob (Tuple): Tuple of float defined the drop prob,
such as (0.5, 0.5, 0)
use_activation(Tuple): Tuple of bool define use active
function or not, such as (True, True, False)
"""

def __init__(self, fc_layers, use_dropout, drop_prob, use_activation):
if fc_layers[-1] != 1:
msg = f'the neuron count of the last layer must be 1,' \
f' but got {fc_layers[-1]}'
raise ValueError(msg)

super().__init__(fc_layers, use_dropout, drop_prob, use_activation)

def forward(self, inputs):
"""Forward function."""
return self.fc_blocks(inputs)


class SMPLDiscriminator(nn.Module):
"""Discriminator for SMPL pose and shape parameters. It is composed of a
discriminator for SMPL shape parameters, a discriminator for SMPL pose
parameters of all joints and a discriminator for SMPL pose parameters of
each joint.
Args:
beta_channel (tuple of int): Tuple of neuron count of the
discriminator of shape parameters. Defaults to (10, 5, 1)
per_joint_channel (tuple of int): Tuple of neuron count of the
discriminator of each joint. Defaults to (9, 32, 32, 1)
full_pose_channel (tuple of int): Tuple of neuron count of the
discriminator of full pose. Defaults to (23*32, 1024, 1024, 1)
"""

def __init__(self,
beta_channel=(10, 5, 1),
per_joint_channel=(9, 32, 32, 1),
full_pose_channel=(23 * 32, 1024, 1024, 1)):
super().__init__()
self.joint_count = 23
# The count of SMPL shape parameter is 10.
assert beta_channel[0] == 10
# Use 3 x 3 rotation matrix as the pose parameters
# of each joint, so the input channel is 9.
assert per_joint_channel[0] == 9
assert self.joint_count * per_joint_channel[-2] \
== full_pose_channel[0]

self.beta_channel = beta_channel
self.per_joint_channel = per_joint_channel
self.full_pose_channel = full_pose_channel
self._create_sub_modules()

def _create_sub_modules(self):
"""Create sub discriminators."""

# create theta discriminator for each joint
self.pose_discriminator = PoseDiscriminator(self.per_joint_channel,
self.joint_count)

# create full pose discriminator for total joints
fc_layers = self.full_pose_channel
use_dropout = tuple([False] * (len(fc_layers) - 1))
drop_prob = tuple([0.5] * (len(fc_layers) - 1))
use_activation = tuple([True] * (len(fc_layers) - 2) + [False])

self.full_pose_discriminator = FullPoseDiscriminator(
fc_layers, use_dropout, drop_prob, use_activation)

# create shape discriminator for betas
fc_layers = self.beta_channel
use_dropout = tuple([False] * (len(fc_layers) - 1))
drop_prob = tuple([0.5] * (len(fc_layers) - 1))
use_activation = tuple([True] * (len(fc_layers) - 2) + [False])
self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout,
drop_prob,
use_activation)

def forward(self, thetas):
"""Forward function."""
cams, poses, shapes = thetas

batch_size = poses.shape[0]
shape_disc_value = self.shape_discriminator(shapes)

# The first rotation matrix is global rotation
# and is NOT used in discriminator.
if poses.dim() == 2:
rotate_matrixs = \
batch_rodrigues(poses.contiguous().view(-1, 3)
).view(batch_size, 24, 9)[:, 1:, :]
else:
rotate_matrixs = poses.contiguous().view(batch_size, 24,
9)[:, 1:, :].contiguous()
pose_disc_value, pose_inter_disc_value \
= self.pose_discriminator(rotate_matrixs)
full_pose_disc_value = self.full_pose_discriminator(
pose_inter_disc_value.contiguous().view(batch_size, -1))
return torch.cat(
(pose_disc_value, full_pose_disc_value, shape_disc_value), 1)
67 changes: 67 additions & 0 deletions mmpose/models/mesh_heads/geometric_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch
from torch.nn import functional as F


def rot6d_to_rotmat(x):
"""Convert 6D rotation representation to 3x3 rotation matrix.
Based on Zhou et al., "On the Continuity of Rotation
Representations in Neural Networks", CVPR 2019
Input:
(B,6) Batch of 6-D rotation representations
Output:
(B,3,3) Batch of corresponding rotation matrices
"""
x = x.view(-1, 3, 2)
a1 = x[:, :, 0]
a2 = x[:, :, 1]
b1 = F.normalize(a1)
b2 = F.normalize(a2 - torch.einsum('bi,bi->b', b1, a2).unsqueeze(-1) * b1)
b3 = torch.cross(b1, b2)
return torch.stack((b1, b2, b3), dim=-1)


def batch_rodrigues(theta):
"""Convert axis-angle representation to rotation matrix.
Args:
theta: size = [B, 3]
Returns:
Rotation matrix corresponding to the quaternion
-- size = [B, 3, 3]
"""
l2norm = torch.norm(theta + 1e-8, p=2, dim=1)
angle = torch.unsqueeze(l2norm, -1)
normalized = torch.div(theta, angle)
angle = angle * 0.5
v_cos = torch.cos(angle)
v_sin = torch.sin(angle)
quat = torch.cat([v_cos, v_sin * normalized], dim=1)
return quat_to_rotmat(quat)


def quat_to_rotmat(quat):
"""Convert quaternion coefficients to rotation matrix.
Args:
quat: size = [B, 4] 4 <===>(w, x, y, z)
Returns:
Rotation matrix corresponding to the quaternion
-- size = [B, 3, 3]
"""
norm_quat = quat
norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True)
w, x, y, z = norm_quat[:, 0], norm_quat[:, 1],\
norm_quat[:, 2], norm_quat[:, 3]

B = quat.size(0)

w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2)
wx, wy, wz = w * x, w * y, w * z
xy, xz, yz = x * y, x * z, y * z

rotMat = torch.stack([
w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy,
w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz,
w2 - x2 - y2 + z2
],
dim=1).view(B, 3, 3)
return rotMat

0 comments on commit cbc3f99

Please sign in to comment.