In [39]:
import json, numpy as np, pathlib, torch
from tqdm import tqdm
babel_ann = json.load(open('babel_v1.0_release/train.json'))
amass_root = pathlib.Path('amass')
seq_info  = babel_ann['9864']          # 举例
npz_file  = amass_root / seq_info['feat_p']


In [40]:
babel_ann.__len__(),#(6615,)

(6615,)

In [41]:
bdata    = np.load(npz_file, allow_pickle=True)
poses  = bdata ['poses']      # (T, 156)  axis-angle，每3维=1关节
trans  = bdata ['trans']      # (T, 3)    根平移
betas  = bdata ['betas'][None]# (1, 10)   身体形状
gender = bdata ['gender'].item() if 'gender' in bdata  else 'neutral'
T, D   = poses.shape
J      = D // 3             # 关节数(52 for SMPL-H, 24 for SMPL)


In [42]:
poses.shape,

((3157, 156),)

In [43]:
from pytorch3d.transforms import axis_angle_to_matrix
pose_aa  = torch.from_numpy(poses).float().reshape(-1, 3)       # (T·J, 3)
rot_mats = axis_angle_to_matrix(pose_aa).reshape(T, J, 3, 3)    # (T, J, 3, 3)

In [44]:
from human_body_prior.body_model.body_model import BodyModel

model_path = f'./amass/smplh/{gender.lower()}/model.npz'
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
body_model = BodyModel(
        bm_fname=model_path,
        num_betas=16, # 通常 AMASS 用 10 或 16 个 betas
        num_dmpls=None, # 如果 AMASS 数据不包含 DMPLs
        model_type='smplh' # 明确指定模型类型
    ).to(device)

In [49]:
# 选择一帧进行演示
frame_idx = 0 # 或者任意你感兴趣的帧
num_frames = bdata['poses'].shape[0]
if frame_idx >= num_frames:
    print(f"错误: frame_idx ({frame_idx}) 超出范围 (0-{num_frames-1})。")
    exit()

all_poses_frame = torch.tensor(bdata['poses'][frame_idx:frame_idx+1], dtype=torch.float32).to(device) # (1, 156)
root_orient = all_poses_frame[:, :3]        # 全局旋转 (轴角)
pose_body = all_poses_frame[:, 3:66]       # 身体关节姿态 (21 joints * 3 = 63 params for SMPL-H)
pose_hand = all_poses_frame[:, 66:66+90]   # 双手姿态 (15+15 joints * 3 = 90 params for SMPL-H)
# 如果是 SMPL-X，还会有 pose_jaw, pose_eye 等

betas = torch.tensor(bdata['betas'][:16], dtype=torch.float32).unsqueeze(0).to(device) # (1, 10)
trans = torch.tensor(bdata['trans'][frame_idx:frame_idx+1], dtype=torch.float32).to(device) # (1, 3)

# --- 2. & 3. 传递参数给模型并进行正向运动学 ---
body = body_model(
    root_orient=root_orient,
    pose_body=pose_body,
    # pose_hand=pose_hand, # 如果模型是 SMPL-H/X 且有手部参数
    # pose_jaw=pose_jaw, # for SMPL-X
    # pose_eye=pose_eye, # for SMPL-X
    betas=betas,
    trans=trans
)

In [50]:
# --- 4. 获取三维关节 ---
# body 对象包含了很多信息，例如：
# body.v: 模型的顶点 (1, num_vertices, 3)
# body.Jtr: 三维关节位置 (1, num_joints, 3)
# body.f: 模型的面片

joints_3d = body.Jtr.detach().cpu().numpy().squeeze() # (num_joints, 3)
print(joints_3d.shape) # 输出关节坐标的形状
print(f"\n计算得到的第 {frame_idx} 帧的三维关节坐标 (部分示例):")

print(joints_3d[:5, :]) # 打印前5个关节的坐标



(52, 3)

计算得到的第 0 帧的三维关节坐标 (部分示例):
[[-0.15769717  0.04982522  0.7323806 ]
 [-0.14868572  0.12678236  0.64955086]
 [-0.19523332 -0.00525749  0.64258844]
 [-0.16806692  0.04579143  0.83666056]
 [ 0.08590449  0.06953263  0.41320103]]


In [47]:
from human_body_prior.tools.rotation_tools import aa2matrot
#将pose 从轴角转换为旋转矩阵
pose_aa = torch.tensor(bdata['poses'], dtype=torch.float32).reshape(-1, 3)  # (T·J, 3)
rot_mats = aa2matrot(pose_aa).reshape(T, J, 3, 3)  # (T, J, 3, 3)
rot_mats.shape, rot_mats.dtype

(torch.Size([3157, 52, 3, 3]), torch.float32)

In [None]:
#将所有帧的轴角转为旋转矩阵，再转为6d旋转表征，
# 然后再将6d旋转表征转换为轴角，输入到body_model中，查看前后的3D坐标误差


In [53]:
# --- Make sure these PyTorch3D utilities are available ---
# (You might have imported axis_angle_to_matrix in cell [19] already)
try:
    from pytorch3d.transforms import (
        axis_angle_to_matrix,
        matrix_to_rotation_6d,
        rotation_6d_to_matrix,
        matrix_to_axis_angle
    )
    print("Using PyTorch3D for rotation conversions.")
except ImportError:
    print("PyTorch3D not found. Please install it to proceed with these conversions.")
    print("Installation: https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md")
    # You might want to raise an error or use alternative functions if PyTorch3D is not available
    raise


# --- Prepare AMASS data for BodyModel (all frames) ---
# Assuming 'poses_np', 'trans_np', 'bdata', 'device', 'body_model' are from previous cells
# poses_np has shape (T, 156)
# trans_np has shape (T, 3)
# bdata['betas'] should provide shape parameters
poses_np  = bdata ['poses']
trans_np=trans  = bdata ['trans']      # (T, 3)    根平移
T = poses_np.shape[0]
num_pose_dims = poses_np.shape[1] # Should be 156 for SMPL-H

num_betas_to_use = 16 # IMPORTANT: Match this with body_model init and your data
betas_torch = torch.from_numpy(bdata['betas'][:num_betas_to_use]).float().unsqueeze(0).to(device) # Shape: (1, num_betas_to_use)
# BodyModel will broadcast betas from (1, N) to (T, N) if T > 1

trans_torch = torch.from_numpy(trans_np).float().to(device) # Shape: (T, 3)
poses_torch_orig = torch.from_numpy(poses_np).float().to(device) # Shape: (T, 156)

# Split original poses for body_model input
root_orient_torch_orig = poses_torch_orig[:, :3]         # (T, 3)
pose_body_torch_orig   = poses_torch_orig[:, 3:66]        # (T, 63) ; 21 body joints
pose_hand_torch_orig   = poses_torch_orig[:, 66:156]      # (T, 90) ; 2x15 hand joints

# --- Get 3D joint coordinates using ORIGINAL poses ---
print("Calculating 3D joints with original poses...")
with torch.no_grad(): # No need to compute gradients for forward pass
    body_model_orig_output = body_model(
        root_orient=root_orient_torch_orig,
        pose_body=pose_body_torch_orig,
        pose_hand=pose_hand_torch_orig,
        betas=betas_torch,
        trans=trans_torch
    )
joints_3d_orig = body_model_orig_output.Jtr # (T, num_joints_output_by_body_model, 3)
print(f"Original 3D joints shape: {joints_3d_orig.shape}")

Using PyTorch3D for rotation conversions.
Calculating 3D joints with original poses...
Original 3D joints shape: torch.Size([3157, 52, 3])


In [54]:
# --- Perform rotation conversions: aa -> rotmat -> 6d -> rotmat -> aa ---
print("Performing rotation conversions...")

# Reshape poses for batch conversion: (T, 156) -> (T * 52, 3)
num_joints_smplh = num_pose_dims // 3 # 52 for SMPL-H
poses_aa_flat_orig = poses_torch_orig.reshape(T * num_joints_smplh, 3)

# 1. Axis-angle to Rotation Matrix
rot_mats_flat = axis_angle_to_matrix(poses_aa_flat_orig) # (T * 52, 3, 3)

# 2. Rotation Matrix to 6D Representation
poses_6d_flat = matrix_to_rotation_6d(rot_mats_flat) # (T * 52, 6)

# 3. 6D Representation back to Rotation Matrix
rot_mats_flat_reconstructed = rotation_6d_to_matrix(poses_6d_flat) # (T * 52, 3, 3)

# 4. Rotation Matrix back to Axis-angle
poses_aa_flat_reconstructed = matrix_to_axis_angle(rot_mats_flat_reconstructed) # (T * 52, 3)

# Reshape reconstructed poses back to (T, 156) for body_model
poses_torch_reconstructed = poses_aa_flat_reconstructed.reshape(T, num_pose_dims)
print("Rotation conversions completed.")

# Optional: Check intermediate shapes
# print(f"poses_aa_flat_orig shape: {poses_aa_flat_orig.shape}")
# print(f"rot_mats_flat shape: {rot_mats_flat.shape}")
# print(f"poses_6d_flat shape: {poses_6d_flat.shape}")
# print(f"rot_mats_flat_reconstructed shape: {rot_mats_flat_reconstructed.shape}")
# print(f"poses_aa_flat_reconstructed shape: {poses_aa_flat_reconstructed.shape}")
# print(f"poses_torch_reconstructed shape: {poses_torch_reconstructed.shape}")

Performing rotation conversions...
Rotation conversions completed.


In [56]:
# --- Split reconstructed poses for body_model input ---
root_orient_torch_recon = poses_torch_reconstructed[:, :3]
pose_body_torch_recon   = poses_torch_reconstructed[:, 3:66]
pose_hand_torch_recon   = poses_torch_reconstructed[:, 66:156]

# --- Get 3D joint coordinates using RECONSTRUCTED poses ---
print("Calculating 3D joints with reconstructed poses...")
with torch.no_grad():
    body_model_recon_output = body_model(
        root_orient=root_orient_torch_recon,
        pose_body=pose_body_torch_recon,
        pose_hand=pose_hand_torch_recon,
        betas=betas_torch,        # Same betas
        trans=trans_torch         # Same trans
    )
joints_3d_recon = body_model_recon_output.Jtr # (T, num_joints_output_by_body_model, 3)
print(f"Reconstructed 3D joints shape: {joints_3d_recon.shape}")

# --- Calculate and Display Error ---
# Ensure shapes are compatible
if joints_3d_orig.shape != joints_3d_recon.shape:
    print("Error: Original and reconstructed 3D joint shapes do not match!")
    print(f"Original shape: {joints_3d_orig.shape}, Reconstructed shape: {joints_3d_recon.shape}")
else:
    # L2 norm for each joint, each frame (Euclidean distance)
    error_per_joint_per_frame = torch.norm(joints_3d_orig - joints_3d_recon, dim=2) # (T, num_joints)

    # Mean Per Joint Position Error (MPJPE) per frame
    mpjpe_per_frame = error_per_joint_per_frame.mean(dim=1) # (T)

    # Overall MPJPE for the sequence (mean over all frames)
    mean_mpjpe_sequence = mpjpe_per_frame.mean().item()

    # Max error to see worst-case reconstruction for a joint in a frame
    max_error = error_per_joint_per_frame.max().item()

    print(f"\n--- Reconstruction Error Report ---")
    print(f"Mean MPJPE over the sequence: {mean_mpjpe_sequence:.9f} (units of your 3D coordinates, likely meters)")
    print(f"Max single joint error in any frame: {max_error:.9f}")

    # Optionally, print error for the first frame's joints
    if T > 0:
        print(f"MPJPE for the first frame (frame 0): {mpjpe_per_frame[0].item():.9f}")
        # print("Error for each joint in the first frame (frame 0):")
        # print(error_per_joint_per_frame[0].cpu().numpy())

    # You can also compare the pose parameters themselves
    pose_reconstruction_error = torch.norm(poses_torch_orig - poses_torch_reconstructed, dim=1).mean().item()
    print(f"Mean L2 error between original and reconstructed axis-angle poses (per frame avg): {pose_reconstruction_error:.9f}")

Calculating 3D joints with reconstructed poses...
Reconstructed 3D joints shape: torch.Size([3157, 52, 3])

--- Reconstruction Error Report ---
Mean MPJPE over the sequence: 0.000000091 (units of your 3D coordinates, likely meters)
Max single joint error in any frame: 0.000000588
MPJPE for the first frame (frame 0): 0.000000071
Mean L2 error between original and reconstructed axis-angle poses (per frame avg): 0.000000493
