In [None]:
import sys
import os
grasp_transfer_path = os.getenv("GRASP_TRANSFER_SOURCE_DIR")
sys.path.insert(0,grasp_transfer_path)

In [None]:
from dataio import *
from networks import *
from torch.utils.data import DataLoader
from visualize import *
from trimesh.viewer import windowed
import time
from tqdm.autonotebook import tqdm
from sdf_estimator import *
import open3d as o3d
from scipy.spatial.transform import Rotation as R
import ipywidgets as widgets
import pickle
import yaml

from utility import *


In [None]:
part_name = "handle" # "rim"
with open("../configs/part_dataset_info.yaml", 'r') as stream:
    dataset_config = yaml.safe_load(stream)
train = dataset_config['train_list']
val = dataset_config['val_list']
part = dataset_config['part_'+part_name]    
output_name = "../outputs/"+part_name

In [None]:
sdf_dataset = Part_SDF_Dataset(train,64,2000)
sdf_dataset_val = Part_SDF_Dataset(val,16,500)

RT = np.array(part['RT'])
sdf_dataset.transform_all_pc_with_Rt(RT)
sdf_dataset_val.transform_all_pc_with_Rt(RT)

train_dataloader = DataLoader(sdf_dataset, shuffle=True,batch_size=16, num_workers=0, drop_last = True)
val_dataloader = DataLoader(sdf_dataset_val, shuffle=True,batch_size=16, num_workers=0, drop_last = True)

In [None]:
model = MyNet(64,latent_dim = 128, hyper_hidden_features=256,hidden_num=128)
model.to(device=torch.device('cuda:0'))
model.load_state_dict(torch.load(output_name+"/model"))#my_model
model.eval()

In [None]:
contour_plot(model,2000) # Change to 0 (Low Frequency Positional encodings) to see the smoother output

# Validation Visulizations

In [None]:
RTs = model.get_refine_poses(torch.arange(0,64).cuda(),affine=False).cpu().detach().numpy()
sdf_dataset.set_sampling_RT(RTs[:,:3,:4])    
RT_base=np.eye(4)
RT_base[:3,:]= RTs[0]

In [None]:
### Creating Interpolation Trajectories
# #Handle: [0,12,16,53,42,8] Rim:[0,4,17,18,61,38] Cuboid: [24,56,60,61,57,25]
mesh_list=estimate_mesh_traj_from_model(model, 2000, [24,56,60,61,57,25], num_of_steps=20, resolution=128,scale=0.5)
draw_mesh_list(mesh_list,output_name+"/interp")

In [None]:
# Single Object Grasp Visualization
mesh_list=list()
part_points_list=list()
sphere_base = trimesh.primitives.Sphere(radius = 0.5,subdivisions=6)
sphere_base2 = trimesh.primitives.Sphere(radius = 0.1,subdivisions=3)
gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
gripper_mesh.apply_scale(4)
shape_tr = 0
index = 0

fc_color =np.random.rand(3) * 255 

_,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])

mesh2 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale,
                        face_colors=fc_color, faces=sdf_dataset.meshes_transformed[index].faces)


pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(index)

points = np.copy(sdf_dataset.pointclouds[index][pc_inds][:8000])
part_points_list.append(points)


RT= np.eye(4)
RT[:3,:] = sdf_dataset.sampling_RTs[index]
RT_inv = find_inverse_RT_4x4(RT)

mesh_list.append(mesh2)

gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                     faces=gripper_mesh.faces,face_colors=[175,175,175])
gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
gripper_mesh_for_grasp.apply_transform(RT_inv)
gripper_mesh_for_grasp.apply_transform(RT_base)

mesh_list.append(gripper_mesh_for_grasp)

shape_tr=shape_tr+1
    
scene = trimesh_show(part_points_list)
scene.add_geometry(mesh_list)

window = windowed.SceneViewer(scene)

In [None]:
# Multi Object Grasp Visualization
part_points_list = list()
mesh_list = list()

for obj_ind in range(0,4):  
    color = np.random.rand(3) * 255
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    
    obj_mesh = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[obj_ind].vertices,
                            face_colors = color,faces=sdf_dataset.meshes_transformed[obj_ind].faces)
    obj_mesh.apply_scale(scale)
    obj_mesh.apply_translation([0,0,-1.5*obj_ind])  

    pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(obj_ind,radius=0.8)
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    points = np.copy(sdf_dataset.pointclouds[obj_ind][pc_inds][:])
    points[:,2]-=1.5*obj_ind
    part_points_list.append(points)

    gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
    gripper_mesh.apply_scale(4)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                             faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation([0,0,-1.5*obj_ind])
    zero_frame =  get_coordinate_frame_mesh(0.2)
    zero_frame.apply_translation([0,0,-1.5*obj_ind])
    mesh_list.append(obj_mesh)
    mesh_list.append(gripper_mesh_for_grasp)
    mesh_list.append(zero_frame)
    
    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[obj_ind]
    
    obj_mesh = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[obj_ind].vertices,
                            face_colors = color,faces=sdf_dataset.meshes_transformed[obj_ind].faces)
    obj_mesh.apply_scale(scale)
    obj_mesh.apply_translation([3,0,-1.5*obj_ind])  

    pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(obj_ind,RT[:3,:],radius=0.8)
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    points = np.copy(sdf_dataset.pointclouds[obj_ind][pc_inds][:])
    points[:,2]-=1.5*obj_ind
    points[:,0]+=3
    part_points_list.append(points)

    gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
    gripper_mesh.apply_scale(4)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                             faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(find_inverse_RT_4x4(RT))    
    gripper_mesh_for_grasp.apply_transform(RT_base)
    
    gripper_mesh_for_grasp.apply_translation([3,0,-1.5*obj_ind])
    zero_frame =  get_coordinate_frame_mesh(0.2)
    zero_frame.apply_transform(find_inverse_RT_4x4(RT))    
    zero_frame.apply_transform(RT_base)

    zero_frame.apply_translation([3,0,-1.5*obj_ind])
    mesh_list.append(obj_mesh)
    mesh_list.append(gripper_mesh_for_grasp)
    mesh_list.append(zero_frame)    

scene = trimesh_show(part_points_list)
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)

In [None]:
# Multi Object Grasp + Reconstruction Visualization

mesh_list=list()
part_points_list=list()

shape_tr = 0

for index in range(0,4):
    shape_Tr=[0,0,shape_tr*2]
    fc_color = np.random.rand(3) * 255 
    
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])
    
    mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale, 
                            faces=sdf_dataset.meshes_transformed[index].faces,face_colors=fc_color)
    mesh3.apply_translation(shape_Tr)  

    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[index]
    RT_inv = find_inverse_RT_4x4(RT)   
        
    points = estimate_points_from_model(model, 2000, index, 64,0.5)
    points = ((RT_inv@ to_hom_np(points).T).T)[:,:3]
    points = ((RT_base@ to_hom_np(points).T).T)[:,:3]
    points[:,:3]+=shape_Tr
    part_points_list.append(points)    
        
    mesh_list.append(mesh3)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(RT_inv)
    gripper_mesh_for_grasp.apply_transform(RT_base)
    gripper_mesh_for_grasp.apply_translation(shape_Tr)    
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh_show_green(part_points_list)
scene.add_geometry(mesh_list)
#scene.show()

window = windowed.SceneViewer(scene)

In [None]:
# Multi Object Grasp + Reconstruction Visualization

mesh_list=list()
part_points_list=list()

shape_tr = 0

for index in range(0,4):
    shape_Tr=[0,0,shape_tr*2]
    fc_color = np.random.rand(3) * 255 
    
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])
    
    mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale, 
                            faces=sdf_dataset.meshes_transformed[index].faces,face_colors=fc_color)
    mesh3.apply_translation(shape_Tr)  

    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[index]
    RT_inv = find_inverse_RT_4x4(RT)   
        
    v,t,_ = estimate_mesh_from_model(model, 2000, index, 128, scale=0.6)
    points = ((RT_inv@ to_hom_np(v).T).T)[:,:3]
    points = ((RT_base@ to_hom_np(points).T).T)[:,:3]
    points[:,:3]+=shape_Tr
    mesh = trimesh.Trimesh(vertices=points, faces=t, face_colors=[0,100,0])
    points = estimate_points_from_model(model, 2000, index, 64,0.5)
    points = ((RT_inv@ to_hom_np(points).T).T)[:,:3]
    points = ((RT_base@ to_hom_np(points).T).T)[:,:3]
    points[:,:3]+=shape_Tr
    mesh_list.append(mesh)    
        
    mesh_list.append(mesh3)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(RT_inv)
    gripper_mesh_for_grasp.apply_transform(RT_base)
    gripper_mesh_for_grasp.apply_translation(shape_Tr)    
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
  
    
scene = trimesh.Scene()
scene.add_geometry(mesh_list)
# scene = trimesh_show_green(part_points_list)
# scene.add_geometry(mesh_list)
#scene.show()

window = windowed.SceneViewer(scene)

In [None]:
# Multi Object Grasp + Reconstruction Visualization THREE-VIEW

mesh_list=list()
sphere_base = trimesh.primitives.Sphere(radius = 0.5,subdivisions=6)
sphere_base2 = trimesh.primitives.Sphere(radius = 0.1,subdivisions=3)

shape_tr = 0 
for index in range(0,10):
    fc_color = np.random.rand(3) * 255
    v,t,_ = estimate_mesh_from_model(model, 2000, index, 128, scale=0.6)

    mesh = trimesh.Trimesh(vertices=v, faces=t, face_colors=fc_color)
    mesh.apply_translation([-2.5, 0, shape_tr*2.5])   
    
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])

    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[index]
    RT_inv = find_inverse_RT_4x4(RT)
    mesh2 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale,
                            face_colors=fc_color, faces=sdf_dataset.meshes_transformed[index].faces)
    mesh2.apply_translation([0, 0, shape_tr*2.5])
    
    mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale, 
                            faces=sdf_dataset.meshes_transformed[index].faces,face_colors=fc_color)

    mesh3.apply_translation([3, 0, shape_tr*2.5])  
    mesh_list.append(mesh)
    mesh_list.append(mesh2)
    mesh_list.append(mesh3)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation([0, 0, shape_tr*2.5])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(RT_inv)
    gripper_mesh_for_grasp.apply_translation([3, 0, shape_tr*2.5])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh.Scene()
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)