In [3]:
import torch
import os
import pickle as pkl
import trimesh
import numpy as np

from pixtrack.utils.pose_utils import geodesic_distance_for_rotations


In [4]:
def get_pose_mat_from_tensor(pose_tensor):
    translation = pose_tensor.t.cpu().numpy()
    rotation = pose_tensor.R.cpu().numpy()
    mesh_pose_in_cam = np.eye(4)
    mesh_pose_in_cam[:3, :3] = rotation
    mesh_pose_in_cam[:3, -1] = translation
    return mesh_pose_in_cam

In [5]:
def similarity_transform(from_points, to_points):
    
    assert len(from_points.shape) == 2, \
        "from_points must be a m x n array"
    assert from_points.shape == to_points.shape, \
        "from_points and to_points must have the same shape"
    
    N, m = from_points.shape
    
    mean_from = from_points.mean(axis = 0)
    mean_to = to_points.mean(axis = 0)
    
    delta_from = from_points - mean_from # N x m
    delta_to = to_points - mean_to       # N x m
    
    sigma_from = (delta_from * delta_from).sum(axis = 1).mean()
    sigma_to = (delta_to * delta_to).sum(axis = 1).mean()
    
    cov_matrix = delta_to.T.dot(delta_from) / N
    
    U, d, V_t = np.linalg.svd(cov_matrix, full_matrices = True)
    cov_rank = np.linalg.matrix_rank(cov_matrix)
    S = np.eye(m)
    
    if cov_rank >= m - 1 and np.linalg.det(cov_matrix) < 0:
        S[m-1, m-1] = -1
    elif cov_rank < m-1:
        raise ValueError("colinearility detected in covariance matrix:\n{}".format(cov_matrix))
    
    R = U.dot(S).dot(V_t)
    c = (d * S.diagonal()).sum() / sigma_from
    t = mean_to - c*R.dot(mean_from)
    
    return R, c, t

In [6]:
def get_pose_offset(poses_file):
    from_trs = []
    to_trs = []
    for image_key in poses_file:
        if (not poses_file[image_key]["success"]):
            continue
        to_trs.append(poses_file[image_key]["T_refined"].t.cpu().numpy())
        from_trs.append(poses_file[image_key]["gt_pose"].t.cpu().numpy())
    R, c, t = similarity_transform(np.array(from_trs), np.array(to_trs))
    pose_from_res_to_gt = np.eye(4)
    pose_from_res_to_gt[:3, :3] = R
    pose_from_res_to_gt[:3, -1] = t
    return pose_from_res_to_gt

In [7]:
from scipy.spatial.transform import Rotation as R
from pytorch3d.loss import chamfer_distance

In [139]:
mesh_path = "/mnt/remote/data/prajwal/YCB_Video_Dataset/models/035_power_drill/textured.obj"

object_mesh = trimesh.load(mesh_path)
vertices = np.array(object_mesh.vertices)
vertices = np.hstack((vertices, np.ones((vertices.shape[0], 1))))

In [165]:
result_folder = "/mnt/remote/data/prajwal/pixtrack/results/021_bleach_cleanser/"
poses_file = os.path.join(result_folder, "poses.pkl")
with open(poses_file, "rb") as f:
    poses_file = pkl.load(f)

In [166]:
def get_metrics(poses_file, tr_threshold, rot_threshold):
    distances = []
    add_ss = []
    pose_dists = []

    pose_from_res_to_gt = get_pose_offset(poses_file)
    bad_count = 0

    for image_key in poses_file:
        if (not poses_file[image_key]["success"]):
            #print(f"skipped {image_key}")
            continue
        res_pose_mat = get_pose_mat_from_tensor(poses_file[image_key]["T_refined"])
        gt_pose_mat = get_pose_mat_from_tensor(poses_file[image_key]["gt_pose"])
        aligned_res_pose = np.dot(pose_from_res_to_gt, res_pose_mat)
        tr_dist = np.linalg.norm(gt_pose_mat[:3, -1] - aligned_res_pose[:3, -1]) * 100
        rot_dist = geodesic_distance_for_rotations(gt_pose_mat[:3, :3], aligned_res_pose[:3, :3]) * 180 / np.pi
        res_vertices = np.dot(pose_from_res_to_gt, np.dot(res_pose_mat, vertices.T)).T[:, :3] * 100
        gt_vertices = np.dot(gt_pose_mat, vertices.T).T[:, :3] * 100


        l2_distances = np.linalg.norm(gt_vertices - res_vertices, axis=1)
        pose_dists.append(tr_dist)
        l2_dist = np.mean(l2_distances)
        if tr_dist > tr_threshold or rot_dist > rot_threshold:
            bad_count += 1
            
        distances.append(l2_dist)
        
        
    results = {}
    results["average_error_vertices"] = np.mean(distances)
    results["max_error"] = np.max(distances)
    results["max_translation_error"] = np.max(pose_dists)
    results["average_translation_error_pose"] = np.mean(pose_dists)
    results["bad_count"] = bad_count
    results["total_frames"] = len(poses_file)
    results["accuracy"] = (1.0*(len(poses_file) - bad_count))/(1.0*len(poses_file) )
    return results

    

In [167]:
res1 = get_metrics(poses_file, tr_threshold=5, rot_threshold=5)
res2 = get_metrics(poses_file, tr_threshold=3, rot_threshold=3)



In [168]:
print(f"thresh_1: {res1} \n thresh_2:{res2}")

thresh_1: {'average_error_vertices': 1.3623460859796348, 'max_error': 8.831234976584756, 'max_translation_error': 8.82980144481684, 'average_translation_error_pose': 1.3492508181637395, 'bad_count': 3, 'total_frames': 1590, 'accuracy': 0.9981132075471698} 
 thresh_2:{'average_error_vertices': 1.3623460859796348, 'max_error': 8.831234976584756, 'max_translation_error': 8.82980144481684, 'average_translation_error_pose': 1.3492508181637395, 'bad_count': 164, 'total_frames': 1590, 'accuracy': 0.8968553459119497}


In [164]:
(2.909139344971845 + 9.88552130601735 + 2.504790297611688 + 3.699731260500495)/4

4.749795552275344

In [160]:
!ls /mnt/remote/data/prajwal/pixtrack/results/021_bleach_cleanser/