In [None]:
import sys
import os
grasp_transfer_path = os.getenv("GRASP_TRANSFER_SOURCE_DIR")
sys.path.insert(0,grasp_transfer_path)

In [None]:
from dataio import *
from networks import *
from torch.utils.data import DataLoader
from visualize import *
from trimesh.viewer import windowed
import time
from tqdm.autonotebook import tqdm
from sdf_estimator import *
import open3d as o3d
from scipy.spatial.transform import Rotation as R
import ipywidgets as widgets
import pickle
import yaml

from utility import *


In [None]:
visualize = True
reuse = False
part_name = "handle" # "rim"
with open("../configs/part_dataset_info.yaml", 'r') as stream:
    dataset_config = yaml.safe_load(stream)
train = dataset_config['train_list']
val = dataset_config['val_list']
part = dataset_config['part_'+part_name]    
output_name = "../outputs/"+part_name

In [None]:
sdf_dataset = Part_SDF_Dataset(train,64,2000)
sdf_dataset_val = Part_SDF_Dataset(val,16,500)

RT = np.array(part['RT'])
sdf_dataset.transform_all_pc_with_Rt(RT)
sdf_dataset_val.transform_all_pc_with_Rt(RT)

train_dataloader = DataLoader(sdf_dataset, shuffle=True,batch_size=16, num_workers=0, drop_last = True)
val_dataloader = DataLoader(sdf_dataset_val, shuffle=True,batch_size=16, num_workers=0, drop_last = True)

In [None]:
model = MyNet(64,latent_dim = 128, hyper_hidden_features=256,hidden_num=128)
model.to(device=torch.device('cuda:0'))
optim = torch.optim.Adam([
                {'params': model.sdf_net.parameters()},
                {'params': model.hyper_net.parameters()},
                {'params': model.latent_codes.parameters(), 'lr': 1e-4},
                {'params': model.se3_refine.parameters(), 'lr': 1e-3},
                {'params': model.affine_tr.parameters(), 'lr': 1e-3}
            ],
    lr=1e-4)

In [None]:
## Visualize

part_points_list = list()
mesh_list = list()

for obj_ind in range(30,36):  
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    obj_mesh = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[obj_ind].vertices,
                            face_colors = np.random.rand(3) * 255,faces=sdf_dataset.meshes_transformed[obj_ind].faces)
    obj_mesh.apply_scale(scale)
    obj_mesh.apply_translation([0,0,-2*obj_ind])  

    pc_inds, sdf_pos_inds, sdf_neg_inds= sdf_dataset.sample_within_sphere(obj_ind)
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    points = np.copy(sdf_dataset.pointclouds[obj_ind][pc_inds][:])
    points[:,2]-=2*obj_ind
    part_points_list.append(points)

    gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
    gripper_mesh.apply_scale(4)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                             faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation([0,0,-2*obj_ind])
    zero_frame =  get_coordinate_frame_mesh(0.2)
    zero_frame.apply_translation([0,0,-2*obj_ind])
    mesh_list.append(obj_mesh)
    mesh_list.append(gripper_mesh_for_grasp)
    mesh_list.append(zero_frame)

scene = trimesh_show(part_points_list)
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)

In [None]:
tik=time.time()
total_steps=0
sphere_base = trimesh.primitives.Sphere(radius = 0.3,subdivisions=6)
epochs=2000
with tqdm(total=len(train_dataloader) * epochs) as pbar:
    train_losses = []
    for epoch in range(epochs):
        RTs = model.get_refine_poses(torch.arange(0,64).cuda())[0].cpu().detach().numpy()
        sdf_dataset.set_sampling_RT(RTs[:,:3,:4])   
        sdf_dataset.set_curr_epoch(epoch)   
        
        model.train()
        for step, (model_input, gt) in enumerate(train_dataloader):
            start_time = time.time()
            model_input = {key: value.cuda() for key, value in model_input.items()}
            gt = {key: value.cuda() for key, value in gt.items()}

            losses = model(model_input,gt,epoch)
            train_loss = 0.
            for loss_name, loss in losses.items():
                single_loss = loss.mean()
                if epoch %100== 0 and step==0:
                    print(loss_name,single_loss)
                train_loss += single_loss

            train_losses.append(train_loss.item())
            optim.zero_grad()
            train_loss.backward()
            optim.step()
            pbar.update(1)
            pbar.set_postfix(loss=train_loss.item(), time=time.time() - start_time, epoch=epoch)
            total_steps += 1
            
        if epoch>0 and epoch%400==0 and visualize:
            try:
                for index in range(2):
                    subject_idx = index
                    index = index
                    v,t,_ = estimate_mesh_from_model(model, epoch, subject_idx, 64, 0.4)
                    show_mesh(v,t)
            except:
                print("Data not good enough yet")
tok=time.time()
print(tok-tik)              

In [None]:
model.eval()
for index in range(1):
    subject_idx = index
    index = index
    v,t,_ = estimate_mesh_from_model(model, 1000, subject_idx, 64, 0.4)
    show_mesh(v,t)

# Save Model

In [None]:
if not os.path.exists(output_name):
    os.mkdir(output_name)
torch.save(model.state_dict(), output_name+"/model")

# Load Model

In [None]:
model.load_state_dict(torch.load(output_name+"/model"))#my_model
model.eval()

# Validation Visulizations

In [None]:
contour_plot(model,1000) # Change to 0 or any other intermediate value (Low Frequency Positional encodings) to see the smoother output

In [None]:
RTs = model.get_refine_poses(torch.arange(0,64).cuda(),affine=False).cpu().detach().numpy()
sdf_dataset.set_sampling_RT(RTs[:,:3,:4])    
RT_base=np.eye(4)
RT_base[:3,:]= RTs[0]

In [None]:
### Interpolation Between different Embeddings
## This creates an HTML file in directory $output_name
for index, obj_id in enumerate([3,4,5]):
    mesh_list=estimate_mesh_traj_from_model(model, 2000,[0,obj_ind], 50, 64)
    draw_mesh_list(mesh_list,output_name+"/interp_"+str(obj_ind+1))

In [None]:
# Multi Object Grasp Visualization
part_points_list = list()
mesh_list = list()

for obj_ind in range(20,30):  
    color = np.random.rand(3) * 255
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    
    obj_mesh = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[obj_ind].vertices,
                            face_colors = color,faces=sdf_dataset.meshes_transformed[obj_ind].faces)
    obj_mesh.apply_scale(scale)
    obj_mesh.apply_translation([0,0,-1.5*obj_ind])  

    pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(obj_ind,radius=0.8)
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    points = np.copy(sdf_dataset.pointclouds[obj_ind][pc_inds][:])
    points[:,2]-=1.5*obj_ind
    part_points_list.append(points)

    gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
    gripper_mesh.apply_scale(4)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                             faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation([0,0,-1.5*obj_ind])
    zero_frame =  get_coordinate_frame_mesh(0.2)
    zero_frame.apply_translation([0,0,-1.5*obj_ind])
    mesh_list.append(obj_mesh)
    mesh_list.append(gripper_mesh_for_grasp)
    mesh_list.append(zero_frame)
    
    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[obj_ind]
    
    obj_mesh = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[obj_ind].vertices,
                            face_colors = color,faces=sdf_dataset.meshes_transformed[obj_ind].faces)
    obj_mesh.apply_scale(scale)
    obj_mesh.apply_translation([3,0,-1.5*obj_ind])  

    pc_inds,sdf_pos_inds,sdf_neg_inds= sdf_dataset.sample_within_sphere(obj_ind,RT[:3,:],radius=0.8)
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes[obj_ind])
    points = np.copy(sdf_dataset.pointclouds[obj_ind][pc_inds][:])
    points[:,2]-=1.5*obj_ind
    points[:,0]+=3
    part_points_list.append(points)

    gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
    gripper_mesh.apply_scale(4)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                             faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(find_inverse_RT_4x4(RT))    
    gripper_mesh_for_grasp.apply_transform(RT_base)
    
    gripper_mesh_for_grasp.apply_translation([3,0,-1.5*obj_ind])
    zero_frame =  get_coordinate_frame_mesh(0.2)
    zero_frame.apply_transform(find_inverse_RT_4x4(RT))    
    zero_frame.apply_transform(RT_base)

    zero_frame.apply_translation([3,0,-1.5*obj_ind])
    mesh_list.append(obj_mesh)
    mesh_list.append(gripper_mesh_for_grasp)
    mesh_list.append(zero_frame)    

scene = trimesh_show(part_points_list)
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)

In [None]:
# Multi Object Grasp + Reconstruction Visualization

mesh_list=list()
part_points_list=list()

shape_tr = 0

for index in range(20,30):
    shape_Tr=[0,0,shape_tr*2]
    fc_color = np.random.rand(3) * 255 
    
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])
    
    mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale, 
                            faces=sdf_dataset.meshes_transformed[index].faces,face_colors=fc_color)
    mesh3.apply_translation(shape_Tr)  

    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[index]
    RT_inv = find_inverse_RT_4x4(RT)   
        
    points = estimate_points_from_model(model, 2000, index, 64,0.5)
    points = ((RT_inv@ to_hom_np(points).T).T)[:,:3]
    points = ((RT_base@ to_hom_np(points).T).T)[:,:3]
    points[:,:3]+=shape_Tr
    part_points_list.append(points)    
        
    mesh_list.append(mesh3)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(RT_inv)
    gripper_mesh_for_grasp.apply_transform(RT_base)
    gripper_mesh_for_grasp.apply_translation(shape_Tr)    
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh_show_green(part_points_list)
scene.add_geometry(mesh_list)
#scene.show()

window = windowed.SceneViewer(scene)

In [None]:
def estimate_mesh_from_model(model, epoch, subject_idx, resolution = 64, scale = 1 ):
    with torch.no_grad():
        N=resolution
        max_batch=64 ** 3
        model.eval()

        # NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
        voxel_origin = [-1, -1, -1]
        voxel_size = 2.0 / (N - 1)

        overall_index = torch.arange(0, N ** 3, 1, out=torch.LongTensor())
        samples = torch.zeros(N ** 3, 4)

        # transform first 3 columns
        # to be the x, y, z index
        samples[:, 2] = overall_index % N
        samples[:, 1] = (overall_index.long() / N) % N
        samples[:, 0] = ((overall_index.long() / N) / N) % N

        # transform first 3 columns
        # to be the x, y, z coordinate
        samples[:, 0] = (samples[:, 0] * scale * voxel_size) + scale *voxel_origin[2]
        samples[:, 1] = (samples[:, 1] * scale * voxel_size) + scale *voxel_origin[1]
        samples[:, 2] = (samples[:, 2] * scale * voxel_size) + scale *voxel_origin[0]

        num_samples = N ** 3

        samples.requires_grad = False

        head = 0
        if subject_idx==-1:
            subject_idx = torch.Tensor([range(64)]).squeeze().long().cuda()[None,...]
            embedding = model.get_latent_code(subject_idx)
            embedding=embedding.mean(dim=1)
        else:
            subject_idx = torch.Tensor([subject_idx]).squeeze().long().cuda()[None,...]
            embedding = model.get_latent_code(subject_idx)
        while head < num_samples:
            sample_subset = samples[head : min(head + max_batch, num_samples), 0:3].cuda()[None,...]
            samples[head : min(head + max_batch, num_samples), 3] = (
                model.inference(sample_subset,embedding,epoch)['model_out']
                .squeeze()#.squeeze(1)
                .detach()
                .cpu()
            )
            head += max_batch

        sdf_values = samples[:, 3]
        sdf_values = sdf_values.reshape(N, N, N)
        v,t,n,_ = measure.marching_cubes(sdf_values.numpy(),0.01,step_size = 1,spacing=[2/N,2/N,2/N])
        return scale*(v-1),t,n

In [None]:
# Multi Object Grasp + Reconstruction Visualization THREE-VIEW
# Reconstruction - Before Alignment - After Alignment
mesh_list=list()
sphere_base = trimesh.primitives.Sphere(radius = 0.5,subdivisions=6)
sphere_base2 = trimesh.primitives.Sphere(radius = 0.1,subdivisions=3)

shape_tr = 0
for index in range(0,8):
    fc_color = np.random.rand(3) * 255
    v,t,_ = estimate_mesh_from_model(model, 1000, index, 64, scale=0.5)

    mesh = trimesh.Trimesh(vertices=v, faces=t, face_colors=fc_color)
    mesh.apply_translation([-2.5, 0, shape_tr*2.5])   
    
    _,scale = compute_unit_sphere_transform(sdf_dataset.meshes_transformed[index])

    RT= np.eye(4)
    RT[:3,:] = sdf_dataset.sampling_RTs[index]
    RT_inv = find_inverse_RT_4x4(RT)
    mesh2 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale,
                            face_colors=fc_color, faces=sdf_dataset.meshes_transformed[index].faces)
    mesh2.apply_translation([0, 0, shape_tr*2.5])
    
    mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale, 
                            faces=sdf_dataset.meshes_transformed[index].faces,face_colors=fc_color)

    mesh3.apply_translation([3, 0, shape_tr*2.5])  
    mesh_list.append(mesh)
    mesh_list.append(mesh2)
    mesh_list.append(mesh3)
    gripper_mesh =  get_gripper_simple_mesh(np.zeros((3,)), np.eye(3), 0.10, 0.10, score=1)
    gripper_mesh.apply_scale(4)
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_translation([0, 0, shape_tr*2.5])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(RT_inv)
    gripper_mesh_for_grasp.apply_translation([3, 0, shape_tr*2.5])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh.Scene()
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)

In [None]:
model_transfer = MyNetTransfer(16, model, latent_dim = 128)
model_transfer.to(device=torch.device('cuda:0'))
optim = torch.optim.Adam([
                {'params': model_transfer.latent_codes.parameters(), 'lr': 1e-3},
                {'params': model_transfer.se3_refine.parameters(), 'lr': 1e-3},
                {'params': model_transfer.affine_tr.parameters(), 'lr': 1e-3}
            ],
    lr=1e-3)


In [None]:
total_steps=0
sphere_base = trimesh.primitives.Sphere(radius = 0.3,subdivisions=6)
epochs=100
with tqdm(total=len(val_dataloader) * epochs) as pbar:
    train_losses = []
    for epoch in range(epochs):
        RTs,_ = model_transfer.get_refine_poses(torch.arange(0,16).cuda())
        RTs = RTs.cpu().detach().numpy()
        sdf_dataset_val.set_sampling_RT(RTs[:,:3,:4])        
        model_transfer.train()
        for step, (model_input, gt) in enumerate(val_dataloader):
            start_time = time.time()
            model_input = {key: value.cuda() for key, value in model_input.items()}
            gt = {key: value.cuda() for key, value in gt.items()}

            losses = model_transfer(model_input,gt,1000)

            train_loss = 0.
            for loss_name, loss in losses.items():
                single_loss = loss.mean()
                train_loss += single_loss

            train_losses.append(train_loss.item())
            optim.zero_grad()
            train_loss.backward()
            optim.step()
            pbar.update(1)
            pbar.set_postfix(loss=train_loss, time=time.time() - start_time, epoch=epoch)
            total_steps += 1

In [None]:
mesh_list=list()
sphere_base = trimesh.primitives.Sphere(radius = 0.5,subdivisions=6)
sphere_base2 = trimesh.primitives.Sphere(radius = 0.1,subdivisions=3)

shape_tr = 0
for index in range(0,16): 
    fc_color = np.random.rand(3) * 255
    v,t,_ = estimate_mesh_from_model(model_transfer, 1000, index, 64, 0.5)
    mesh = trimesh.Trimesh(vertices=v, faces=t, face_colors=fc_color)
    mesh.apply_translation([-2+shape_tr//2*4.5, -0.2, shape_tr%2*2.5])   

    _,scale = compute_unit_sphere_transform(sdf_dataset_val.meshes_transformed[index])

    mesh2 = trimesh.Trimesh(vertices=sdf_dataset_val.meshes_transformed[index].vertices*scale,
                            face_colors=fc_color, faces=sdf_dataset_val.meshes_transformed[index].faces)
    mesh2.apply_translation([shape_tr//2*4.5, 0, shape_tr%2*2.5])
    
    RT= np.eye(4)
    RT[:3,:] = sdf_dataset_val.sampling_RTs[index]
    RT_inv = find_inverse_RT_4x4(RT)
    mesh_list.append(mesh)
    mesh_list.append(mesh2)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175,75])
    
    gripper_mesh_for_grasp.apply_translation([shape_tr//2*4.5, 0, shape_tr%2*2.5])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    
    mesh_list.append(gripper_mesh_for_grasp)
    
    gripper_mesh_for_grasp = trimesh.Trimesh(vertices=gripper_mesh.vertices,
                                         faces=gripper_mesh.faces,face_colors=[175,175,175])
    gripper_mesh_for_grasp.apply_transform(RT_inv)
    gripper_mesh_for_grasp.apply_translation([shape_tr//2*4.5, 0, shape_tr%2*2.5])
    gripper_mesh_for_grasp.colors = np.tile(np.array([175,175,175]), (gripper_mesh_for_grasp.vertices.shape[0], 1))
    mesh_list.append(gripper_mesh_for_grasp)
    
    shape_tr=shape_tr+1
    
scene = trimesh.Scene()
scene.add_geometry(mesh_list)
window = windowed.SceneViewer(scene)