In [13]:
import cv2
import torch
import smplx
import numpy as np
import open3d as o3d
from pytransform3d.rotations import (
    quaternion_from_compact_axis_angle,
    compact_axis_angle_from_quaternion,
    quaternion_slerp,
)

NUM_BODY_JOINTS = 24

def interpolate_smpl_poses(pose1, pose2, num_frames=10):
    """
    Interpolates between two SMPL poses using SLERP.
    Args:
        pose1 (ndarray): The first pose in axis-angle of shape (NUM_BODY_JOINTS, 3)
        pose2 (ndarray): The second pose in axis-angle of shape (NUM_BODY_JOINTS, 3)
        num_frames (int): The number of frames to interpolate between.
    Returns:
        interpolated_poses (ndarray): The interpolated poses in axis-angle of shape (num_frames, NUM_BODY_JOINTS, 3)
    """

    # Ensure the poses are in the correct format
    assert isinstance(pose1, np.ndarray) and isinstance(pose2, np.ndarray), "Poses must be numpy arrays"
    assert pose1.shape == pose2.shape == (NUM_BODY_JOINTS, 3), f"Pose shapes must be ({NUM_BODY_JOINTS}, 3)"

    interpolated_poses = np.zeros((NUM_BODY_JOINTS, num_frames, 3), dtype=np.float32)
    for i in range(NUM_BODY_JOINTS):
        quat1 = quaternion_from_compact_axis_angle(pose1[i])
        quat2 = quaternion_from_compact_axis_angle(pose2[i])
        interpolated_poses[i, 0] = pose1[i]
        interpolated_poses[i, -1] = pose2[i]
        for j in range(1, num_frames - 1):
            t = j / (num_frames - 1)
            interp_quat = quaternion_slerp(quat1, quat2, t)
            interpolated_poses[i, j] = compact_axis_angle_from_quaternion(interp_quat)
    return np.swapaxes(interpolated_poses, 0, 1)

def render_pose_as_image(vertices, faces):
    mesh = o3d.geometry.TriangleMesh()
    mesh.vertices = o3d.utility.Vector3dVector(vertices)
    mesh.triangles = o3d.utility.Vector3iVector(faces)
    mesh.compute_vertex_normals()
    mesh.paint_uniform_color([0.3, 0.3, 0.3])

    vis = o3d.visualization.Visualizer()
    vis.create_window(visible=False)
    vis.add_geometry(mesh)
    vis.poll_events()
    vis.update_renderer()

    ctr = vis.get_view_control()
    ctr.set_lookat([0, 0, 0])
    ctr.set_front([0, -1, 0])
    ctr.set_up([0, 0, 1])
    ctr.set_zoom(1)

    image = vis.capture_screen_float_buffer(True)
    vis.destroy_window()
    return np.asarray(image)

def save_video(frames, output_path="output.mp4", fps=30):
    assert frames.ndim == 4, "Frames should be a 4D numpy array."
    height, width = frames.shape[1:3]
    is_color = frames.shape[-1] == 3
    fourcc = cv2.VideoWriter_fourcc(*'avc1')  # or 'XVID', 'avc1', etc.
    writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height), isColor=is_color)

    for i, frame in enumerate(frames):
        if frame.shape[:2] != (height, width):
            raise ValueError(f"Frame {i} has mismatched size.")
        if is_color and frame.shape[2] != 3:
            raise ValueError(f"Frame {i} is not a 3-channel color image.")
        if not is_color and len(frame.shape) != 2:
            raise ValueError(f"Frame {i} is not a grayscale image.")
        writer.write(frame if is_color else cv2.cvtColor(frame, cv2.COLOR_GRAY2BGR))

    writer.release()

def render_pose_sequence_as_video(poses, smpl_model, output_path="output.mp4", fps=30):
    images = []
    for pose in poses:
        output = smpl_model(
            global_orient=pose[None, :3],
            body_pose=pose[None, 3:],
            transl=None,
            return_verts=True)
        vertices = output.vertices.detach().cpu().numpy().squeeze()

        image = render_pose_as_image(vertices, smpl_model.faces)
        image = (image * 255).astype(np.uint8)
        images.append(image)
    images = np.array(images)
    save_video(images, output_path, fps)


In [14]:
# To download the SMPL model go to [this](https://smpl.is.tue.mpg.de/) (male
# and female models) and [this](https://smplify.is.tue.mpg.de/) (gender neutral
# model) project website and register to get access to the downloads section.

MODEL_PATH = "smplx/models"

POSE1_PATH = "wham_pose_30.pt"
POSE2_PATH = "wham_pose_75_edited.pt"
POSE3_PATH = "wham_pose_120.pt"

smpl_model = smplx.create(
    model_path=MODEL_PATH,
    model_type='smpl',
    gender='neutral',
    ext='npz')

pose1 = torch.load(POSE1_PATH)
pose2 = torch.load(POSE2_PATH)
pose3 = torch.load(POSE3_PATH)

pose1 = pose1[:, :72].reshape(NUM_BODY_JOINTS, 3).numpy()
pose2 = pose2[None, :72].reshape(NUM_BODY_JOINTS, 3).numpy()
pose3 = pose3[:, :72].reshape(NUM_BODY_JOINTS, 3).numpy()

In [16]:
pose1_to_pose2 = interpolate_smpl_poses(pose1, pose2, num_frames=46)
pose1_to_pose2 = pose1_to_pose2.reshape(-1, NUM_BODY_JOINTS * 3)
pose1_to_pose2[:, (NUM_BODY_JOINTS - 2) * 3:] = 0.
pose1_to_pose2 = torch.from_numpy(pose1_to_pose2)

pose2_to_pose3 = interpolate_smpl_poses(pose2, pose3, num_frames=46)
pose2_to_pose3 = pose2_to_pose3.reshape(-1, NUM_BODY_JOINTS * 3)
pose2_to_pose3[:, (NUM_BODY_JOINTS - 2) * 3:] = 0.
pose2_to_pose3 = torch.from_numpy(pose2_to_pose3)

assert torch.all(pose1_to_pose2[-1] == pose2_to_pose3[0])
pose1_to_pose3 = torch.zeros((91, 72), dtype=torch.float32)
pose1_to_pose3[:46] = pose1_to_pose2
pose1_to_pose3[45:] = pose2_to_pose3
torch.save(pose1_to_pose3, "pose1_to_pose3.pt")

In [None]:
render_pose_sequence_as_video(pose1_to_pose3, smpl_model, output_path="pose1_to_pose3.mp4", fps=30)
print("Video saved to pose1_to_pose3.mp4")