In [1]:
import torch 
import numpy as np
import open3d as o3d
from neurals.network import LargeScoreFunction

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


pybullet build time: May 20 2022 19:44:17


In [2]:
def project_to_pcd(tip_poses, kd_tree, points):
    new_poses = np.zeros_like(tip_poses)
    for i, tip_pose in enumerate(tip_poses):
        idx = kd_tree.search_knn_vector_3d(tip_pose,1)[1]
        new_poses[i] = points[idx]
    return new_poses
        
def project_vis_grasp(grasp, point_cloud):
    points = np.asarray(point_cloud.points)
    kd_tree = o3d.geometry.KDTreeFlann(point_cloud)
    projected_grasp = project_to_pcd(grasp, kd_tree, points)
    tip_vis = []
    for point in projected_grasp:
        sp = o3d.geometry.TriangleMesh.create_sphere(0.01)
        sp.translate(point)
        tip_vis.append(sp)
    # Down sample the pointcloud
    idx = np.random.choice(len(points), 1024, replace=False)
    points_down = points[idx]
    return projected_grasp, points_down, tip_vis + [point_cloud]

In [3]:
score_function = LargeScoreFunction(2)
score_function.load_state_dict(torch.load("neurals/pretrained_score_function/only_score_model_larger_model_2980.pth"))
score_function.eval()

LargeScoreFunction(
  (pcn): PCN(
    (conv1): Conv1d(3, 32, kernel_size=(1,), stride=(1,))
    (conv2): Conv1d(32, 64, kernel_size=(1,), stride=(1,))
    (conv3): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
    (conv4): Conv1d(128, 128, kernel_size=(1,), stride=(1,))
  )
  (fc1): Linear(in_features=134, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=256, bias=True)
  (out): Linear(in_features=256, out_features=1, bias=True)
)

In [4]:
pcd = o3d.io.read_point_cloud("data/seeds/pointclouds/pose_5_pcd.ply")
grasp = np.array([[0.15, -0.05, 0.05],[0.2, 0.0, -0.03]])
projected_grasp, points, vis = project_vis_grasp(grasp, pcd)
condition = torch.from_numpy(projected_grasp).view(1,-1).float()
points = torch.from_numpy(points).view(1,1024,3).float()
print(score_function.pred_score(points, condition))

tensor([[-4.9595]], grad_fn=<AddmmBackward0>)


In [5]:
o3d.visualization.draw_geometries(vis)