In [1]:
import os
from config import get_opts
from models.ngp_wrapper import NGP_Prop_Art_Wrapper, NGP_Prop_Wrapper, NGP_Prop_Art_Seg_Wrapper
import torch
from torch.utils.data import DataLoader
from dataset.sapien import SapienParisDataset
from tqdm import tqdm
import open3d as o3d
import sys
import traceback
from pose_estimation import PoseEstimator
from test_ngp import NGPevaluator
from dataset.pose_utils import quaternion_to_axis_angle, get_rotation_axis_angle
from dataset.io_utils import load_gt_from_json
from models.utils import axis_metrics, geodesic_distance, translational_error
import numpy as np
import torchvision.transforms.functional as tvf

argv_string = ["--config", "configs_seg/stapler_f16.json"]

opts = get_opts(argv_string)

print("=" * 100)
print(f"running exp: {opts.exp_name}")
print("=" * 100)

# sys.stdout = TracePrints()
# set device
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'

setattr(opts, 'device', device)

gt_info = load_gt_from_json(opts.motion_gt_json, opts.state, opts.motion_type)

# load model
ignore_empty = False
if opts.use_art_seg_estimator:
    model = NGP_Prop_Art_Seg_Wrapper(config=opts, training=True, 
                                        ignore_empty=ignore_empty, use_timestamp=True, use_se3=opts.use_se3)
else:
    model = NGP_Prop_Art_Wrapper(config=opts, training=True, ignore_empty=ignore_empty)
co_mask = 'stapler_art_mask.pth'
# load dataset
train_dataset = SapienParisDataset(
    root_dir = opts.root_dir,
    near = opts.near_plane,
    far = opts.far_plane,
    img_wh = opts.img_wh, 
    batch_size=opts.batch_size,
    split='train',
    render_bkgd='white',
    ignore_empty=ignore_empty,
    co_mask=None,
    state=opts.state
)

test_dataset = SapienParisDataset(
    root_dir = opts.root_dir,
    near = opts.near_plane,
    far = opts.far_plane,
    img_wh = opts.img_wh, 
    batch_size=opts.batch_size,
    split='val',
    render_bkgd='white',
    state=opts.state
)
pose_lr = 1e-1
# load pose estimator
print('=' * 100)
print('=' * 40, 'loading coarse pose estimator', '=' * 40)
print('=' * 100)
# pretrain_cfg = opts.pretrained_config
# pretrain_strs = ["--config", pretrain_cfg]
# pretrain_opts = get_opts(pretrain_strs)

    
pretran_model = NGP_Prop_Wrapper(opts, training=False)

renderer = NGPevaluator(opts, dataset=train_dataset, model=pretran_model)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
running exp: stapler_end_to_start_f16


In [2]:
estimator = PoseEstimator(renderer=renderer, dataset=train_dataset, output_dir=model.eval_path, use_num_frames=opts.use_num_frames, device=device, scaling=0.5, use_se3=opts.use_se3, motion_type=opts.motion_type, idx_list=None, eps=opts.eps, select_frame=opts.use_num_frames)

------------------------------------------
collecting points for parts
------------------------------------------


100%|██████████| 16/16 [00:03<00:00,  4.45it/s]


training with frames: tensor([ 5,  6,  2,  8, 10,  7,  4, 15,  3, 13, 11, 12,  0, 14,  1,  9],
       device='cuda:0')
------------------------------------------
collecting points for parts again after frame selection
------------------------------------------
-----------removing outliers from nerf scan-----------


100%|██████████| 16/16 [00:03<00:00,  5.32it/s]
100%|██████████| 16/16 [00:02<00:00,  5.36it/s]


In [3]:
estimator.co_mask_list[0]

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='cuda:0')

In [4]:
final_2d_pts, tgt_pts, dy_pts, co_pts = estimator.proj_3d_to_2d(0, vis=True)

In [10]:
dy_img = tvf.to_pil_image(estimator.pts_2d_to_img(dy_pts).unsqueeze(0))
co_img = tvf.to_pil_image(estimator.pts_2d_to_img(co_pts))
tgt_img = tvf.to_pil_image(estimator.pts_2d_to_img(tgt_pts))

In [12]:
dy_img.save('vis/dy_img.png')
co_img.save('vis/co_img.png')
tgt_img.save('vis/tgt_img.png')

In [13]:
estimator.dataset.rgb.shape

torch.Size([100, 640000, 3])

In [17]:
rgb0 = estimator.dataset.rgb[estimator.idx_list[0],:, :].view(800, 800, 3).permute(2, 0, 1)

In [18]:
rgb0_img = tvf.to_pil_image(rgb0)

In [19]:
rgb0_img.save('vis/tgt_rgb.png')