In [36]:
from dataset.io_utils import load_multipart_gt
from pathlib import Path as P
import torch
from dataset.pose_utils import quaternion_to_axis_angle, get_rotation_axis_angle
import torch.nn.functional as F
from models.utils import axis_metrics, geodesic_distance, translational_error
gt_path = 'data_paris/sapien/fridge_mp/10612/textured_objs/trans.json'
motion_type = 'r'
state = 'end'
gt_infos = load_multipart_gt(gt_path,state=state, motion_type=motion_type)
root_path = P('/home/dj/Downloads/project/nerfacc_ngp/results_stable/fridge_mp_start_to_end_f16')

In [31]:
for exp_path in root_path.glob('*'):
    if exp_path.name != 'eval':
        ckpt_fname = exp_path / 'ckpt' / 'best_ckpt.pth'
        try:
            ckpt_dict = torch.load(str(ckpt_fname))
            break
        except:
            print(f'ckpt file does not exist in folder {ckpt_fname}')
        # break

In [32]:

ckpt_dict = torch.load(str(root_path / '1713347432' / 'ckpt' / '010000.pth'))
ckpt_dict.keys()

dict_keys(['estimator', 'model', 'optimizer', 'prop_networks', 'prop_optimizer', 'pose_params'])

In [37]:
ckpt_dict['pose_params']
for p in ckpt_dict['pose_params']:
    axis_dir, radian, angles = quaternion_to_axis_angle(p['Q'])
    pred_R = get_rotation_axis_angle(axis_dir.cpu().numpy(), radian.cpu().numpy())
    
    if motion_type == 'r':
        pred_axis_dir = axis_dir.detach()
    else:
        pred_axis_dir = p['dir'].detach()
    
    pred_info = {
        "axis_o": p['axis_origin'].detach(),
        "axis_d": F.normalize(pred_axis_dir.view(1, -1)).view(-1),
        "R": torch.Tensor(pred_R),
        "theta": radian.detach(),
        "dist": p['scale'].detach()
    }
    print(pred_info)
    eval_dicts = []
    for gt_info in gt_infos:
        ang_err, pos_err = axis_metrics(pred_info, gt_info)
        trans_err = translational_error(pred_info, gt_info)
        geo_dist = geodesic_distance(torch.Tensor(pred_R), gt_info['R'])
        eval_metric_dict = {
                                "ang_err": ang_err.item(),
                                "pos_err": pos_err.item(),
                                "geo_dist": geo_dist.item(),
                                "trans_err": trans_err.item()
                            }
        eval_dicts += [eval_metric_dict]
    print(eval_dicts)

{'axis_o': tensor([ 0.4008, -0.3101,  0.0799], device='cuda:0'), 'axis_d': tensor([0.0106, 0.0020, 0.9999], device='cuda:0'), 'R': tensor([[ 8.2234e-01, -5.6899e-01,  3.0297e-03],
        [ 5.6900e-01,  8.2232e-01, -5.6738e-03],
        [ 7.3695e-04,  6.3897e-03,  9.9998e-01]]), 'theta': tensor([0.6053], device='cuda:0'), 'dist': tensor([0.], device='cuda:0')}
[{'ang_err': 0.6180230975151062, 'pos_err': 0.0006404438754543662, 'geo_dist': 0.4877876341342926, 'trans_err': 0.0}, {'ang_err': 0.6180230975151062, 'pos_err': 0.16591235995292664, 'geo_dist': 79.6811752319336, 'trans_err': 0.0}]
{'axis_o': tensor([-0.4001, -0.2844,  0.0864], device='cuda:0'), 'axis_d': tensor([ 0.0033, -0.0083, -1.0000], device='cuda:0'), 'R': tensor([[ 7.2711e-01,  6.8649e-01, -6.6276e-03],
        [-6.8651e-01,  7.2712e-01,  2.5808e-05],
        [ 4.8368e-03,  4.5312e-03,  9.9998e-01]]), 'theta': tensor([0.7567], device='cuda:0'), 'dist': tensor([0.], device='cuda:0')}
[{'ang_err': 0.5139620304107666, 'pos_er

In [29]:
gt_infos

[{'axis_o': tensor([ 0.4272, -0.3057,  0.8436]),
  'axis_d': tensor([0., 0., 1.]),
  'R': tensor([[ 0.8192, -0.5736,  0.0000],
          [ 0.5736,  0.8192,  0.0000],
          [ 0.0000,  0.0000,  1.0000]]),
  'theta': 0.6108652381980153,
  'dist': tensor([0.])},
 {'axis_o': tensor([-0.4028, -0.2939,  0.8677]),
  'axis_d': tensor([0., 0., 1.]),
  'R': tensor([[ 0.7071,  0.7071,  0.0000],
          [-0.7071,  0.7071,  0.0000],
          [ 0.0000,  0.0000,  1.0000]]),
  'theta': -0.7853981633974483,
  'dist': tensor([0.])}]