In [1]:
# %matplotlib widget
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
os.environ['PYOPENGL_PLATFORM'] = 'egl'

from src.vector_object import *

from src.frame import Frame

from src.feature import SDF_Feature
from src.dataset import *
from src.utils import *
from src.data_gen_utils import *

from os import path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import trimesh

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
exp_name = 'PIFO_best'
state = torch.load('network/'+exp_name+'.pth.tar')
C = state['config']

trainset = PIFODataset(C['DATA_FILENAME'],
                       num_views=C['NUM_VIEWS'],
                       num_points=C['NUM_POINTS'],
                       num_grasps=C['NUM_GRASPS'],
                       num_hangs=C['NUM_HANGS'],
                       grasp_draw_points=C['GRASP_DRAW_POINTS'],
                       hang_draw_points=C['HANG_DRAW_POINTS'],
                       random_erase=False,
                       on_gpu_memory=True)

testset = PIFODataset('data/test_batch.hdf5',
                      num_views=C['NUM_VIEWS'],
                      num_points=C['NUM_POINTS'],
                      num_grasps=C['NUM_GRASPS'],
                      num_hangs=C['NUM_HANGS'],
                      grasp_draw_points=C['GRASP_DRAW_POINTS'],
                      hang_draw_points=C['HANG_DRAW_POINTS'],
                      random_erase=False,
                      on_gpu_memory=True)


warper = RandomImageWarper(img_res=C['IMG_RES'], 
                           sig_center=0, 
                           return_cam_params=True)

# PIFO
obj1 = Frame()
obj1.build_backbone(pretrained=True, **C)
obj1.build_sdf_head(C['SDF_HEAD_HIDDEN'])
obj1.build_keypoint_head('grasp', C['GRASP_HEAD_HIDDEN'], C['GRIPPER_POINTS'])
obj1.build_keypoint_head('hang', C['HANG_HEAD_HIDDEN'], C['HOOK_POINTS'])
obj1.load_state_dict(state['network'])
obj1.to(device).eval()
F_grasp1 = KeyPoint_Feature(obj1, 'grasp')
F_hang1 = KeyPoint_Feature(obj1, 'hang')

# noPixel
exp_name = 'noPixelAligned_best'
state = torch.load('network/'+exp_name+'.pth.tar')
C = state['config']
obj2 = Frame()
obj2.build_backbone(pretrained=True, **C)
obj2.build_sdf_head(C['SDF_HEAD_HIDDEN'])
obj2.build_keypoint_head('grasp', C['GRASP_HEAD_HIDDEN'], C['GRIPPER_POINTS'])
obj2.build_keypoint_head('hang', C['HANG_HEAD_HIDDEN'], C['HOOK_POINTS'])
obj2.load_state_dict(state['network'])
obj2.to(device).eval()
F_grasp2 = KeyPoint_Feature(obj2, 'grasp')
F_hang2 = KeyPoint_Feature(obj2, 'hang')

# vecObj
exp_name = 'vectorObject_best'
state = torch.load('network/'+exp_name+'.pth.tar')
C = state['config']
obj3 = Frame_vec()
obj3.build_backbone(pretrained=True, **C)
obj3.build_sdf_head(C['SDF_HEAD_HIDDEN'])
obj3.build_pose_head('grasp', C['GRASP_HEAD_HIDDEN'])
obj3.build_pose_head('hang', C['HANG_HEAD_HIDDEN'])
obj3.load_state_dict(state['network'])
obj3.to(device).eval()

F_grasp3 = Pose_Feature_vec(obj3, 'grasp')
F_hang3 = Pose_Feature_vec(obj3, 'hang')

In [3]:
for num_views in [2, 8]:
    print('======================== '+str(num_views)+' ========================')
    for t, dataset in enumerate([trainset, testset]):
        dataset.num_views = num_views
        
        x_PIFO, x_noPixel, x_vecObj = [], [], []
        y_PIFO, y_noPixel, y_vecObj = [], [], []
        B, N, num_best = len(dataset), 10, 1

        for it in range(10):
            rgb_list, projections_list, cam_params_list, filename_list, mass_list, com_list = [], [], [], [], [], []
            for i in range(B):
                data = to_device(dataset[i], device)
                rgb, projections, cam_pos, new_origin, cam_roll = warper(data['rgb'].unsqueeze(0), 
                                          data['cam_extrinsic'].unsqueeze(0), 
                                          data['cam_intrinsic'].unsqueeze(0))
                rgb_list.append(rgb)
                projections_list.append(projections)
                cam_params_list.append(torch.cat([cam_pos, cam_roll], dim=2))
                filename_list.append(data['filenames'])
                mass_list.append(data['masses'])
                com_list.append(data['coms'])

            x_init = torch.cat([.2*torch.randn(B,N,3, device=device), 
                               random_quaternions(B*N, device=device).view(B,N,4)], dim=2)

            x, cost, coll = F_grasp1.optimize(x_init.clone(),
                                              torch.cat(rgb_list), 
                                              torch.cat(projections_list))

            best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)
            best_poses = torch.gather(x, dim=1, index=best_inds)

            x_PIFO.append(best_poses)

            x, cost, coll = F_grasp1.optimize(x,
                                             torch.cat(rgb_list), 
                                             torch.cat(projections_list),
                                             w_coll=1e3)

            best_inds = torch.tensor(np.square(cost)+np.square(coll*1e3)).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)
            best_poses = torch.gather(x, dim=1, index=best_inds)

            y_PIFO.append(best_poses)


            ### 2
            x, cost, coll = F_grasp2.optimize(x_init.clone(),
                                              torch.cat(rgb_list), 
                                              torch.cat(projections_list))

            best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)
            best_poses = torch.gather(x, dim=1, index=best_inds)

            x_noPixel.append(best_poses)

            x, cost, coll = F_grasp2.optimize(x,
                                             torch.cat(rgb_list), 
                                             torch.cat(projections_list),
                                             w_coll=1e3)

            best_inds = torch.tensor(np.square(cost)+np.square(coll*1e3)).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)
            best_poses = torch.gather(x, dim=1, index=best_inds)

            y_noPixel.append(best_poses)


            ### 3
            x, cost, coll = F_grasp3.optimize(x_init.clone(),
                                              torch.cat(rgb_list), 
                                              torch.cat(cam_params_list))

            best_inds = torch.tensor(cost).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)
            best_poses = torch.gather(x, dim=1, index=best_inds)

            x_vecObj.append(best_poses)

            x, cost, coll = F_grasp3.optimize(x,
                                             torch.cat(rgb_list), 
                                             torch.cat(cam_params_list),
                                             w_coll=1e3)

            best_inds = torch.tensor(np.square(cost)+np.square(coll*1e3)).argsort(dim=1)[:, :num_best].to(device).view(B,num_best,1).expand(-1,-1,7)
            best_poses = torch.gather(x, dim=1, index=best_inds)

            y_vecObj.append(best_poses)

        f1_PIFO = F_grasp1.check_feasibility(torch.cat(x_PIFO, dim=1),
                                             filename_list, 
                                             mass_list,
                                             com_list)
        f2_PIFO = F_grasp1.check_feasibility(torch.cat(y_PIFO, dim=1),
                                             filename_list, 
                                             mass_list,
                                             com_list)

        f1_noPixel = F_grasp2.check_feasibility(torch.cat(x_noPixel, dim=1),
                                             filename_list, 
                                             mass_list,
                                             com_list)
        f2_noPixel = F_grasp2.check_feasibility(torch.cat(y_noPixel, dim=1),
                                             filename_list, 
                                             mass_list,
                                             com_list)

        f1_vecObj = F_grasp3.check_feasibility(torch.cat(x_vecObj, dim=1),
                                             filename_list, 
                                             mass_list,
                                             com_list)
        f2_vecObj = F_grasp3.check_feasibility(torch.cat(y_vecObj, dim=1),
                                             filename_list, 
                                             mass_list,
                                             com_list)

        print(f1_PIFO.sum()/f1_PIFO.size, f2_PIFO.sum()/f2_PIFO.size)
        print(f1_noPixel.sum()/f1_noPixel.size, f2_noPixel.sum()/f2_noPixel.size)
        print(f1_vecObj.sum()/f1_vecObj.size, f2_vecObj.sum()/f2_vecObj.size)

        data_name = 'train' if t == 0 else 'test'
        with h5py.File('evals/grasp/'+data_name+'_'+str(num_views)+'.hdf5', mode='w') as f:
            f.create_dataset("x_PIFO", data=torch.cat(x_PIFO, dim=1).cpu().numpy())
            f.create_dataset("y_PIFO", data=torch.cat(y_PIFO, dim=1).cpu().numpy())

            f.create_dataset("x_noPixel", data=torch.cat(x_noPixel, dim=1).cpu().numpy())
            f.create_dataset("y_noPixel", data=torch.cat(y_noPixel, dim=1).cpu().numpy())

            f.create_dataset("x_vecObj", data=torch.cat(x_vecObj, dim=1).cpu().numpy())
            f.create_dataset("y_vecObj", data=torch.cat(y_vecObj, dim=1).cpu().numpy())

            f.create_dataset("f1_PIFO", data=f1_PIFO)
            f.create_dataset("f2_PIFO", data=f2_PIFO)

            f.create_dataset("f1_noPixel", data=f1_noPixel)
            f.create_dataset("f2_noPixel", data=f2_noPixel)

            f.create_dataset("f1_vecObj", data=f1_vecObj)
            f.create_dataset("f2_vecObj", data=f2_vecObj)

iter: 0, cost: 6.0169219970703125, coll: 0.0
iter: 0, cost: 3.974701166152954, coll: 233.3343048095703
iter: 0, cost: 5.02092981338501, coll: 0.0
iter: 0, cost: 2.5152018070220947, coll: 188.2131805419922
iter: 0, cost: 5.2272868156433105, coll: 0.0
iter: 0, cost: 2.806736946105957, coll: 278.63421630859375
iter: 0, cost: 5.037403106689453, coll: 0.0
iter: 0, cost: 2.8562281131744385, coll: 199.72267150878906
iter: 0, cost: 4.787069797515869, coll: 0.0
iter: 0, cost: 2.8516359329223633, coll: 217.81399536132812
iter: 0, cost: 4.920601844787598, coll: 0.0
iter: 0, cost: 2.92130446434021, coll: 277.5904235839844
iter: 0, cost: 5.023969650268555, coll: 0.0
iter: 0, cost: 3.83841609954834, coll: 242.57069396972656
iter: 0, cost: 4.968585014343262, coll: 0.0
iter: 0, cost: 2.5577192306518555, coll: 204.83006286621094
iter: 0, cost: 4.99035120010376, coll: 0.0
iter: 0, cost: 2.4886393547058105, coll: 282.6738586425781
iter: 0, cost: 5.9312920570373535, coll: 0.0
iter: 0, cost: 4.017564296722

In [7]:
for num_views in [2, 4, 8]:
    for data_name in ['train', 'test']:
        with h5py.File('evals/grasp/'+data_name+'_'+str(num_views)+'.hdf5', mode='r') as f:
            f1_PIFO, f1_noPixel, f1_vecObj = f['f1_PIFO'][:], f['f1_noPixel'][:], f['f1_vecObj'][:] 
            f2_PIFO, f2_noPixel, f2_vecObj = f['f2_PIFO'][:], f['f2_noPixel'][:], f['f2_vecObj'][:] 
        print('======================== '+data_name+'_'+str(num_views)+' ========================')
        print(f1_PIFO.sum()/f1_PIFO.size, f2_PIFO.sum()/f2_PIFO.size)
        print(f1_noPixel.sum()/f1_noPixel.size, f2_noPixel.sum()/f2_noPixel.size)
        print(f1_vecObj.sum()/f1_vecObj.size, f2_vecObj.sum()/f2_vecObj.size)

0.6576923076923077 0.8294871794871795
0.6756410256410257 0.808974358974359
0.13205128205128205 0.007692307692307693
0.5535714285714286 0.7714285714285715
0.6392857142857142 0.7035714285714286
0.12857142857142856 0.0035714285714285713
0.6897435897435897 0.8807692307692307
0.6230769230769231 0.8269230769230769
0.21153846153846154 0.005128205128205128
0.6392857142857142 0.825
0.6178571428571429 0.7571428571428571
0.225 0.0035714285714285713
0.7192307692307692 0.8871794871794871
0.7128205128205128 0.8397435897435898
0.28974358974358977 0.005128205128205128
0.6928571428571428 0.85
0.6714285714285714 0.7928571428571428
0.2392857142857143 0.007142857142857143
