In [85]:
import sys, os
import glob
import os
import sys
import pdb
import os.path as osp
sys.path.append(os.getcwd())
import importlib, time

import numpy as np

import torch
from torch.utils.data import DataLoader

from humor.utils.config import TestConfig
from humor.utils.logging import Logger, class_name_to_file_name, mkdir, cp_files
from humor.utils.torch import get_device, save_state, load_state
from humor.utils.stats import StatTracker
from humor.utils.transforms import rotation_matrix_to_angle_axis
from humor.body_model.utils import SMPL_JOINTS
from humor.datasets.amass_utils import NUM_KEYPT_VERTS, CONTACT_INDS
from humor.losses.humor_loss import CONTACT_THRESH

NUM_WORKERS = 0

def parse_args(argv):
    # create config and parse args
    config = TestConfig(argv)
    known_args, unknown_args = config.parse()
    print('Unrecognized args: ' + str(unknown_args))
    return known_args

def test(args_obj, config_file):

    # set up output
    args = args_obj.base
    mkdir(args.out)

    # create logging system
    test_log_path = os.path.join(args.out, 'test.log')
    Logger.init(test_log_path)

    # save arguments used
    Logger.log('Base args: ' + str(args))
    Logger.log('Model args: ' + str(args_obj.model))
    Logger.log('Dataset args: ' + str(args_obj.dataset))
    Logger.log('Loss args: ' + str(args_obj.loss))

    # save training script/model/dataset/config used
    test_scripts_path = os.path.join(args.out, 'test_scripts')
    mkdir(test_scripts_path)
    pkg_root = os.path.join(cur_file_path, '..')
    dataset_file = class_name_to_file_name(args.dataset)
    dataset_file_path = os.path.join(pkg_root, 'datasets/' + dataset_file + '.py')
    model_file = class_name_to_file_name(args.model)
    loss_file = class_name_to_file_name(args.loss)
    model_file_path = os.path.join(pkg_root, 'models/' + model_file + '.py')
    train_file_path = os.path.join(pkg_root, 'test/test_humor.py')
    cp_files(test_scripts_path, [train_file_path, model_file_path, dataset_file_path, config_file])

    # load model class and instantiate
    model_class = importlib.import_module('models.' + model_file)
    Model = getattr(model_class, args.model)
    model = Model(**args_obj.model_dict,
                    model_smpl_batch_size=args.batch_size) # assumes model is HumorModel

    # load loss class and instantiate
    loss_class = importlib.import_module('losses.' + loss_file)
    Loss = getattr(loss_class, args.loss)
    loss_func = Loss(**args_obj.loss_dict,
                      smpl_batch_size=args.batch_size*args_obj.dataset.sample_num_frames) # assumes loss is HumorLoss

    device = get_device(args.gpu)
    model.to(device)
    loss_func.to(device)

    print(model)

    # count params
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    Logger.log('Num model params: ' + str(params))

    # freeze params in loss
    for param in loss_func.parameters():
        param.requires_grad = False

    # load in pretrained weights if given
    if args.ckpt is not None:
        start_epoch, min_val_loss, min_train_loss = load_state(args.ckpt, model, optimizer=None, map_location=device, ignore_keys=model.ignore_keys)
        Logger.log('Successfully loaded saved weights...')
        Logger.log('Saved checkpoint is from epoch idx %d with min val loss %.6f...' % (start_epoch, min_val_loss))
    else:
        Logger.log('ERROR: No weight specified to load!!')
        # return

    # load dataset class and instantiate training and validation set
    if args.test_on_train:
        Logger.log('WARNING: running evaluation on TRAINING data as requested...should only be used for debugging!')
    elif args.test_on_val:
        Logger.log('WARNING: running evaluation on VALIDATION data as requested...should only be used for debugging!')
    Dataset = getattr(importlib.import_module('datasets.' + dataset_file), args.dataset)
    split = 'test'
    if args.test_on_train:
        split = 'train'
    elif args.test_on_val:
        split = 'val'
    test_dataset = Dataset(split=split, **args_obj.dataset_dict)
    # create loaders
    test_loader = DataLoader(test_dataset, 
                            batch_size=args.batch_size,
                            shuffle=args.shuffle_test, 
                            num_workers=NUM_WORKERS,
                            pin_memory=True,
                            drop_last=False,
                            worker_init_fn=lambda _: np.random.seed())

    test_dataset.return_global = True
    model.dataset = test_dataset

    if args.eval_full_test:
        Logger.log('Running full test set evaluation...')
        # stats tracker
        tensorboard_path = os.path.join(args.out, 'test_tensorboard')
        mkdir(tensorboard_path)
        stat_tracker = StatTracker(tensorboard_path)

        # testing with same stats as training
        test_start_t = time.time()
        test_dataset.pre_batch()
        model.eval()
        for i, data in enumerate(test_loader):
            batch_start_t = time.time()
            # run model
            #   note we're always using ground truth input so this is only measuring single-step error, just like in training
            loss, stats_dict = model_class.step(model, loss_func, data, test_dataset, device, 0, mode='test', use_gt_p=1.0)

            # collect stats
            batch_elapsed_t = time.time() - batch_start_t
            total_elapsed_t = time.time() - test_start_t
            stats_dict['loss'] = loss
            stats_dict['time_per_batch'] = torch.Tensor([batch_elapsed_t])[0]

            stat_tracker.update(stats_dict, tag='test')

            if i % args.print_every == 0:
                stat_tracker.print(i, len(test_loader),
                                0, 1,
                                total_elapsed_time=total_elapsed_t,
                                tag='test')

            test_dataset.pre_batch()

    if args.eval_sampling or args.eval_sampling_debug:
        eval_sampling(model, test_dataset, test_loader, device,
                            out_dir=args.out if args.eval_sampling else None,
                            num_samples=args.eval_num_samples,
                            samp_len=args.eval_sampling_len,
                            viz_contacts=args.viz_contacts,
                            viz_pred_joints=args.viz_pred_joints,
                            viz_smpl_joints=args.viz_smpl_joints)

    Logger.log('Finished!')

def eval_sampling(model, test_dataset, test_loader, device, 
                  out_dir=None,
                  num_samples=1,
                  samp_len=10.0,
                  viz_contacts=False,
                  viz_pred_joints=False,
                  viz_smpl_joints=False):
    Logger.log('Evaluating sampling qualitatively...')
    from body_model.body_model import BodyModel
    from body_model.utils import SMPLH_PATH

    eval_qual_samp_len = int(samp_len * 30.0) # at 30 Hz

    res_out_dir = None
    if out_dir is not None:
        res_out_dir = os.path.join(out_dir, 'eval_sampling')
        if not os.path.exists(res_out_dir):
            os.mkdir(res_out_dir)

    J = len(SMPL_JOINTS)
    V = NUM_KEYPT_VERTS
    male_bm_path = os.path.join(SMPLH_PATH, 'male/model.npz')
    female_bm_path = os.path.join(SMPLH_PATH, 'female/model.npz')
    male_bm = BodyModel(bm_path=male_bm_path, num_betas=16, batch_size=eval_qual_samp_len).to(device)
    female_bm = BodyModel(bm_path=female_bm_path, num_betas=16, batch_size=eval_qual_samp_len).to(device)

    with torch.no_grad():
        test_dataset.pre_batch()
        model.eval()
        for i, data in enumerate(test_loader):
            # get inputs
            batch_in, batch_out, meta = data
            print(meta['path'])
            seq_name_list = [spath[:-4] for spath in meta['path']]
            if res_out_dir is None:
                batch_res_out_list = [None]*len(seq_name_list)
            else:
                batch_res_out_list = [os.path.join(res_out_dir, seq_name.replace('/', '_') + '_b' + str(i) + 'seq' + str(sidx)) for sidx, seq_name in enumerate(seq_name_list)]
                print(batch_res_out_list)
            # continue
            x_past, _, gt_dict, input_dict, global_gt_dict = model.prepare_input(batch_in, device, 
                                                                                data_out=batch_out,
                                                                                return_input_dict=True,
                                                                                return_global_dict=True)

            # roll out predicted motion
            B, T, _, _ = x_past.size()
            x_past = x_past[:,0,:,:] # only need input for first step
            rollout_input_dict = dict()
            for k in input_dict.keys():
                rollout_input_dict[k] = input_dict[k][:,0,:,:] # only need first step

            # sample same trajectory multiple times and save the joints/contacts output
            for samp_idx in range(num_samples):
                x_pred_dict = model.roll_out(x_past, rollout_input_dict, eval_qual_samp_len, gender=meta['gender'], betas=meta['betas'].to(device))

                # visualize and save
                print('Visualizing sample %d/%d!' % (samp_idx+1, num_samples))
                imsize = (1080, 1080)
                cur_res_out_list = batch_res_out_list
                if res_out_dir is not None:
                    cur_res_out_list = [out_path + '_samp%d' % (samp_idx) for out_path in batch_res_out_list]
                    imsize = (720, 720)
                viz_eval_samp(global_gt_dict, x_pred_dict, meta, male_bm, female_bm, cur_res_out_list,
                                imw=imsize[0],
                                imh=imsize[1],
                                show_smpl_joints=viz_smpl_joints,
                                show_pred_joints=viz_pred_joints,
                                show_contacts=viz_contacts
                              )

def viz_eval_samp(global_gt_dict, x_pred_dict, meta, male_bm, female_bm, out_path_list,
                    imw=720,
                    imh=720,
                    show_pred_joints=False,
                    show_smpl_joints=False,
                    show_contacts=False):
    '''
    Given x_pred_dict from the model rollout and the ground truth dict, runs through SMPL model to visualize
    '''
    J = len(SMPL_JOINTS)
    V = NUM_KEYPT_VERTS

    pred_world_root_orient = x_pred_dict['root_orient']
    B, T, _ = pred_world_root_orient.size()
    pred_world_root_orient = rotation_matrix_to_angle_axis(pred_world_root_orient.reshape((B*T, 3, 3))).reshape((B, T, 3))
    pred_world_pose_body = x_pred_dict['pose_body']
    pred_world_pose_body = rotation_matrix_to_angle_axis(pred_world_pose_body.reshape((B*T*(J-1), 3, 3))).reshape((B, T, (J-1)*3))
    pred_world_trans = x_pred_dict['trans']
    pred_world_joints = x_pred_dict['joints'].reshape((B, T, J, 3))

    viz_contacts = [None]*B
    if show_contacts and 'contacts' in x_pred_dict.keys():
        pred_contacts = torch.sigmoid(x_pred_dict['contacts'])
        pred_contacts = (pred_contacts > CONTACT_THRESH).to(torch.float)
        viz_contacts = torch.zeros((B, T, len(SMPL_JOINTS))).to(pred_contacts)
        viz_contacts[:,:,CONTACT_INDS] = pred_contacts
        pred_contacts = viz_contacts

    betas = meta['betas'].to(global_gt_dict[list(global_gt_dict.keys())[0]].device)
    for b in range(B):
        bm_world = male_bm if meta['gender'][b] == 'male' else female_bm
        # pred
        body_pred = bm_world(pose_body=pred_world_pose_body[b], 
                        pose_hand=None,
                        betas=betas[b,0].reshape((1, -1)).expand((T, 16)),
                        root_orient=pred_world_root_orient[b],
                        trans=pred_world_trans[b])

        pred_smpl_joints = body_pred.Jtr[:, :J]
        viz_joints = None
        if show_smpl_joints:
            viz_joints = pred_smpl_joints
        elif show_pred_joints:
            viz_joints = pred_world_joints[b]

        cur_offscreen = out_path_list[b] is not None
        from viz.utils import viz_smpl_seq, create_video
        body_alpha = 0.5 if viz_joints is not None and cur_offscreen else 1.0
        viz_smpl_seq(body_pred,
                        imw=imw, imh=imh, fps=30,
                        render_body=True,
                        render_joints=viz_joints is not None,
                        render_skeleton=viz_joints is not None and cur_offscreen,
                        render_ground=True,
                        contacts=viz_contacts[b],
                        joints_seq=viz_joints,
                        body_alpha=body_alpha,
                        use_offscreen=cur_offscreen,
                        out_path=out_path_list[b],
                        wireframe=False,
                        RGBA=False,
                        follow_camera=True,
                        cam_offset=[0.0, 2.2, 0.9],
                        joint_color=[ 0.0, 1.0, 0.0 ],
                        point_color=[0.0, 0.0, 1.0],
                        skel_color=[0.5, 0.5, 0.5],
                        joint_rad=0.015,
                        point_rad=0.015
                )

        if cur_offscreen:
            create_video(out_path_list[b] + '/frame_%08d.' + '%s' % ('png'), out_path_list[b] + '.mp4', 30)


def main(args, config_file):
    test(args, config_file)


In [86]:
args_obj = parse_args(['@./configs/test_humor_sampling.cfg'])
config_file = './configs/test_humor_sampling.cfg'
args = args_obj.base

Unrecognized args: []


In [87]:
import humor.models.humor_model as HumorModel
from humor.losses.humor_loss import HumorLoss
from humor.datasets.amass_discrete_dataset import AmassDiscreteDataset

# pkg_root = os.path.join('..')
# model_file_path = os.path.join(pkg_root, 'models/' + "humor_model" + '.py')
# train_file_path = os.path.join(pkg_root, 'test/test_humor.py')
# cp_files(test_scripts_path, [train_file_path, model_file_path, dataset_file_path, config_file])

# load model class and instantiate
model_class = HumorModel
model = HumorModel.HumorModel(**args_obj.model_dict,
                model_smpl_batch_size=args.batch_size) # assumes model is HumorModel

# load loss class and instantiate
loss_func = HumorLoss(**args_obj.loss_dict,
                  smpl_batch_size=args.batch_size*args_obj.dataset.sample_num_frames) # assumes loss is HumorLoss

device = get_device(args.gpu)
model.to(device)
loss_func.to(device)

print(model)

# count params
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print('Num model params: ' + str(params))

# freeze params in loss
for param in loss_func.parameters():
    param.requires_grad = False

# load in pretrained weights if given
if args.ckpt is not None:
    start_epoch, min_val_loss, min_train_loss = load_state(args.ckpt, model, optimizer=None, map_location=device, ignore_keys=model.ignore_keys)
    print('Successfully loaded saved weights...')
    print('Saved checkpoint is from epoch idx %d with min val loss %.6f...' % (start_epoch, min_val_loss))
else:
    print('ERROR: No weight specified to load!!')
    # return

# load dataset class and instantiate training and validation set
if args.test_on_train:
    print('WARNING: running evaluation on TRAINING data as requested...should only be used for debugging!')
elif args.test_on_val:
    print('WARNING: running evaluation on VALIDATION data as requested...should only be used for debugging!')
split = 'val'
# if args.test_on_train:
#     split = 'train'
# elif args.test_on_val:
#     split = 'val'
test_dataset = AmassDiscreteDataset(split=split, **args_obj.dataset_dict)
# create loaders
test_loader = DataLoader(test_dataset, 
                        batch_size=args.batch_size,
                        shuffle=args.shuffle_test, 
                        num_workers=NUM_WORKERS,
                        pin_memory=True,
                        drop_last=False,
                        worker_init_fn=lambda _: np.random.seed())

test_dataset.return_global = True
model.dataset = test_dataset


Using posterior architecture: mlp
Using decoder architecture: mlp
Using prior architecture: mlp
Using detected GPU...
HumorModel(
  (encoder): MLP(
    (net): ModuleList(
      (0): Linear(in_features=678, out_features=1024, bias=True)
      (1): GroupNorm(16, 1024, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Linear(in_features=1024, out_features=1024, bias=True)
      (4): GroupNorm(16, 1024, eps=1e-05, affine=True)
      (5): ReLU()
      (6): Linear(in_features=1024, out_features=1024, bias=True)
      (7): GroupNorm(16, 1024, eps=1e-05, affine=True)
      (8): ReLU()
      (9): Linear(in_features=1024, out_features=1024, bias=True)
      (10): GroupNorm(16, 1024, eps=1e-05, affine=True)
      (11): ReLU()
      (12): Linear(in_features=1024, out_features=96, bias=True)
    )
  )
  (decoder): MLP(
    (net): ModuleList(
      (0): Linear(in_features=387, out_features=1024, bias=True)
      (1): GroupNorm(16, 1024, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Linear(

## One Step Error

In [4]:
test_dataset.pre_batch()
model.eval()
for i, data in enumerate(test_loader):
    batch_start_t = time.time()
    # run model
    #   note we're always using ground truth input so this is only measuring single-step error, just like in training
    loss, stats_dict = model_class.step(model, loss_func, data, test_dataset, device, 0, mode='test', use_gt_p=1.0)
    print(loss)
    
    break



tensor(0.0024, device='cuda:0', grad_fn=<AddBackward0>)


In [5]:
data[0].keys()

dict_keys(['pose_body', 'root_orient', 'root_orient_vel', 'trans', 'trans_vel', 'joints', 'joints_vel', 'contacts'])

In [6]:
data[0]['pose_body'].shape

torch.Size([1, 10, 1, 189])

# Sampling for sequence

In [50]:
for i, data in enumerate(test_loader):
    batch_in, batch_out, meta = data
    break
data_names = ['trans', 'trans_vel', 'root_orient', 'root_orient_vel', 'pose_body', 'joints', 'joints_vel', 'contacts']
x_past, _, gt_dict, input_dict, global_gt_dict = model.prepare_input(
                batch_in,
                device,
                data_out=batch_out,
                return_input_dict=True,
                return_global_dict=True,
            )

In [51]:
# rollout_input_dict = dict()
# for k in input_dict.keys():
#     rollout_input_dict[k] = input_dict[k][
#         :, 0, :, :
#     ]  # only need first step

# eval_qual_samp_len = 1
# x_pred_dict = model.roll_out(
#                     x_past[0, 0:1],
#                     rollout_input_dict,
#                     eval_qual_samp_len,
#                     gender=meta["gender"],
#                     betas=meta["betas"].to(device),
#                 )

sample_out = model.sample_step(x_past[0, 0])
decoder_out = sample_out['decoder_out']
x_pred_dict = model.split_output(decoder_out, convert_rots=True)

In [52]:
from copycat.utils.transform_utils import (
    convert_aa_to_orth6d, convert_orth_6d_to_aa, vertizalize_smpl_root,
    rotation_matrix_to_angle_axis, rot6d_to_rotmat, convert_orth_6d_to_mat, angle_axis_to_rotation_matrix,
    angle_axis_to_quaternion
)
from copycat.smpllib.smpl_parser import SMPL_Parser, SMPL_BONE_ORDER_NAMES, SMPLH_Parser
device_cpu = torch.device("cpu")
# smpl_p = SMPL_Parser("/hdd/zen/dev/copycat/Copycat/data/smpl", gender = "male")
# smpl_p.to(device_cpu)

smplh_p = SMPLH_Parser("/hdd/zen/dev/copycat/Copycat/data/smpl", gender = "male", use_pca = False, create_transl = False)
smplh_p.to(device_cpu)




SMPLH_Parser(
  Gender: MALE
  Number of joints: 52
  Betas: 16
  Flat hand mean: False
  (vertex_joint_selector): VertexJointSelector()
)

In [58]:
import pyvista as pv
pose_aa_body = rotation_matrix_to_angle_axis(x_pred_dict['pose_body'].reshape(21, 3, 3))
pose_aa = torch.cat([rotation_matrix_to_angle_axis(x_pred_dict['root_orient'].reshape(1, 3, 3)), pose_aa_body, torch.zeros((30, 3)).to(device)])


In [57]:
# spose_aa_body = rotation_matrix_to_angle_axis(input_dict['pose_body'][0,0].reshape(21, 3, 3))
# pose_aa_prev = torch.cat([rotation_matrix_to_angle_axis(input_dict['root_orient'][0,0].reshape(1, 3, 3)), pose_aa_body, torch.zeros((30, 3)).to(device)])


In [16]:
with torch.no_grad():
    pose = pose_aa
#     pose[:, :3] = 0
    betas = torch.zeros((1, 16))
    verts, jts = smplh_p.get_joints_verts(pose.cpu(), betas.cpu())
    vertices = verts[0].numpy()
    
    verts_prev, jts = smplh_p.get_joints_verts(pose_aa_prev.cpu(), betas.cpu())
    vertices_prev = verts_prev[0].numpy()
    # mesh faces
    faces = smplh_p.faces
    faces =  np.hstack([np.concatenate([[3], f]) for f in faces])
    mesh = pv.PolyData(vertices, faces = faces)
    mesh_prev = pv.PolyData(vertices_prev, faces = faces)
    
#     mesh.plot( jupyter_backend='pythreejs')
#     pv.plot([mesh, mesh], jupyter_backend='pythreejs')
pl = pv.Plotter()
plane = pv.Plane( i_size=5, j_size=5, i_resolution=10, j_resolution=10)
pl.add_mesh(mesh, show_edges=True, color='yellow')
pl.add_mesh(mesh_prev, show_edges=True, color='red')
pl.add_mesh(plane, show_edges=True, color='white')
pl.show(jupyter_backend='pythreejs', cpos=[-1, 1, 0.5])

Renderer(camera=PerspectiveCamera(aspect=1.3333333333333333, children=(DirectionalLight(color='#fefefe', inten…

In [84]:
x_pred_dict['contacts'].shape

torch.Size([1, 50, 9])

# Testing

In [60]:
import joblib
test_data = joblib.load('test.pkl')

acc_data, x_pred_dict = test_data[0], test_data[1]
i = 0
def dict_to_data(x_pred_dict):
    B = x_pred_dict['pose_body'].squeeze().shape[0]
    pose_aa_body = rotation_matrix_to_angle_axis(x_pred_dict['pose_body'].squeeze().reshape(B * 21, 3, 3)).reshape(B, 21, 3)
    root_pose = rotation_matrix_to_angle_axis(x_pred_dict['root_orient'].squeeze().reshape(B, 3, 3)).reshape(B, 1,  3)
    pose_aa = torch.cat([root_pose, pose_aa_body, torch.zeros((B, 30, 3)).to(device)], dim = 1)
    trans = x_pred_dict['trans'].squeeze()
    return pose_aa, trans





In [61]:
from collections import defaultdict
faces = smplh_p.faces
faces =  np.hstack([np.concatenate([[3], f]) for f in faces])

pl = pv.Plotter()
plane = pv.Plane( i_size=5, j_size=5, i_resolution=10, j_resolution=10)

pose_aa, trans = dict_to_data(x_pred_dict)
B = pose_aa.shape[0]
with torch.no_grad():
    pose = pose_aa
    
    betas = torch.zeros((B, 16))
    verts, jts = smplh_p.get_joints_verts(pose.cpu(), betas.cpu(), trans.cpu())
    for i in range(verts.shape[0]):
        vertices = verts[i].numpy()
        mesh = pv.PolyData(vertices, faces = faces)
        pl.add_mesh(mesh, show_edges=True, color='yellow')


x_pred_dict_acc = defaultdict(list)
for data_entry in acc_data:
    [x_pred_dict_acc[k].append(v.cpu().numpy()) for k, v in data_entry.items()]
x_pred_dict_acc = {k: torch.from_numpy(np.array(v)).to(device) for k, v in x_pred_dict_acc.items()}
pose_aa, trans = dict_to_data(x_pred_dict_acc)


with torch.no_grad():
    pose = pose_aa
    betas = torch.zeros((B, 16))
    verts, jts = smplh_p.get_joints_verts(pose.cpu(), betas.cpu(), trans.cpu())
    for i in range(verts.shape[0]):
        vertices = verts[i].numpy()
        mesh = pv.PolyData(vertices, faces = faces)
        pl.add_mesh(mesh, show_edges=True, color='red')


pl.add_mesh(plane, show_edges=True, color='white')
pl.show(jupyter_backend='pythreejs', cpos=[-1, 1, 0.5])

Renderer(camera=PerspectiveCamera(aspect=1.3333333333333333, children=(DirectionalLight(color='#fefefe', inten…

In [62]:
from copycat.smpllib.smpl_parser import SMPL_Parser, SMPLH_Parser
smpl_p = SMPL_Parser("/hdd/zen/dev/copycat/Copycat/data/smpl", gender="neutral")
smpl_hp = SMPLH_Parser("/hdd/zen/dev/copycat/Copycat/data/smpl", gender="neutral")



In [80]:
pose_aa_body = rotation_matrix_to_angle_axis(x_pred_dict['pose_body'].reshape(-1, 3, 3)).reshape(50, 21, 3)
with torch.no_grad():
    verts, jts = smplh_p.get_joints_verts(pose.cpu(), betas.cpu(), trans.cpu())
    root_pose = rotation_matrix_to_angle_axis(x_pred_dict['root_orient'].squeeze().reshape(B, 3, 3)).reshape(B, 1,  3)
    pose_aa = torch.cat([root_pose, pose_aa_body, torch.zeros((B, 2, 3)).to(device)], dim = 1)
    verts_h, jts_h = smpl_p.get_joints_verts(pose_aa.cpu(), betas.cpu(), trans.cpu())

In [81]:
pl = pv.Plotter()
plane = pv.Plane( i_size=5, j_size=5, i_resolution=10, j_resolution=10)

vertices = verts[0].numpy()
mesh = pv.PolyData(vertices, faces = faces)
pl.add_mesh(mesh, show_edges=True, color='red')


vertices = verts_h[0].numpy()
mesh = pv.PolyData(vertices, faces = faces)
pl.add_mesh(mesh, show_edges=True, color='yellow')


pl.add_mesh(plane, show_edges=True, color='white')
pl.show(jupyter_backend='pythreejs', cpos=[-1, 1, 0.5])

Renderer(camera=PerspectiveCamera(aspect=1.3333333333333333, children=(DirectionalLight(color='#fefefe', inten…