In [1]:
import torch
import cv2
from models.ngp_wrapper import NGP_Prop_Art_Seg_Wrapper
from dataset.pose_utils import quaternion_to_axis_angle, get_quaternion_axis_angle
from config import get_opts
from dataset.sapien import SapienParisDataset
import math
from pathlib import Path as P
from dataset.io_utils import load_gt_from_json
from tqdm import tqdm
import numpy as np


from PIL import Image
def draw_axis(pil_img, axis_info, c2w, K, thickness=8):
    '''
    info_type
        options: gt or pred
    '''
    cv_img = np.array(pil_img)
    # cv_img = cv2.cvtColor(cv_img, cv2.COLOR_RGB2BGR)
    pred_info = axis_info['pred']
    gt_info = axis_info['gt']
    pred_pix = proj_axis(pred_info, c2w, K).round().detach().cpu().numpy().astype(np.int16)
    cv2.arrowedLine(cv_img, pred_pix[0], pred_pix[1], thickness=thickness, color=(255, 0, 0))
    
    gt_pix = proj_axis(gt_info, c2w, K).round().detach().cpu().numpy().astype(np.int16)
    cv2.arrowedLine(cv_img, gt_pix[0], gt_pix[1], thickness=thickness, color=(0, 255, 0))
    
    return Image.fromarray(cv_img)

@torch.inference_mode()
def proj_axis(axis_info, c2w, K):
    '''
    return [2, 2] in pixel coordinate
    '''
    # Add a dimension of ones to the point cloud to make it homogeneous
    ones = torch.ones((axis_info.shape[0], 1), device=axis_info.device)
    homogeneous_point_cloud = torch.cat((axis_info, ones), dim=1)
    
    # Transform the point cloud from world coordinates to camera coordinates
    points_in_camera_coordinates = torch.inverse(c2w.to(axis_info)) @ homogeneous_point_cloud.t()
    # Normalize the coordinates
    points_in_camera_coordinates /= points_in_camera_coordinates[3, :].clone()
    
    # Project the points onto the 2D plane using the intrinsic matrix
    projected_points = K.to(axis_info) @ points_in_camera_coordinates[:3, :]
    
    # Normalize the projected points
    projected_points /= projected_points[2, :].clone()
    
    return projected_points[:2, :].T



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# eval_args
fridge_dict = {
    "ckpt_fname": "results_ablation/fridge_end_to_start_f64/1714724128/ckpt/best_ckpt.pth",
    "cfg_file": "results_ablation/fridge_end_to_start_f64/1714724128/eval/config.json",
    "test_path": "vis/fridge",
    "fix_view": 33,
    "num_step": 4
}

stapler_dict = {
    "ckpt_fname": "results_ablation/stapler_end_to_start_f100/1714808215/ckpt/best_ckpt.pth",
    "cfg_file": "results_ablation/stapler_end_to_start_f100/1714808215/eval/config.json",
    "test_path": "vis/stapler",
    "fix_view": 33,
    "num_step": 4
}

oven_dict = {
    "ckpt_fname": "results_ablation/oven_start_to_end_f100/1714874463/ckpt/best_ckpt.pth",
    "cfg_file": "results_ablation/oven_start_to_end_f100/1714874463/eval/config.json",
    "test_path": "vis/oven",
    "fix_view": 33,
    "num_step": 4
}
blade_dict = {
    "ckpt_fname": "results_ablation/blade_start_to_end_f100/1714857261/ckpt/best_ckpt.pth",
    "cfg_file": "results_ablation/blade_start_to_end_f100/1714857261/eval/config.json",
    "test_path": "vis/blade",
    "fix_view": 33,
    "num_step": 4
}
storage_dict = {
    "ckpt_fname": "results_ablation/storage_end_to_start_f100/1714855450/ckpt/best_ckpt.pth",
    "cfg_file": "results_ablation/storage_end_to_start_f100/1714855450/eval/config.json",
    "test_path": "vis/storage",
    "fix_view": 33,
    "num_step": 4
}

laptop_dict = {
    "ckpt_fname": "results_ablation/laptop_start_to_end_f100/1714857261/ckpt/best_ckpt.pth",
    "cfg_file": "results_ablation/laptop_start_to_end_f100/1714857261/eval/config.json",
    "test_path": "vis/laptop",
    "fix_view": 33,
    "num_step": 4
}


In [3]:


def save_vis(ckpt_fname, cfg_file, test_path, num_step=30, fix_view=None):

    test_path = P(test_path)


    test_path.mkdir(exist_ok=True, parents=True)
    opts = get_opts(['--config', cfg_file])
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
        
    setattr(opts, 'device', device)
    opts.pre_trained_weights = None
    # num_step = 30

    model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                    ignore_empty=False, use_timestamp=False, use_se3=opts.use_se3)



    gt_info = load_gt_from_json(opts.motion_gt_json, opts.state, opts.motion_type)

    test_dataset = SapienParisDataset(
            root_dir = opts.root_dir,
            near = opts.near_plane,
            far = opts.far_plane,
            img_wh = opts.img_wh, 
            batch_size=opts.batch_size,
            split='test',
            render_bkgd='white',
            state=opts.state
        )

    model.load_ckpt(ckpt_fname)
    motion_list = []
    axis_list = []
    
    vis_scale = 1
    if 'fridge' in ckpt_fname:
        vis_scale = 0.5
    if opts.motion_type == 'r':
        # gen_art_pose
        for pose_param in model.pose_module_list:
            pose_param.norm_Q()
            cur_quat = pose_param.Q
            axis_d, angle, deg = quaternion_to_axis_angle(cur_quat)
            angle_detach = angle.cpu().detach()
            # print(torch.arange(0, num_step+1) / (num_step))
            if angle > 0:
                cur_list = torch.arange(0, num_step+1) / (num_step) * angle_detach
            else:
                cur_list = angle_detach - torch.arange(0, num_step+1) / (num_step) * angle_detach
            half_q_list = [get_quaternion_axis_angle(axis_d.cpu().detach().numpy(), cur_q) for cur_q in cur_list]
            motion_list += [half_q_list]
            axis_o = pose_param.axis_origin
            sim = torch.nn.functional.cosine_similarity(axis_o.view(1, -1), axis_d.view(1, -1), dim=1)
            
            cur_axis = torch.stack([axis_o -sim*axis_d, axis_o + (1-vis_scale*sim)*axis_d])
            axis_list += [cur_axis]
    else:
        for pose_param in model.pose_module_list:
            # print(pose_param)
            pose_param.norm_dir()
            scale = pose_param.scale.cpu().detach()
            axis_d = pose_param.dir
            if scale > 0:
                # cur_list = torch.arange(0, scale, step=scale/num_step)
                cur_list = torch.arange(0, num_step+1) / (num_step ) * scale
            else:
                cur_list = scale - torch.arange(0, num_step+1) / (num_step ) * scale
            # print(f'cur_list: {cur_list}')
            motion_list += [cur_list]
            # print(f'motion list: {motion_list}')
            cur_axis = torch.stack([torch.zeros_like(axis_d), axis_d])
            axis_list += [cur_axis]

    # generate camera pose
    theta_list = torch.arange(0, 2*math.pi, step=math.pi/20)
    phi = torch.zeros_like(theta_list) * 0.25 * math.pi

    K = test_dataset.K
    pred_axis_info = axis_list[0]
    print(vis_scale)
    gt_axis_info = torch.stack([gt_info['axis_o'], gt_info['axis_o'] + vis_scale * gt_info['axis_d']])
    axis_info = {
        'pred': pred_axis_info,
        'gt': gt_axis_info
    }

    model.load_ckpt(ckpt_fname)
    # test_path = P(ckpt_fname).parent.parent / 'test'
    print(axis_info)
    static_path = test_path / 'target'
    static_path.mkdir(exist_ok=True)
    psnrs = []
    gt_img_list = []
    pred_img_list = []
    for i in tqdm(range(len(test_dataset))):
        test_data = test_dataset.__getitem__(i)
        eval_dict = model.eval(test_data)
        psnrs += [eval_dict['psnr'].item()]
        img_gt = eval_dict['img_gt']
        img_pred = eval_dict['img_pred']
        
        img_gt.save(static_path / f'img_gt_{i:04d}.png')
        img_pred.save(static_path / f'img_pred_{i:04d}.png')
        # print(axis_info)
        img_arrow = draw_axis(img_pred, axis_info, test_data['c2w'].squeeze(0), K)
        img_arrow.save(static_path / f'img_pred_arrow_{i:04d}.png')
        gt_img_list += [img_gt]
        pred_img_list += [img_arrow]
    avg_psnr = sum(psnrs) / len(psnrs)
    print(avg_psnr)

    art_path = test_path / 'art'
    art_path.mkdir(exist_ok=True)
    total_art_part = len(model.pose_module_list)
    art_img_list = []
    dataset_len = len(test_dataset)
    # print(pose_param.scale)
    # print(motion_list[0])
    # print(axis_info)
    for i in tqdm(range(len(motion_list[0]))):
        if fix_view is None:
            test_data = test_dataset.__getitem__(i % dataset_len)
        else:
            test_data = test_dataset.__getitem__(fix_view)
        
        if opts.motion_type == 'r':
            for p in range(total_art_part):
                cur_Q = torch.Tensor(motion_list[p][i])
                cur_pose_param = model.pose_module_list[p]
                cur_Q = cur_Q.to(cur_pose_param.Q)
                cur_pose_param.Q = torch.nn.Parameter(cur_Q)
        else:
            for p in range(total_art_part):
                cur_scale = torch.Tensor(motion_list[p][i])
                cur_pose_param = model.pose_module_list[p]
                cur_scale = cur_scale.to(cur_pose_param.scale)
                cur_pose_param.scale = torch.nn.Parameter(cur_scale)
        
        eval_dict = model.eval(test_data)
        img_pred = eval_dict['img_pred']
        # print(axis_info)
        img_arrow = draw_axis(img_pred, axis_info, test_data['c2w'].squeeze(0), K, thickness=10)
        img_arrow.save(art_path / f'img_pred_art_{i:04d}.png')
        img_pred.save(art_path / f'img_pred_art_clean_{i:04d}.png')
        art_img_list += [img_arrow]
        pass

    # gif_path = test_path / 'gif'
    # gif_path.mkdir(exist_ok=True)
    # gt_img_list[0].save(str(gif_path / 'gt.gif'), save_all=True, append_images=gt_img_list[1:], duration=5, optimize=False, loop=0)
    # pred_img_list[0].save(str(gif_path / 'pred_target.gif'), save_all=True, append_images=pred_img_list[1:], duration=5, optimize=False, loop=0)
    # art_img_list[0].save(str(gif_path / 'pred_art.gif'), save_all=True, append_images=art_img_list[1:], duration=5, optimize=False, loop=0)

In [74]:
save_vis(**laptop_dict)

1
{'pred': tensor([[-0.2553,  0.3327, -0.0066],
        [-0.2592,  1.3327, -0.0077]], device='cuda:0',
       grad_fn=<StackBackward0>), 'gt': tensor([[-0.2471,  0.0000, -0.0077],
        [-0.2471,  1.0000, -0.0077]])}


100%|██████████| 5/5 [00:11<00:00,  2.33s/it]


In [75]:
save_vis(**stapler_dict)

1
{'pred': tensor([[-0.7505,  0.0731,  0.1030],
        [-0.7568, -0.9268,  0.1013]], device='cuda:0',
       grad_fn=<StackBackward0>), 'gt': tensor([[-0.7521,  0.0000,  0.1050],
        [-0.7521,  1.0000,  0.1050]])}


100%|██████████| 5/5 [00:11<00:00,  2.25s/it]


In [4]:
save_vis(**fridge_dict)

0.5
{'pred': tensor([[ 0.1664,  0.2306, -0.1386],
        [ 0.1669,  0.2342, -1.0403]], device='cuda:0',
       grad_fn=<StackBackward0>), 'gt': tensor([[0.1669, 0.2269, 0.4569],
        [0.1669, 0.2269, 0.9569]])}


100%|██████████| 50/50 [02:01<00:00,  2.43s/it]


36.83317230224609


100%|██████████| 5/5 [00:11<00:00,  2.37s/it]


In [77]:
save_vis(**oven_dict)

1
{'pred': tensor([[ 0.4048,  0.0576, -0.4440],
        [ 0.4025,  1.0576, -0.4497]], device='cuda:0',
       grad_fn=<StackBackward0>), 'gt': tensor([[ 0.4194, -0.6382, -0.4423],
        [ 0.4194,  0.3618, -0.4423]])}


100%|██████████| 5/5 [00:11<00:00,  2.33s/it]


In [78]:
save_vis(**storage_dict)

1
{'pred': tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.9999,  0.0047, -0.0143]], device='cuda:0',
       grad_fn=<StackBackward0>), 'gt': tensor([[ 0.3823, -0.0146,  0.1356],
        [ 1.3823, -0.0146,  0.1356]])}


100%|██████████| 5/5 [00:11<00:00,  2.34s/it]


In [79]:
save_vis(**blade_dict)

1
{'pred': tensor([[ 0.0000,  0.0000,  0.0000],
        [ 0.9990,  0.0406, -0.0183]], device='cuda:0',
       grad_fn=<StackBackward0>), 'gt': tensor([[0., 0., 0.],
        [1., 0., 0.]])}


100%|██████████| 5/5 [00:11<00:00,  2.21s/it]


In [None]:

def save_vis_multipart(ckpt_fname, cfg_file, test_path, num_step=30, fix_view=None):

    test_path = P(test_path)


    test_path.mkdir(exist_ok=True)
    opts = get_opts(['--config', cfg_file])
    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'
        
    setattr(opts, 'device', device)
    opts.pre_trained_weights = None
    # num_step = 30

    model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=False, 
                                    ignore_empty=False, use_timestamp=False, use_se3=opts.use_se3)



    gt_info = load_gt_from_json(opts.motion_gt_json, opts.state, opts.motion_type)

    test_dataset = SapienParisDataset(
            root_dir = opts.root_dir,
            near = opts.near_plane,
            far = opts.far_plane,
            img_wh = opts.img_wh, 
            batch_size=opts.batch_size,
            split='test',
            render_bkgd='white',
            state=opts.state
        )

    model.load_ckpt(ckpt_fname)
    motion_list = []
    axis_list = []
    if opts.motion_type == 'r':
        # gen_art_pose
        for pose_param in model.pose_module_list:
            # motion_list = []
            # axis_list = []
            pose_param.norm_Q()
            cur_quat = pose_param.Q
            axis_d, angle, deg = quaternion_to_axis_angle(cur_quat)
            print(f'angle = {angle}')
            if angle > 0:
                cur_list = torch.arange(0, angle.item(), step=angle.item()/num_step)
            else:
                cur_list = torch.arange(angle.item(), 0, step=angle.item()/num_step)
            half_q_list = [get_quaternion_axis_angle(axis_d.cpu().detach().numpy(), cur_q) for cur_q in cur_list]
            cur_q_list = half_q_list + half_q_list[::-1]
            motion_list += [cur_q_list]
            axis_o = pose_param.axis_origin
            cur_axis = torch.stack([axis_o, axis_o + axis_d])
            axis_list += [cur_axis]
            # motion_lists += [motion_list]
            # axis_lists += [axis_list]
    else:
        for pose_param in model.pose_module_list:
            pose_param.norm_dir()
            scale = pose_param.scale()
            axis_d = pose_param.dir()
            if scale > 0:
                cur_list = torch.arange(0, scale, step=scale/num_step)
            else:
                cur_list = torch.arange(scale, 0, step=scale/num_step)
            motion_list += [axis_d * cur_scale for cur_scale in cur_list]
            cur_axis = torch.stack([torch.zeros_like(axis_d), axis_d])
            axis_list += [cur_axis]

    # generate camera pose
    theta_list = torch.arange(0, 2*math.pi, step=math.pi/20)
    phi = torch.zeros_like(theta_list) * 0.25 * math.pi

    K = test_dataset.K
    pred_axis_info = axis_list[0]
    gt_axis_info = torch.stack([gt_info['axis_o'], gt_info['axis_o'] + gt_info['axis_d']])
    axis_info = {
        'pred': pred_axis_info,
        'gt': gt_axis_info
    }

    model.load_ckpt(ckpt_fname)
    # test_path = P(ckpt_fname).parent.parent / 'test'

    static_path = test_path / 'target'
    static_path.mkdir(exist_ok=True)
    psnrs = []
    gt_img_list = []
    pred_img_list = []
    for i in tqdm(range(len(test_dataset))):
        test_data = test_dataset.__getitem__(i)
        eval_dict = model.eval(test_data)
        psnrs += [eval_dict['psnr'].item()]
        img_gt = eval_dict['img_gt']
        img_pred = eval_dict['img_pred']
        
        img_gt.save(static_path / f'img_gt_{i:04d}.png')
        img_pred.save(static_path / f'img_pred_{i:04d}.png')
        
        img_arrow = draw_axis(img_pred, axis_info, test_data['c2w'].squeeze(0), K)
        img_arrow.save(static_path / f'img_pred_arrow_{i:04d}.png')
        gt_img_list += [img_gt]
        pred_img_list += [img_arrow]
    avg_psnr = sum(psnrs) / len(psnrs)
    print(avg_psnr)

    art_path = test_path / 'art'
    art_path.mkdir(exist_ok=True)
    total_art_part = len(model.pose_module_list)
    art_img_list = []
    dataset_len = len(test_dataset)
    for i in tqdm(range(len(motion_list[0]))):
        if fix_view is None:
            test_data = test_dataset.__getitem__(i % dataset_len)
        else:
            test_data = test_dataset.__getitem__(fix_view)
        
        for p in range(total_art_part):
            cur_Q = torch.Tensor(motion_list[p][i])
            cur_pose_param = model.pose_module_list[p]
            cur_Q = cur_Q.to(cur_pose_param.Q)
            cur_pose_param.Q = torch.nn.Parameter(cur_Q)
        
        eval_dict = model.eval(test_data)
        img_pred = eval_dict['img_pred']
        img_arrow = draw_axis(img_pred, axis_info, test_data['c2w'].squeeze(0), K)
        img_arrow.save(art_path / f'img_pred_art_{i:04d}.png')
        art_img_list += [img_arrow]
        pass
