In [None]:
import time

import numpy as np
from IPython.display import clear_output
from pydrake.all import (
    PointCloud,
    Rgba,
    RigidTransform,
    RotationMatrix,
    Sphere,
    Concatenate,
    StartMeshcat,
    AbstractValue,
    AddMultibodyPlantSceneGraph,
    DiagramBuilder,
    JointSliders,
    LeafSystem,
    MeshcatPoseSliders,
    MeshcatVisualizer,
    MeshcatVisualizerParams,
    Parser,
    PointCloud,
    RandomGenerator,
    Rgba,
    RigidTransform,
    RotationMatrix,
    Simulator, 
    UniformlyRandomRotationMatrix,
)

from scipy.spatial import KDTree

from manipulation import running_as_notebook
from manipulation.meshcat_utils import AddMeshcatTriad
from manipulation.mustard_depth_camera_example import MustardExampleSystem
from manipulation.scenarios import AddFloatingRpyJoint, AddRgbdSensors, ycb
from manipulation.utils import ConfigureParser

In [None]:
# Start the visualizer.
meshcat = StartMeshcat()

In [None]:
def GraspCandidateCost(
    diagram,
    context,
    cloud,
    wsg_body_index=None,
    plant_system_name="plant",
    scene_graph_system_name="scene_graph",
    adjust_X_G=False,
    verbose=False,
    meshcat_path=None,
):
    """
    Args:
        diagram: A diagram containing a MultibodyPlant+SceneGraph that contains
            a free body gripper and any obstacles in the environment that we
            want to check collisions against. It should not include the objects
            in the point cloud; those are handled separately.
        context: The diagram context.  All positions in the context will be
            held fixed *except* the gripper free body pose.
        cloud: a PointCloud in world coordinates which represents candidate
            grasps.
        wsg_body_index: The body index of the gripper in plant.  If None, then
            a body named "body" will be searched for in the plant.

    Returns:
        cost: The grasp cost

    If adjust_X_G is True, then it also updates the gripper pose in the plant
    context.
    """
    plant = diagram.GetSubsystemByName(plant_system_name)
    plant_context = plant.GetMyMutableContextFromRoot(context)
    scene_graph = diagram.GetSubsystemByName(scene_graph_system_name)
    scene_graph_context = scene_graph.GetMyMutableContextFromRoot(context)
    if wsg_body_index:
        wsg = plant.get_body(wsg_body_index)
    else:
        wsg = plant.GetBodyByName("body")
        wsg_body_index = wsg.index()

    X_G = plant.GetFreeBodyPose(plant_context, wsg)

    # Transform cloud into gripper frame
    X_GW = X_G.inverse()
    p_GC = X_GW @ cloud.xyzs()

    # Crop to a region inside of the finger box.
    crop_min = [-0.05, 0.1, -0.00625]
    crop_max = [0.05, 0.1125, 0.00625]
    indices = np.all(
        (
            crop_min[0] <= p_GC[0, :],
            p_GC[0, :] <= crop_max[0],
            crop_min[1] <= p_GC[1, :],
            p_GC[1, :] <= crop_max[1],
            crop_min[2] <= p_GC[2, :],
            p_GC[2, :] <= crop_max[2],
        ),
        axis=0,
    )

    if meshcat_path:
        pc = PointCloud(np.sum(indices))
        pc.mutable_xyzs()[:] = cloud.xyzs()[:, indices]
        meshcat.SetObject(
            "planning/points", pc, rgba=Rgba(1.0, 0, 0), point_size=0.01
        )

    if adjust_X_G and np.sum(indices) > 0:
        p_GC_x = p_GC[0, indices]
        p_Gcenter_x = (p_GC_x.min() + p_GC_x.max()) / 2.0
        X_G.set_translation(X_G @ np.array([p_Gcenter_x, 0, 0]))
        plant.SetFreeBodyPose(plant_context, wsg, X_G)
        X_GW = X_G.inverse()

    query_object = scene_graph.get_query_output_port().Eval(
        scene_graph_context
    )

    # Check collisions between the gripper and the sink
    if query_object.HasCollisions():
        cost = np.inf
        if verbose:
            print("Gripper is colliding with the sink!\n")
            print(f"cost: {cost}")
        return cost

    # Check collisions between the gripper and the point cloud
    # must be smaller than the margin used in the point cloud preprocessing.
    margin = 0.0
    for i in range(cloud.size()):
        distances = query_object.ComputeSignedDistanceToPoint(
            cloud.xyz(i), threshold=margin
        )
        if distances:
            cost = np.inf
            if verbose:
                print("Gripper is colliding with the point cloud!\n")
                print(f"cost: {cost}")
            return cost

    n_GC = X_GW.rotation().multiply(cloud.normals()[:, indices])

    # Penalize deviation of the gripper from vertical.
    # weight * -dot([0, 0, -1], R_G * [0, 1, 0]) = weight * R_G[2,1]
    cost = 20.0 * X_G.rotation().matrix()[2, 1]

    # Reward sum |dot product of normals with gripper x|^2
    cost -= np.sum(n_GC[0, :] ** 2)
    if verbose:
        print(f"cost: {cost}")
        print(f"normal terms: {n_GC[0,:]**2}")
    return cost


class ScoreSystem(LeafSystem):
    def __init__(self, diagram, cloud, wsg_pose_index):
        LeafSystem.__init__(self)
        self._diagram = diagram
        self._context = diagram.CreateDefaultContext()
        self._plant = diagram.GetSubsystemByName("plant")
        self._plant_context = self._plant.GetMyMutableContextFromRoot(
            self._context
        )
        wsg = self._plant.GetBodyByName("body")
        self._wsg_body_index = wsg.index()
        self._wsg_pose_index = wsg_pose_index
        self._cloud = cloud
        self.DeclareAbstractInputPort(
            "body_poses", AbstractValue.Make([RigidTransform()])
        )
        self.DeclareForcedPublishEvent(self.Publish)

    def Publish(self, context):
        X_WG = self.get_input_port(0).Eval(context)[self._wsg_pose_index]
        self._plant.SetFreeBodyPose(
            self._plant_context,
            self._plant.get_body(self._wsg_body_index),
            X_WG,
        )
        GraspCandidateCost(
            self._diagram,
            self._context,
            self._cloud,
            verbose=True,
            meshcat_path="planning/cost",
        )
        clear_output(wait=True)


def process_point_cloud(diagram, context, cameras, bin_name):
    plant = diagram.GetSubsystemByName("plant")
    plant_context = plant.GetMyContextFromRoot(context)

    # Compute crop box.
    bin_instance = plant.GetModelInstanceByName(bin_name)
    bin_body = plant.GetBodyByName("bin_base", bin_instance)
    X_B = plant.EvalBodyPoseInWorld(plant_context, bin_body)
    margin = 0.001  # only because simulation is perfect!
    a = X_B.multiply(
        [-0.22 + 0.025 + margin, -0.29 + 0.025 + margin, 0.015 + margin]
    )
    b = X_B.multiply([0.22 - 0.1 - margin, 0.29 - 0.025 - margin, 2.0])
    crop_min = np.minimum(a, b)
    crop_max = np.maximum(a, b)

    pcd = []
    for i in range(3):
        cloud = diagram.GetOutputPort(f"{cameras[i]}_point_cloud").Eval(
            context
        )

        # Crop to region of interest.
        pcd.append(cloud.Crop(lower_xyz=crop_min, upper_xyz=crop_max))
        # Estimate normals
        pcd[i].EstimateNormals(radius=0.1, num_closest=30)

        # Flip normals toward camera
        camera = plant.GetModelInstanceByName(f"camera{i}")
        body = plant.GetBodyByName("base", camera)
        X_C = plant.EvalBodyPoseInWorld(plant_context, body)
        pcd[i].FlipNormalsTowardPoint(X_C.translation())

    # Merge point clouds.
    merged_pcd = Concatenate(pcd)

    # Voxelize down-sample.  (Note that the normals still look reasonable)
    return merged_pcd.VoxelizedDownSample(voxel_size=0.005)


def make_environment_model(
    directive=None, draw=False, rng=None, num_ycb_objects=0, bin_name="bin0"
):
    # Make one model of the environment, but the robot only gets to see the sensor outputs.
    if not directive:
        directive = "package://manipulation/two_bins_w_cameras.dmd.yaml"

    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0005)
    parser = Parser(plant)
    ConfigureParser(parser)
    parser.SetAutoRenaming(True)
    parser.AddModelsFromUrl(directive)

    for i in range(num_ycb_objects):
        object_num = rng.integers(len(ycb))
        parser.AddModelsFromUrl(
            f"package://manipulation/hydro/{ycb[object_num]}"
        )

    plant.Finalize()
    AddRgbdSensors(builder, plant, scene_graph)

    if draw:
        MeshcatVisualizer.AddToBuilder(
            builder,
            scene_graph,
            meshcat,
            MeshcatVisualizerParams(prefix="environment"),
        )

    diagram = builder.Build()
    context = diagram.CreateDefaultContext()

    if num_ycb_objects > 0:
        generator = RandomGenerator(rng.integers(1000))  # this is for c++
        plant_context = plant.GetMyContextFromRoot(context)
        bin_instance = plant.GetModelInstanceByName(bin_name)
        bin_body = plant.GetBodyByName("bin_base", bin_instance)
        X_B = plant.EvalBodyPoseInWorld(plant_context, bin_body)
        z = 0.2
        for body_index in plant.GetFloatingBaseBodies():
            tf = RigidTransform(
                UniformlyRandomRotationMatrix(generator),
                [rng.uniform(-0.15, 0.15), rng.uniform(-0.2, 0.2), z],
            )
            plant.SetFreeBodyPose(
                plant_context, plant.get_body(body_index), X_B.multiply(tf)
            )
            z += 0.1

        simulator = Simulator(diagram, context)
        simulator.AdvanceTo(2.0 if running_as_notebook else 0.1)
    elif draw:
        diagram.ForcedPublish(context)

    return diagram, context


# Another diagram for the objects the robot "knows about": gripper, cameras, bins.  Think of this as the model in the robot's head.
def make_internal_model():
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    ConfigureParser(parser)
    parser.AddModelsFromUrl("package://manipulation/clutter_planning.dmd.yaml")
    plant.Finalize()
    return builder.Build()


def grasp_score_inspector():
    meshcat.Delete()
    environment, environment_context = make_environment_model(
        directive="package://manipulation/clutter_mustard.dmd.yaml", draw=True
    )

    internal_model = make_internal_model()

    # Finally, we'll build a diagram for running our visualization
    builder = DiagramBuilder()
    plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.001)
    parser = Parser(plant)
    ConfigureParser(parser)
    parser.AddModelsFromUrl("package://manipulation/clutter_planning.dmd.yaml")
    AddFloatingRpyJoint(
        plant,
        plant.GetFrameByName("body"),
        plant.GetModelInstanceByName("gripper"),
    )
    plant.Finalize()

    meshcat.DeleteAddedControls()
    params = MeshcatVisualizerParams()
    params.prefix = "planning"
    visualizer = MeshcatVisualizer.AddToBuilder(
        builder, scene_graph, meshcat, params
    )

    cloud = process_point_cloud(
        environment,
        environment_context,
        ["camera0", "camera1", "camera2"],
        "bin0",
    )
    meshcat.SetObject("planning/cloud", cloud, point_size=0.003)

    score = builder.AddSystem(
        ScoreSystem(internal_model, cloud, plant.GetBodyByName("body").index())
    )
    builder.Connect(plant.get_body_poses_output_port(), score.get_input_port())

    lower_limit = [-1, -1, 0, -np.pi, -np.pi / 4.0, -np.pi / 4.0]
    upper_limit = [1, 1, 1, 0, np.pi / 4.0, np.pi / 4.0]
    q0 = [-0.05, -0.5, 0.25, -np.pi / 2.0, 0, 0]
    default_interactive_timeout = None if running_as_notebook else 1.0
    sliders = builder.AddSystem(
        JointSliders(
            meshcat,
            plant,
            initial_value=q0,
            lower_limit=lower_limit,
            upper_limit=upper_limit,
            decrement_keycodes=[
                "KeyQ",
                "KeyS",
                "KeyA",
                "KeyJ",
                "KeyK",
                "KeyU",
            ],
            increment_keycodes=[
                "KeyE",
                "KeyW",
                "KeyD",
                "KeyL",
                "KeyI",
                "KeyO",
            ],
        )
    )
    diagram = builder.Build()
    sliders.Run(diagram, default_interactive_timeout)
    meshcat.DeleteAddedControls()


grasp_score_inspector()

# Point cloud processing

I've produced a scene with multiple cameras looking at our favorite YCB mustard bottle.  I've taken the individual point clouds, estimated their normals, merged the point clouds, cropped then point clouds (to get rid of the geometry from the other cameras), then down-sampled the point clouds.  (The order is important!)

I've pushed all of the point clouds to meshcat, but with many of them set to not be visible by default.  Use the drop-down menu to turn them on and off, and make sure you understand basically what is happening on each of the steps.

In [33]:
def point_cloud_processing_example():
    # This just sets up our mustard bottle with three depth cameras positioned
    # around it.
    system = MustardExampleSystem()

    plant = system.GetSubsystemByName("plant")

    # Evaluate the camera output ports to get the images.
    context = system.CreateDefaultContext()
    plant_context = plant.GetMyContextFromRoot(context)

    meshcat.Delete()
    meshcat.DeleteAddedControls()
    meshcat.SetProperty("/Background", "visible", True)

    meshcat.AddButton("Stop Normal Estimation", "Escape")
    print(
        "Press ESC or the 'Stop Normal Estimation' button in Meshcat to continue"
    )
    
    meshcat.AddSlider(
        "query_x",
        min=-0.5,
        max=0.5,
        step=0.001,
        value=0,  # initial value
        decrement_keycode="ArrowLeft",
        increment_keycode="ArrowRight",
    )
    meshcat.AddSlider(
        "query_y",
        min=-0.5,
        max=0.5,
        step=0.001,
        value=0.2,  # initial value
        decrement_keycode="ArrowLeft",
        increment_keycode="ArrowRight",
    )
    meshcat.AddSlider(
        "query_z",
        min=-0.5,
        max=0.5,
        step=0.001,
        value=0.1,  # initial value
        decrement_keycode="ArrowLeft",
        increment_keycode="ArrowRight",
    )
    last_query_x= -1
    last_query_y= -1
    last_query_z= -1
    while meshcat.GetButtonClicks("Stop Normal Estimation") < 1:
        query_x = meshcat.GetSliderValue("query_x")
        query_y = meshcat.GetSliderValue("query_y")
        query_z = meshcat.GetSliderValue("query_z")
        
        if query_x == last_query_x and query_y == last_query_y and query_z == last_query_z:
            time.sleep(0.1)
            continue
        last_query_x = query_x
        last_query_y = query_y 
        last_query_z = query_z
        
        p_query = np.array([query_x, query_y, query_z])
        
        pcd = []
        for i in range(3):
            cloud = system.GetOutputPort(f"camera{i}_point_cloud").Eval(context)
            meshcat.SetObject(f"pointcloud{i}", cloud, point_size=0.001)
            meshcat.SetProperty(f"pointcloud{i}", "visible", False)

            # Crop to region of interest. # Would want to crop to region where the robot arm is located. 
            pcd.append(
                cloud.Crop(lower_xyz=[-0.3, -0.3, -0.3], upper_xyz=[0.3, 0.3, 0.3])
            )
            meshcat.SetObject(f"pointcloud{i}_cropped", pcd[i], point_size=0.001)
            meshcat.SetProperty(f"pointcloud{i}_cropped", "visible", False)

            pcd[i].EstimateNormals(radius=0.1, num_closest=30)

            camera = plant.GetModelInstanceByName(f"camera{i}")
            body = plant.GetBodyByName("base", camera)
            X_C = plant.EvalBodyPoseInWorld(plant_context, body)
            pcd[i].FlipNormalsTowardPoint(X_C.translation())

        # Merge point clouds.  (Note: You might need something more clever here for
        # noisier point clouds; but this can often work!)
        merged_pcd = Concatenate(pcd)
        meshcat.SetObject("merged", merged_pcd, point_size=0.001)
        meshcat.SetProperty("merged", "visible", False)

        
        # Voxelize down-sample.  (Note that the normals still look reasonable)
        down_sampled_pcd = merged_pcd.VoxelizedDownSample(voxel_size=0.005)
        meshcat.SetObject("down_sampled", down_sampled_pcd, point_size=0.001)
        #meshcat.SetProperty("down_sampled", "visible", False)
        
        meshcat.SetLineSegments(
            "down_sampled_normals",
            down_sampled_pcd.xyzs(),
            down_sampled_pcd.xyzs() + 0.01 * down_sampled_pcd.normals(),
        )
        meshcat.SetProperty("down_sampled_normals", "visible", False)
        
        
        kdtree = KDTree(down_sampled_pcd.xyzs().T)
        
        # Visualize the query point
        meshcat.SetObject("query", Sphere(0.005), Rgba(0, 1, 0))
        meshcat.SetTransform("query", RigidTransform(p_query))
        
        #To visualize the axis of the point of interest
        AddMeshcatTriad(meshcat, "Axis_of_point_of_interes", length=0.03, radius=0.0005)
        
        
        num_closest = 5
        neighbors = PointCloud(num_closest)

        # Find the nearest neighbors
        (distances, indices) = kdtree.query(
            p_query, k=num_closest
        )
        
        # Add the relevant closest points to p_query to the neighbors point cloud
        neighbors.resize(len(distances))
        neighbors.mutable_xyzs()[:] = down_sampled_pcd.xyzs()[:, indices]
        
        #visualize
        meshcat.SetObject(
            "neighbors", neighbors, rgba=Rgba(0, 0, 1), point_size=0.001
        )
        neighbors.EstimateNormals(radius=0.1, num_closest=30)
        neighbors.FlipNormalsTowardPoint(p_query)
        
        meshcat.SetLineSegments(
            "neighbors_normals",
            neighbors.xyzs(),
            neighbors.xyzs() + 0.01 * neighbors.normals(),
        )
        
        # Estimate rotation matrix
        
        neighbor_pts = neighbors.xyzs().T
        pstar = np.mean(neighbor_pts, axis=0)
        prel = neighbor_pts - pstar
        W = np.matmul(prel.T, prel)
        w, V = np.linalg.eigh(W)
        # V[:, 0] corresponds to the smallest eigenvalue, and V[:, 2] to the
        # largest.
        R = np.fliplr(V)
        # R[:, 0] corresponds to the largest eigenvalue, and R[:, 2] to the
        # smallest (the normal).

        # Handle improper rotations
        R = R @ np.diag([1, 1, np.linalg.det(R)])

        # If the normal is not pointing towards the querypoint...
        if (p_query).dot(R[:, 2]) < 0:
            # then flip the y and z axes.
            R = R @ np.diag([1, -1, -1])
        
        meshcat.SetTransform(
            "Axis_of_point_of_interes", RigidTransform(RotationMatrix(R), p_query)
        )
        


point_cloud_processing_example()

Press ESC or the 'Stop Normal Estimation' button in Meshcat to continue


# Estimation of the normal of the surface

In [None]:
def normal_estimation():
    system = MustardExampleSystem()
    context = system.CreateDefaultContext()

    meshcat.Delete()
    meshcat.DeleteAddedControls()
    meshcat.SetProperty("/Background", "visible", False)

    point_cloud = system.GetOutputPort("camera0_point_cloud").Eval(context)
    cloud = point_cloud.Crop(
        lower_xyz=[-0.3, -0.3, -0.3], upper_xyz=[0.3, 0.3, 0.3]
    )
    meshcat.SetObject("point_cloud", cloud)

    # Extract camera position
    plant = system.GetSubsystemByName("plant")
    p_WC = (
        plant.GetFrameByName("camera0_origin")
        .CalcPoseInWorld(plant.GetMyContextFromRoot(context))
        .translation()
    )

    kdtree = KDTree(cloud.xyzs().T)

    num_closest = 40
    neighbors = PointCloud(num_closest)
    AddMeshcatTriad(meshcat, "least_squares_basis", length=0.03, radius=0.0005)

    meshcat.AddSlider(
        "point",
        min=0,
        max=cloud.size() - 1,
        step=1,
        value=429,  # 4165,
        decrement_keycode="ArrowLeft",
        increment_keycode="ArrowRight",
    )
    meshcat.AddButton("Stop Normal Estimation", "Escape")
    print(
        "Press ESC or the 'Stop Normal Estimation' button in Meshcat to continue"
    )
    last_index = -1
    while meshcat.GetButtonClicks("Stop Normal Estimation") < 1:
        index = round(meshcat.GetSliderValue("point"))
        if index == last_index:
            time.sleep(0.1)
            continue
        last_index = index

        query = cloud.xyz(index)
        meshcat.SetObject("query", Sphere(0.001), Rgba(0, 1, 0))
        meshcat.SetTransform("query", RigidTransform(query))
        (distances, indices) = kdtree.query(
            query, k=num_closest, distance_upper_bound=0.1
        )

        neighbors.resize(len(distances))
        neighbors.mutable_xyzs()[:] = cloud.xyzs()[:, indices]

        meshcat.SetObject(
            "neighbors", neighbors, rgba=Rgba(0, 0, 1), point_size=0.001
        )

        neighbor_pts = neighbors.xyzs().T
        pstar = np.mean(neighbor_pts, axis=0)
        prel = neighbor_pts - pstar
        W = np.matmul(prel.T, prel)
        w, V = np.linalg.eigh(W)
        # V[:, 0] corresponds to the smallest eigenvalue, and V[:, 2] to the
        # largest.
        R = np.fliplr(V)
        # R[:, 0] corresponds to the largest eigenvalue, and R[:, 2] to the
        # smallest (the normal).

        # Handle improper rotations
        R = R @ np.diag([1, 1, np.linalg.det(R)])

        # If the normal is not pointing towards the camera...
        if (p_WC - -query).dot(R[:, 2]) < 0:
            # then flip the y and z axes.
            R = R @ np.diag([1, -1, -1])

        meshcat.SetTransform(
            "least_squares_basis", RigidTransform(RotationMatrix(R), query)
        )

        if not running_as_notebook:
            break

    meshcat.DeleteAddedControls()


normal_estimation()

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=9122a98d-4e2b-4d7d-a028-9e4f57238976' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>