In [None]:
import sys
import os
grasp_transfer_path = os.getenv("GRASP_TRANSFER_SOURCE_DIR")
sys.path.insert(0,grasp_transfer_path)

In [None]:
from dataio import *
from networks import *
from torch.utils.data import DataLoader
from visualize import *
from trimesh.viewer import windowed
import time
from sdf_estimator import *
from utility import *
from tqdm import tqdm
import torch
import numpy as np
import random
import pickle
import torch
from skimage import measure
import numpy as np
import plotly.graph_objects as go
import sys, os
import plotly.offline    
import numpy as np
from scipy.spatial import cKDTree as KDTree
import trimesh
from scipy.spatial.transform import Rotation as R

def compute_trimesh_chamfer(gt_mesh, gen_mesh, num_mesh_samples=30000):
    """
    This function computes a symmetric chamfer distance, i.e. the sum of both chamfers.
    gt_points: trimesh.points.PointCloud of just poins, sampled from the surface (see
               compute_metrics.ply for more documentation)
    gen_mesh: trimesh.base.Trimesh of output mesh from whichever autoencoding reconstruction
              method
    Reference: DeepSDF, https://github.com/facebookresearch/DeepSDF/blob/main/deep_sdf/metrics/chamfer.py
    """

    gt_points_sampled = trimesh.sample.sample_surface(gt_mesh, num_mesh_samples)[0]
    
    gen_points_sampled = trimesh.sample.sample_surface(gen_mesh, num_mesh_samples)[0]


    # one direction
    gen_points_kd_tree = KDTree(gen_points_sampled)
    one_distances, one_vertex_ids = gen_points_kd_tree.query(gt_points_sampled)
    gt_to_gen_chamfer = np.mean(np.square(one_distances))

    # other direction
    gt_points_kd_tree = KDTree(gt_points_sampled)
    two_distances, two_vertex_ids = gt_points_kd_tree.query(gen_points_sampled)
    gen_to_gt_chamfer = np.mean(np.square(two_distances))

    return gt_to_gen_chamfer + gen_to_gt_chamfer


In [None]:
num_of_scenes = 64
for obj_name in ['mugs','chairs','planes','cars']: # 
    for seed in [10,20,30,40,50]:
        torch.manual_seed(seed)
        random.seed(seed)
        np.random.seed(seed)
        path_to_mugs = glob.glob("../datasets/"+ obj_name +"/*")
        obj_list= [obj_path for obj_path in np.random.choice(path_to_mugs, num_of_scenes,replace=False)]
        sdf_dataset = Shape_SDF_Dataset_Noisy(obj_list,num_of_scenes,8000)
        dataloader = DataLoader(sdf_dataset, shuffle=True,batch_size=4,
                                num_workers=0, drop_last = True)
    
        output_path = '../outputs/train_'+ obj_name +'_'+ str(seed)
        
        os.mkdir(output_path)
        with open(output_path +'/obj_list.pickle', 'wb') as handle:
            pickle.dump(obj_list, handle, protocol=pickle.HIGHEST_PROTOCOL)
        np.save(output_path +'/random_poses', sdf_dataset.random_poses)
        model = MyNet(num_of_scenes,affine=False)
        model.to(device=torch.device('cuda:0'))
        train_dataloader=dataloader
        optim = torch.optim.Adam([
                        {'params': model.sdf_net.parameters()},
                        {'params': model.hyper_net.parameters()},
                        {'params': model.latent_codes.parameters(), 'lr': 1e-3},
                        {'params': model.se3_refine.parameters(), 'lr': 1e-3},
                    ],
            lr=1e-4)
        total_steps=0
        epochs=2000
        with tqdm(total=len(train_dataloader) * epochs) as pbar:
            train_losses = []
            for epoch in range(epochs):
                model.train()
                for step, (model_input, gt) in enumerate(train_dataloader):
                    start_time = time.time()
                    model_input = {key: value.cuda() for key, value in model_input.items()}
                    gt = {key: value.cuda() for key, value in gt.items()}

                    losses = model(model_input,gt,epoch)

                    train_loss = 0.
                    for loss_name, loss in losses.items():
                        single_loss = loss.mean()
                        train_loss += single_loss

                    train_losses.append(train_loss.item())
                    optim.zero_grad()
                    train_loss.backward()
                    optim.step()
                    pbar.update(1)
                    pbar.set_postfix(loss=train_loss.item(), time=time.time() - start_time, epoch=epoch)
                    total_steps += 1
        RTs = model.get_refine_poses(torch.arange(0,num_of_scenes).cuda())[0].cpu().detach().numpy()
        RT_base = find_inverse_RT_3x4(model.get_refine_poses(torch.arange(0,64).cuda(),affine=False).cpu().detach().numpy()[0])

        torch.save(model.state_dict(), output_path +'/model')
        try:
            chamfer_distances_reconstructed = np.zeros((num_of_scenes-1,)) 
            chamfer_distances_perturbed = np.zeros((num_of_scenes-1,)) 
            chamfer_distances_perturbed_transformed = np.zeros((num_of_scenes-1,)) 
            obj_ind_counter = 0
            for index in range(1,num_of_scenes-1): # 
                #
                color = np.random.rand(3) * 255
                v,t,_ = estimate_mesh_from_model(model, 2000, index, 128)

                mesh2 = trimesh.Trimesh(vertices=sdf_dataset.meshes[index].vertices, faces=sdf_dataset.meshes[index].faces,face_colors = color)
                translations,scale = compute_unit_sphere_transform(mesh2)
                
                mesh3 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices, faces=sdf_dataset.meshes_transformed[index].faces,face_colors = color)
                mesh = trimesh.Trimesh(vertices=v/scale-translations, faces=t,face_colors = color)   
                mesh.apply_transform(RT_base)

                transformed_transformed_points = (RTs[index]@to_hom_np(sdf_dataset.meshes_transformed[index].vertices).T).T    
                mesh4 = trimesh.Trimesh(vertices=transformed_transformed_points, 
                                        faces=sdf_dataset.meshes_transformed[index].faces,face_colors = color)  
                
                chamfer_distances_reconstructed[index-1] = compute_trimesh_chamfer(mesh,mesh2)
                chamfer_distances_perturbed[index-1] = compute_trimesh_chamfer(mesh3,mesh2)
                chamfer_distances_perturbed_transformed[index-1] = compute_trimesh_chamfer(mesh4,mesh2)
                del mesh,mesh2,mesh3,mesh4
            np.save(output_path +'/chamfer_distances_reconstructed',
                    chamfer_distances_reconstructed)
            np.save(output_path +'/chamfer_distances_perturbed',
                    chamfer_distances_perturbed)
            np.save(output_path +'/chamfer_distances_perturbed_transformed',
                    chamfer_distances_perturbed_transformed)            
        except:
            print("There are some errors in " + obj_name +'_'+ str(seed))
        del model, sdf_dataset

In [None]:
cdist_reconstructed_dict = dict()
cdist_perturbed_dict = dict()
cdist_perturbed_transformed_dict = dict()
for obj_name in ['mugs','chairs','planes','cars']:
    cdist_reconstructed_dict[obj_name] = list()
    cdist_reconstructed_no_smoothing_dict[obj_name] = list()
    cdist_perturbed_dict[obj_name] = list()
    cdist_perturbed_transformed_dict[obj_name] = list()
    for seed in [10,20,30,40,50]:
        output_path = '../outputs/train_'+ obj_name +'_'+ str(seed)
        cdist_reconstructed = np.load(output_path +'/chamfer_distances_reconstructed.npy')
        cdist_perturbed = np.load(output_path +'/chamfer_distances_perturbed.npy')        
        cdist_perturbed_transformed = np.load(output_path +'/chamfer_distances_perturbed_transformed.npy')

        cdist_reconstructed_dict[obj_name].append(np.mean(cdist_reconstructed))
        cdist_perturbed_dict[obj_name].append(np.mean(cdist_perturbed))
        cdist_perturbed_transformed_dict[obj_name].append(np.mean(cdist_perturbed_transformed))        


In [None]:
from tabulate import tabulate

rows= list()
for pair in ['mugs','chairs','planes','cars']:
    row_name = pair
    row = list()
    row.append(row_name)
    row.append(str(round(np.mean(cdist_reconstructed_dict[row_name]),4))+ ' \pm ' +
               str(round(np.std(cdist_reconstructed_dict[row_name]),4)))
    row.append(str(round(np.mean(cdist_perturbed_dict[row_name]),4))+ ' \pm ' +
               str(round(np.std(cdist_perturbed_dict[row_name]),4)))    
    rows.append(row)
col_names = ['Object Names','Our Method','Pre-Alignment']
  
#display table
print(tabulate(rows, headers=col_names, tablefmt="fancy_grid"))

# For genering latex code for the table
# rows=list()
# for pair in ['mugs','chairs','planes','cars']:
#     row_name = pair
#     row = ''
#     row += row_name + ' & '
#     row += str(round(np.mean(cdist_reconstructed_dict[row_name]),4))+ ' $\pm$ ' + \
#                str(round(np.std(cdist_reconstructed_dict[row_name]),4))
#     row += ' & '
#     row += str(round(np.mean(cdist_perturbed_dict[row_name]),4))+ ' $\pm$ ' + \
#                str(round(np.std(cdist_perturbed_dict[row_name]),4))    
#     print(row)


In [None]:
obj_name = 'mugs'
seed=10
num_of_scenes=64
output_path = '../outputs/train_'+ obj_name +'_'+ str(seed)

with open(output_path+'/obj_list.pickle', 'rb') as handle:
    obj_list = pickle.load(handle)
random_poses = np.load(output_path +'/random_poses.npy')
sdf_dataset = Shape_SDF_Dataset_Noisy(obj_list,num_of_scenes,8000,random_poses=random_poses)
dataloader = DataLoader(sdf_dataset, shuffle=True,batch_size=4, pin_memory=True, num_workers=0, drop_last = True)
model = MyNet(num_of_scenes,affine=False)
model.to(device=torch.device('cuda:0'))
model.load_state_dict(torch.load(output_path+'/model'))#my_model
model.eval()

RT_base = find_inverse_RT_3x4(model.get_refine_poses(torch.arange(0,64).cuda(),affine=False).cpu().detach().numpy()[0])
translation_base, scale_base = compute_unit_sphere_transform(sdf_dataset.meshes[0])

In [None]:
R_viz_fix=np.eye(4)

R_viz_fix[:3,:3] = R.from_rotvec(np.array([ 0,-5*np.pi/6, 0])).as_matrix() @R.from_rotvec(np.array([ -np.pi/9,0, -np.pi/9])).as_matrix()

# Used Camera/Color Fixes: R_viz_fix -- color
### MUG ::: R.from_rotvec(np.array([ 1*np.pi/6,0, 0])).as_matrix() [125,50,200]
### CAR ::: R.from_rotvec(np.array([0, -7*np.pi/9, 0])).as_matrix()@R.from_rotvec(np.array([-np.pi/6, 0, -np.pi/6])).as_matrix(), [50,50,250]
### CHAIRS ::: R.from_rotvec(np.array([ 0,-5*np.pi/6, 0])).as_matrix() @R.from_rotvec(np.array([ -np.pi/9,0, -np.pi/9])).as_matrix() [50,200,125]
### PLANES ::: R.from_rotvec(np.array([ 0,-5*np.pi/6, 0])).as_matrix() @R.from_rotvec(np.array([ -np.pi/9,0, -np.pi/9])).as_matrix(), [175,100,50]
color =  [175,100,50]
for index in range(4):
    print(index)
    mesh_list=list()

    translation, scale = compute_unit_sphere_transform(sdf_dataset.meshes[index])

    mesh = trimesh.Trimesh(vertices=sdf_dataset.meshes[index].vertices*scale, 
                            faces=sdf_dataset.meshes[index].faces,
                            face_colors = [color[0],color[1],color[2],125])
    mesh2 = trimesh.Trimesh(vertices=sdf_dataset.meshes_transformed[index].vertices*scale,
                            faces=sdf_dataset.meshes_transformed[index].faces,
                            face_colors = color)
    mesh.apply_transform(R_viz_fix)
    mesh2.apply_transform(R_viz_fix)

    translation,scale = compute_unit_sphere_transform(sdf_dataset.meshes[index])

    mesh4 = trimesh.Trimesh(vertices=sdf_dataset.meshes[index].vertices*scale, 
                            faces=sdf_dataset.meshes[index].faces,
                            face_colors = [color[0],color[1],color[2],125])

    v,t,_ = estimate_mesh_from_model(model, 1000, index, 128)
    mesh3 = trimesh.Trimesh(vertices=v, faces=t,face_colors = color)
    mesh3.apply_translation(-scale * translation)  
    mesh3.apply_transform(RT_base)
    mesh3.apply_transform(R_viz_fix)
    mesh4.apply_transform(R_viz_fix)

    mesh3.apply_translation([2.5, 0, 0])  
    mesh4.apply_translation([2.5, 0, 0])  

    mesh_list.append(mesh)
    mesh_list.append(mesh2)
    mesh_list.append(mesh3)
    mesh_list.append(mesh4)

    scene = trimesh.Scene()
    scene.add_geometry(mesh_list)
    window = windowed.SceneViewer(scene,smooth=False)

In [None]:
# For Better Looking Meshes
import pyrender
mesh_list_pyrender = list()
scene = pyrender.Scene()
for mesh in mesh_list:
    mesh_pyrender = pyrender.Mesh.from_trimesh(mesh, smooth=False)
    scene.add(mesh_pyrender)
cam = pyrender.PerspectiveCamera(yfov=(np.pi / 3.0))
cam_pose = np.array([
    [1,  0,  0, 0],
    [0,  1,  0, 0.0],
    [0,  0,  1, 2],
    [0,  0,  0, 1.0]
])
cam_node = scene.add(cam, pose=cam_pose)

pyrender.Viewer(scene,central_node=mesh, use_raymond_lighting=True)