In [None]:
import sys
import os
grasp_transfer_path = os.getenv("GRASP_TRANSFER_SOURCE_DIR")
sys.path.insert(0,grasp_transfer_path)

In [None]:
from dataio import *
from networks import *
from torch.utils.data import DataLoader
from visualize import *
from trimesh.viewer import windowed
import time
from tqdm import tqdm
from sdf_estimator import *
import open3d as o3d
from scipy.spatial.transform import Rotation as R
import ipywidgets as widgets


In [None]:
from mesh_to_sdf import sample_sdf_near_surface

In [None]:
class Shape_SDF_Dataset_Interp(Dataset):
    def __init__(self,num_of_samples):
        """
        Shapes_path (<str>): Path to Aligned Shape Meshes
        """
        self.meshes = list()
        self.meshes_transformed = self.meshes
        self.sdf_points_pos = list()
        self.sdf_points_neg = list()
        self.sdf_values_pos = list()
        self.sdf_values_neg = list()
        self.pointclouds = list()
        self.pointcloud_normals = list()
        number_of_scenes = 2 * 4 * 8
        self.sampling_RTs = np.zeros((number_of_scenes,3,4))
        self.sampling_RTs[:,:3,:3]=np.eye(3)
        r1 = R.from_rotvec(np.array([np.pi/2, 0, 0]))
        self.num_of_samples = num_of_samples
        self.number_of_scenes = number_of_scenes
        obj_ind = 0
        with tqdm(total=64, miniters=1) as pbar:
            start_time = time.time()            
            for r in np.linspace(0.05,0.4,8):
                for h in np.linspace(0.25,1,4):
                    box = trimesh.primitives.Box(extents =(r*2 , r*2 ,h))

                    cylinder = trimesh.primitives.Cylinder(radius = r, height = h )
                    tr = np.eye(4) 
                    tr[:3,:3] =  r1.as_matrix()
                    for m  in [box, cylinder]:
                        _,scale = compute_unit_sphere_transform(m)
                        tr[:3,3] = tr[:3,:3]@np.array([r,0,0])
                        m.apply_transform(tr) 
                        translation,scale = compute_unit_sphere_transform(m)
                        self.meshes.append(m)
                        samples, fid  = m.sample(25000, return_index=True)                        
                        bary = trimesh.triangles.points_to_barycentric(
                            triangles=m.triangles[fid], points=samples)
                        # interpolate vertex normals from barycentric coordinates
                        normals = trimesh.unitize((m.vertex_normals[m.faces[fid]] *
                                                  trimesh.unitize(bary).reshape(
                                                      (-1, 3, 1))).sum(axis=1))
        #                 samples = (samples)  * scale 
                        self.pointclouds.append(samples)            
                        self.pointcloud_normals.append(normals)            
                        points, sdf_values = sample_sdf_near_surface(m, number_of_points=25000, scan_count=100, 
                                                                     scan_resolution=400,sign_method='depth') 
                        points=(points/scale - translation)
                        self.sdf_points_pos.append(points[sdf_values>0])
                        self.sdf_points_neg.append(points[sdf_values<0])   
                        pos_sdf_values = sdf_values[sdf_values>0]
                        neg_sdf_values = sdf_values[sdf_values<0]
        #                 pos_sdf_values = pos_sdf_values*20
                        pos_sdf_values[pos_sdf_values>=0.999999999999]=0.999999999999
        #                 neg_sdf_values = neg_sdf_values*20
                        neg_sdf_values[neg_sdf_values<=-0.999999999999]=-0.999999999999            
                        self.sampling_RTs[obj_ind,:3,:3] = np.eye(3)
                        self.sdf_values_pos.append(pos_sdf_values)
                        self.sdf_values_neg.append(neg_sdf_values)
                        obj_ind=obj_ind+1
                        pbar.update(1)
                        pbar.set_postfix(time=time.time() - start_time)
                        
    def __len__(self):
        return len(self.meshes)
            
    def sample_within_sphere(self,ind,refRT = np.eye(4)[:3,:], radius=0.5):
        points = refRT @ to_hom_np(np.copy(self.pointclouds[ind])).T
        points = (points.T)
        pc_inds = np.where(
            np.linalg.norm(points,axis=1)<radius
            )[0]   
        points = refRT @ to_hom_np(np.copy(self.sdf_points_pos[ind])).T
        points = (points.T)[:,:3]
        sdf_pos_inds = np.where(
            np.linalg.norm(points,axis=1)<radius
            )[0]   
        points = refRT @ to_hom_np(np.copy(self.sdf_points_neg[ind])).T
        points = (points.T)[:,:3]
        sdf_neg_inds = np.where(
            np.linalg.norm(points,axis=1)<radius
            )[0]   
        
        return pc_inds,sdf_pos_inds,sdf_neg_inds
    def set_sampling_RT(self,RTs):
        self.sampling_RTs = RTs
    def __getitem__(self,index):
        pc_inds,sdf_pos_inds,sdf_neg_inds = self.sample_within_sphere(index,self.sampling_RTs[index])
        scale=1
        random_points_on_surface = pc_inds[np.random.choice(np.arange(len(pc_inds)),self.num_of_samples//2)]
        random_points_pos = sdf_pos_inds[np.random.choice(np.arange(len(sdf_pos_inds)),self.num_of_samples//4)]
        random_points_neg = sdf_neg_inds[np.random.choice(np.arange(len(sdf_neg_inds)),self.num_of_samples//4)]
        
        
        x = torch.from_numpy(np.vstack([self.pointclouds[index][random_points_on_surface],
                       self.sdf_points_pos[index][random_points_pos],
                       self.sdf_points_neg[index][random_points_neg]])).float()

        y = {'sdf':torch.from_numpy(np.vstack([np.zeros((self.num_of_samples//2,1)),
                             scale*self.sdf_values_pos[index][random_points_pos].reshape(-1,1),
                             scale*self.sdf_values_neg[index][random_points_neg].reshape(-1,1)])).float(),
             'normals': torch.from_numpy(np.vstack([self.pointcloud_normals[index][random_points_on_surface],
                                                   -1*np.ones((self.num_of_samples//2,3))])).float(),
             }
        observations =  {'coords': x,
            'sdf': y['sdf'],
            'normals': y['normals'],
            'instance_idx':torch.Tensor([index]).squeeze().long()}
    
        ground_truth = {'sdf':observations['sdf'] ,
        'normals': observations['normals']}
        return observations, ground_truth            

In [None]:
sdf_dataset = Shape_SDF_Dataset_Interp(2000)
train_dataloader = DataLoader(sdf_dataset, shuffle=True,batch_size=4, num_workers=0, drop_last = True)

In [None]:
model = MyNet(64,latent_dim = 128, hyper_hidden_features=256,hidden_num=128)
model.to(device=torch.device('cuda:0'))
optim = torch.optim.Adam([
                {'params': model.sdf_net.parameters()},
                {'params': model.hyper_net.parameters()},
                {'params': model.latent_codes.parameters(), 'lr': 1e-4},
            ],
    lr=1e-4)

In [None]:
part_points_list = list()
gripper_mesh_list = list()
object_mesh_list = list()
gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
gripper_mesh.apply_scale(4)
for obj_ind in range(0,8):  
    
    pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(obj_ind,radius=0.5)
    
    points = np.copy(sdf_dataset.pointclouds[obj_ind][pc_inds][:4000])
    points[:,2]+=1.5*obj_ind
    part_points_list.append(points)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                                 faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation([0, 0, 1.5*obj_ind])
    gripper_mesh_list.append(gripper_mesh_for_grasp)
    
    obj_mesh_for_grasp = trimesh.Trimesh(vertices=sdf_dataset.meshes[obj_ind].vertices,
                                                 faces=sdf_dataset.meshes[obj_ind].faces,face_colors=[100,100,100])
    obj_mesh_for_grasp.apply_translation([0, 0, 1.5*obj_ind])
    object_mesh_list.append(obj_mesh_for_grasp)    
scene = trimesh_show(part_points_list)
scene.add_geometry(gripper_mesh_list)
scene.add_geometry(object_mesh_list)
window = windowed.SceneViewer(scene)

In [None]:
import numpy as np
import trimesh
from matplotlib import cm


def trimesh_show(np_pcd_list, color_list=None, rand_color=False, show=True):
    colormap = cm.get_cmap('brg', len(np_pcd_list))
    # colormap= cm.get_cmap('gist_ncar_r', len(np_pcd_list))
    colors = [
        (np.asarray(colormap(val)) * 255).astype(np.int32) for val in np.linspace(0.05, 0.95, num=len(np_pcd_list))
    ]
    if color_list is None:
        if rand_color:
            color_list = []
            for i in range(len(np_pcd_list)):
                color_list.append((np.random.rand(3) * 255).astype(np.int32).tolist() + [255])
        else:
            color_list = colors
    
    tpcd_list = []
    for i, pcd in enumerate(np_pcd_list):
        tpcd = trimesh.PointCloud(pcd)
        tpcd.colors = np.tile(color_list[i], (tpcd.vertices.shape[0], 1))

        tpcd_list.append(tpcd)
    
    scene = trimesh.Scene()
    scene.add_geometry(tpcd_list)
    if show:
        scene.show() 

    return scene
def trimesh_show_green(np_pcd_list, color_list=None, rand_color=False, show=True):
    # colormap= cm.get_cmap('gist_ncar_r', len(np_pcd_list))
    tpcd_list = []
    for i, pcd in enumerate(np_pcd_list):
        tpcd = trimesh.PointCloud(pcd)
        tpcd.colors = np.tile(np.array([0,255,0]), (tpcd.vertices.shape[0], 1))

        tpcd_list.append(tpcd)
    
    scene = trimesh.Scene()
    scene.add_geometry(tpcd_list)
    if show:
        scene.show() 

    return scene

In [None]:
total_steps=0
sphere_base = trimesh.primitives.Sphere(radius = 0.3,subdivisions=6)
epochs=2000
with tqdm(total=len(train_dataloader) * epochs) as pbar:
    train_losses = []
    for epoch in range(epochs):
        RTs = model.get_refine_poses(torch.arange(0,64).cuda())[0].cpu().detach().numpy()
        sdf_dataset.set_sampling_RT(RTs[:,:3,:4])        
        model.train()
        for step, (model_input, gt) in enumerate(train_dataloader):
            start_time = time.time()
            model_input = {key: value.cuda() for key, value in model_input.items()}
            gt = {key: value.cuda() for key, value in gt.items()}

            losses = model(model_input,gt,epoch)

            train_loss = 0.
            for loss_name, loss in losses.items():
                single_loss = loss.mean()
                if epoch %100== 0 and step==0:
                    print(loss_name,single_loss)
                train_loss += single_loss

            train_losses.append(train_loss.item())
            optim.zero_grad()
            train_loss.backward()
            optim.step()
            pbar.update(1)
            pbar.set_postfix(loss=train_loss, time=time.time() - start_time, epoch=epoch)
            total_steps += 1
            
        if epoch>0 and epoch%400==0 and visualize:
            try:
                for index in range(2):
                    subject_idx = index
                    index = index
                    v,t,_ = estimate_mesh_from_model(model, epoch, subject_idx, 64, 0.6)
                    show_mesh(v-1,t)
            except:
                print("Data not good enough yet")

In [None]:
model.eval()
for index in range(56,64):
    subject_idx = index
    index = index
    v,t,_ = estimate_mesh_from_model(model, epoch, subject_idx, 64, 0.5)
    show_mesh(v,t)

In [None]:
contour_plot(model,500)

In [None]:
output_path = '../outputs/cuboid'

os.mkdir(output_path)
torch.save(model.state_dict(), output_path+'/model')


In [None]:
model.load_state_dict(torch.load("../outputs/cuboid/model"))#my_model
model.eval()

In [None]:
for k in range(28,32):
    mesh_list=estimate_mesh_traj_from_model(model, 2000, [k*2,k*2+1], 128,scale=0.5)
    draw_mesh_list(mesh_list,output_path+"/interp_"+str(k*2)+''+str(k*2+1))

In [None]:
for first in [0,1]:
    for second in [6,7]:    
        mesh_list=estimate_mesh_traj_from_model(model, 2000, [first,second], 64,scale = 0.5)
        draw_mesh_list(mesh_list,output_path+"/interp_"+str(first)+'_'+str(second))

In [None]:
mesh_list=list()
part_points_list=list()
gripper_mesh = get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
shape_tr = 0
for index in [0]:
    fc_color = np.random.rand(3) * 255 
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])
    
    mesh2 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices,
                            face_colors=fc_color, faces=sdf_dataset.meshes_transformed[index].faces)
    
        
    pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(index)
    
    points = np.copy(sdf_dataset.pointclouds[index][pc_inds][:16000])
    part_points_list.append(points)
    
    
    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[index]
    R_,T_ = RT[:3,:3], RT[:3,3]
    RT_inv = np.eye(4)
    RT_inv[:3,:3] = np.linalg.inv(R_)
    RT_inv[:3,3] =- RT_inv[:3,:3]@ T_
    
    mesh_list.append(mesh2)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh_show(part_points_list)
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)

In [None]:
mesh_list=list()
part_points_list=list()
sphere_base = trimesh.primitives.Sphere(radius = 0.5,subdivisions=6)
sphere_base2 = trimesh.primitives.Sphere(radius = 0.1,subdivisions=3)

shape_tr = 0
gripper_mesh = get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.45, 0.40, score=1, color=None)
for index in range(0,8):
    shape_Tr=[0,0,shape_tr*2]
    fc_color = np.random.rand(3) * 255 

    mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices, 
                            faces=sdf_dataset.meshes_transformed[index].faces,face_colors=fc_color)
    mesh3.apply_translation(shape_Tr)  
    points = estimate_points_from_model(model, 2000, index, 64,0.4)
    points = ((RT_inv@ to_hom_np(points).T).T)[:,:3]
    points[:,:3]+=shape_Tr
    part_points_list.append(points)    
    
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation(shape_Tr)    
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh_show_green(part_points_list)
scene.add_geometry(mesh_list)
#scene.show()

window = windowed.SceneViewer(scene)