# MultiviewSMPLifyX Inference Notebook

This notebook is a cell-by-cell version of `main.py` for interactive inference.

Current scope:
- Load config and dataset
- Build SMPL model + priors
- Build multiview cameras and keypoints
- Run fitting and export mesh/params

Later we can add overlay visualization cells for projected mesh on input images.

In [None]:
import os
import os.path as osp
import sys
import time
import yaml

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import smplx

from utils import JointMapper
from cmd_parser import parse_config
from data_parser import create_dataset
from fit_single_frame import fit_single_frame
from camera import create_camera
from prior import create_prior

torch.backends.cudnn.enabled = False

In [12]:
# Runtime paths (edit these as needed)
CONFIG_PATH = 'cfg_files/fit_smpl.yaml'
DATA_FOLDER = './dataset_example/image_data/rp_dennis_posed_004'
OUTPUT_FOLDER = './dataset_example/mesh_data/rp_dennis_posed_004/notebook'

# Parse the same config flow used by main.py
sys.argv = [
    'notebook',
    '--config', CONFIG_PATH,
    '--data_folder', DATA_FOLDER,
    '--output_folder', OUTPUT_FOLDER,
]

args = parse_config()
print('Loaded config with keys:', len(args))
print('Data folder:', args['data_folder'])
print('Output folder:', args['output_folder'])

Loaded config with keys: 70
Data folder: ./dataset_example/image_data/rp_dennis_posed_004
Output folder: ./dataset_example/mesh_data/rp_dennis_posed_004/notebook


In [13]:
# Step 1: output folder + dataset
output_folder = args.pop('output_folder')
output_folder = osp.expandvars(output_folder)
os.makedirs(output_folder, exist_ok=True)

conf_fn = osp.join(output_folder, 'conf.yaml')
with open(conf_fn, 'w') as conf_file:
    yaml.dump(args, conf_file)
print(f'Config saved to {conf_fn}')

float_dtype = args.get('float_dtype', 'float32')
if float_dtype == 'float64':
    dtype = torch.float64
elif float_dtype == 'float32':
    dtype = torch.float32
else:
    raise ValueError(f"Unknown float type {float_dtype}")

use_cuda = args.get('use_cuda', True)
if use_cuda and not torch.cuda.is_available():
    raise RuntimeError('CUDA is not available. Set use_cuda=False in config or start a CUDA runtime.')

img_folder = args.pop('img_folder', 'images')
dataset_obj = create_dataset(img_folder=img_folder, **args)

start = time.time()

input_gender = args.pop('gender', 'neutral')
gender_lbl_type = args.pop('gender_lbl_type', 'none')
max_persons = args.pop('max_persons', -1)

print('Views in dataset:', len(dataset_obj))
print('Gender mode:', input_gender)

Config saved to ./dataset_example/mesh_data/rp_dennis_posed_004/notebook/conf.yaml
Views in dataset: 360
Gender mode: neutral


In [14]:
# Step 2: body models + priors
joint_mapper = JointMapper(dataset_obj.get_model2data())

model_params = dict(
    model_path=args.get('model_folder'),
    joint_mapper=joint_mapper,
    create_global_orient=True,
    create_body_pose=not args.get('use_vposer'),
    create_betas=True,
    create_left_hand_pose=True,
    create_right_hand_pose=True,
    create_expression=True,
    create_jaw_pose=True,
    create_leye_pose=True,
    create_reye_pose=True,
    create_transl=False,
    dtype=dtype,
    **args,
)

male_model = smplx.create(gender='male', **model_params)
neutral_model = None
if args.get('model_type') != 'smplh':
    neutral_model = smplx.create(gender='neutral', **model_params)
female_model = smplx.create(gender='female', **model_params)

use_hands = args.get('use_hands', True)
use_face = args.get('use_face', True)

body_pose_prior = create_prior(prior_type=args.get('body_prior_type'), dtype=dtype, **args)
shape_prior = create_prior(prior_type=args.get('shape_prior_type', 'l2'), dtype=dtype, **args)
angle_prior = create_prior(prior_type='angle', dtype=dtype)

jaw_prior, expr_prior = None, None
if use_face:
    jaw_prior = create_prior(prior_type=args.get('jaw_prior_type'), dtype=dtype, **args)
    expr_prior = create_prior(prior_type=args.get('expr_prior_type', 'l2'), dtype=dtype, **args)

left_hand_prior, right_hand_prior = None, None
if use_hands:
    lhand_args = args.copy()
    lhand_args['num_gaussians'] = args.get('num_pca_comps')
    left_hand_prior = create_prior(prior_type=args.get('left_hand_prior_type'), dtype=dtype, use_left_hand=True, **lhand_args)

    rhand_args = args.copy()
    rhand_args['num_gaussians'] = args.get('num_pca_comps')
    right_hand_prior = create_prior(prior_type=args.get('right_hand_prior_type'), dtype=dtype, use_right_hand=True, **rhand_args)

print('Model and priors initialized.')

Model and priors initialized.


In [None]:
# Step 3: device transfer + multiview camera/keypoint assembly (use exactly 4 views)
if use_cuda and torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

female_model = female_model.to(device=device)
male_model = male_model.to(device=device)
if neutral_model is not None:
    neutral_model = neutral_model.to(device=device)

body_pose_prior = body_pose_prior.to(device=device)
shape_prior = shape_prior.to(device=device)
angle_prior = angle_prior.to(device=device)
if use_face:
    expr_prior = expr_prior.to(device=device)
    jaw_prior = jaw_prior.to(device=device)
if use_hands:
    left_hand_prior = left_hand_prior.to(device=device)
    right_hand_prior = right_hand_prior.to(device=device)

joint_weights = dataset_obj.get_joint_weights().to(device=device, dtype=dtype)
joint_weights.unsqueeze_(dim=0)

img_list, keypoints_list, camera_list = [], [], []
selected_image_ids = []

view_num = len(dataset_obj)
if view_num < 4:
    raise ValueError(f'Need at least 4 images/views, but found {view_num}')

selected_indices = np.linspace(0, view_num - 1, num=4, dtype=int).tolist()
selected_indices = sorted(set(selected_indices))

for idx in selected_indices:
    data = dataset_obj[idx]

    img = data['img']
    keypoints = data['keypoints'][[0]]

    focal_length = args.get('focal_length')
    camera = create_camera(
        focal_length_x=focal_length,
        focal_length_y=focal_length,
        dtype=dtype,
        **args,
    )

    cam_R = data['cam_R']
    cam_t = data['cam_t']
    cam_fx = data['cam_fx']
    cam_fy = data['cam_fy']
    cam_cx = data['cam_cx']
    cam_cy = data['cam_cy']

    camera.focal_length_x = torch.full([1], cam_fx, dtype=dtype)
    camera.focal_length_y = torch.full([1], cam_fy, dtype=dtype)
    camera.center = torch.tensor([cam_cx, cam_cy], dtype=dtype).unsqueeze(0)
    camera.rotation.data = torch.from_numpy(cam_R).unsqueeze(0)
    camera.translation.data = torch.from_numpy(cam_t).unsqueeze(0)
    camera.rotation.requires_grad = False
    camera.translation.requires_grad = False

    if use_cuda and torch.cuda.is_available():
        camera = camera.to(device)

    img_list.append(img)
    keypoints_list.append(keypoints)
    camera_list.append(camera)
    selected_image_ids.append(data['fn'])

    print('Selected:', data['img_path'])

print('Selected 4 image ids:', selected_image_ids)

Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0000.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0018.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0036.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0054.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0072.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0090.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0108.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0126.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0144.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0162.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0180.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0198.jpg
Processing: ./dataset_example/image_data/rp_dennis_posed_004/color/0216.jpg
Processing: 

In [None]:
# Step 3.1: visualize selected 4 views in a 2x2 grid
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
axes = axes.ravel()

for i, (img, img_id) in enumerate(zip(img_list, selected_image_ids)):
    axes[i].imshow(img)
    axes[i].set_title(f'Image ID: {img_id}')
    axes[i].axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Step 4: run fitting and save outputs
curr_result_fn = osp.join(output_folder, 'smpl_param.pkl')
curr_mesh_fn = osp.join(output_folder, 'smpl_mesh.obj')

gender = input_gender
if gender == 'neutral':
    if neutral_model is None:
        raise RuntimeError('Neutral model is not available for model_type=smplh')
    body_model = neutral_model
elif gender == 'female':
    body_model = female_model
elif gender == 'male':
    body_model = male_model
else:
    raise ValueError(f'Unknown gender: {gender}')

fit_single_frame(
    img_list,
    keypoints_list,
    body_model=body_model,
    camera_list=camera_list,
    joint_weights=joint_weights,
    dtype=dtype,
    output_folder=output_folder,
    result_fn=curr_result_fn,
    mesh_fn=curr_mesh_fn,
    shape_prior=shape_prior,
    expr_prior=expr_prior,
    body_pose_prior=body_pose_prior,
    left_hand_prior=left_hand_prior,
    right_hand_prior=right_hand_prior,
    jaw_prior=jaw_prior,
    angle_prior=angle_prior,
    **args,
)

elapsed = time.time() - start
print('Done. Result params:', curr_result_fn)
print('Done. Result mesh:', curr_mesh_fn)
print('Elapsed:', time.strftime('%Hh %Mm %Ss', time.gmtime(elapsed)))

## Next (planned)
- Add image display cells for selected multiview inputs.
- Project fitted mesh vertices to each image and render overlay for visual QA.
- Save side-by-side overlay outputs under output folder.