Skip to content

Commit

Permalink
Allow passing SMPL initialization parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
UuuNyaa committed Nov 30, 2022
1 parent f69c8dd commit f299527
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 7 deletions.
8 changes: 6 additions & 2 deletions model/mdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class MDM(nn.Module):
def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot,
latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512,
arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs):
arch='trans_enc', emb_trans_dec=False, clip_version=None, smpl_model_path=None, joint_regressor_train_extra_path=None, **kargs):
super().__init__()

self.legacy = legacy
Expand Down Expand Up @@ -93,7 +93,11 @@ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_re
self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
self.nfeats)

self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
self.rot2xyz = Rotation2xyz(
device='cpu', dataset=self.dataset,
smpl_model_path=smpl_model_path,
joint_regressor_train_extra_path=joint_regressor_train_extra_path
)

def parameters_wo_clip(self):
return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
Expand Down
4 changes: 2 additions & 2 deletions model/rotation2xyz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@


class Rotation2xyz:
def __init__(self, device, dataset='amass'):
def __init__(self, device, dataset='amass', smpl_model_path=None, joint_regressor_train_extra_path=None):
self.device = device
self.dataset = dataset
self.smpl_model = SMPL().eval().to(device)
self.smpl_model = SMPL(model_path=smpl_model_path, joint_regressor_train_extra_path=joint_regressor_train_extra_path).eval().to(device)

def __call__(self, x, mask, pose_rep, translation, glob,
jointstype, vertstrans, betas=None, beta=0,
Expand Down
6 changes: 3 additions & 3 deletions model/smpl.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,14 @@
class SMPL(_SMPLLayer):
""" Extension of the official SMPL implementation to support more joints """

def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
kwargs["model_path"] = model_path
def __init__(self, model_path=None, joint_regressor_train_extra_path=None, **kwargs):
kwargs["model_path"] = model_path or SMPL_MODEL_PATH

# remove the verbosity for the 10-shapes beta parameters
with contextlib.redirect_stdout(None):
super(SMPL, self).__init__(**kwargs)

J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
J_regressor_extra = np.load(joint_regressor_train_extra_path or JOINT_REGRESSOR_TRAIN_EXTRA)
self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])
a2m_indexes = vibe_indexes[action2motion_joints]
Expand Down

0 comments on commit f299527

Please sign in to comment.