In [1]:
import torch
from smotdm.renderer.matplotlib import SceneRenderer
from smotdm.data.motion import MotionLoader
from smotdm.rifke import rifke_to_joints
from smplx import SMPLX
import numpy as np


INTERX_DATASET_FILE = "deps/interx/dataset_tiny.h5"
device = torch.device("cpu")

renderer = SceneRenderer(
    colors1=("red", "red", "red", "red", "red"),
    colors2=("black", "black", "black", "black", "black"),
)

In [2]:
from smotdm.rifke import joints_to_rifke, ungroup


@torch.no_grad()
def get_joints(smplx_params):
    output = smplx_model(
        **smplx_params,
    )

    return torch.matmul(
        output.joints,
        torch.tensor(
            [[1.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 0.0]],
            device=device,
        ),
    )[:, : smplx_model.NUM_JOINTS, :]


P1_smplx = np.load("deps/interx/motions/G001T000A001R005/P1.npz")
P2_smplx = np.load("deps/interx/motions/G001T000A001R005/P2.npz")

smplx_model = SMPLX(
    model_path="deps/smplx/SMPLX_NEUTRAL.npz",
    num_betas=10,
    use_pca=False,
    use_face_contour=True,
    batch_size=P1_smplx["pose_body"].shape[0],
).to(device)

# difference_in_transl = torch.tensor(P2_smplx["trans"]).to(device) - torch.tensor(
#     P1_smplx["trans"][0]
# ).to(device)

j1, j2 = (
    get_joints(
        {
            "body_pose": torch.tensor(P1_smplx["pose_body"]).to(device),
            "left_hand_pose": torch.tensor(P1_smplx["pose_lhand"]).to(device),
            "right_hand_pose": torch.tensor(P1_smplx["pose_rhand"]).to(device),
            "transl": torch.tensor(P1_smplx["trans"]).to(device),
            "global_orient": torch.tensor(P1_smplx["root_orient"]).to(device),
        }
    ),
    get_joints(
        {
            "body_pose": torch.tensor(P2_smplx["pose_body"]).to(device),
            "left_hand_pose": torch.tensor(P2_smplx["pose_lhand"]).to(device),
            "right_hand_pose": torch.tensor(P2_smplx["pose_rhand"]).to(device),
            "transl": torch.tensor(P2_smplx["trans"] - P1_smplx["trans"]).to(device),
            "global_orient": torch.tensor(P2_smplx["root_orient"]).to(device),
        }
    ),
)

reactor_feats, translation, angles = joints_to_rifke(j1)
actor_feats, _, _ = joints_to_rifke(j2)

In [None]:
# loader = MotionLoader(INTERX_DATASET_FILE, fps=20.0)
# sample1 = loader("G001T000A001R005", 0, 0.0, 27.0)
# sample2 = loader("G001T000A001R005", 1, 0.0, 27.0)
# motion1 = sample1["x"]
# motion2 = sample2["x"]

In [None]:
renderer.render(
    rifke_to_joints(reactor_feats).detach().cpu().numpy(),
    rifke_to_joints(actor_feats).detach().cpu().numpy(),
    output="test_recons.mp4",
)

renderer.render(
    j1.detach().cpu().numpy(),
    j2.detach().cpu().numpy(),
    output="test.mp4",
)