# Drake & OMPL integration example

This example shows how to use Drake for collision checking and visualization, and OMPL for planning.

In [None]:
import os
from pathlib import Path
import numpy as np
from airo_planner.utils import files
from pydrake.geometry import Meshcat
from pydrake.geometry import Meshcat
from pydrake.math import RigidTransform, RollPitchYaw
from pydrake.geometry import MeshcatVisualizer
from pydrake.planning import RobotDiagramBuilder, SceneGraphCollisionChecker
from pydrake.multibody.plant import DiscreteContactSolver

# from pydrake.visualization import ApplyVisualizationConfig, VisualizationConfig

In [None]:
robot_diagram_builder = RobotDiagramBuilder()  # time_step=0.001 even when I set timestep I get the mimic joint warning
scene_graph = robot_diagram_builder.scene_graph()
plant = robot_diagram_builder.plant()
builder = robot_diagram_builder.builder()
parser = robot_diagram_builder.parser()

# Add visualizer
meshcat = Meshcat()
visualizer = MeshcatVisualizer.AddToBuilder(builder, scene_graph, meshcat)
# config = VisualizationConfig(publish_contacts=True, enable_alpha_sliders=True)
# ApplyVisualizationConfig(config, builder=builder, plant=plant, meshcat=meshcat)

# This get rid ot the warning for the mimic joints in the Robotiq gripper
plant.set_discrete_contact_solver(DiscreteContactSolver.kSap)

# Load URDF files
resources_root = str(files.get_resources_dir())
ur5e_urdf = Path(resources_root) / "robots" / "ur5e" / "ur5e.urdf"
robotiq_2f_85_gripper_urdf = Path(resources_root) / "grippers" / "2f_85_gripper" / "urdf" / "robotiq_2f_85_static.urdf"
cube_urdf = "cube_and_cylinder.urdf"
table_urdf = "table.urdf"

arm_left_index = parser.AddModelFromFile(str(ur5e_urdf), model_name="arm_left")
arm_right_index = parser.AddModelFromFile(str(ur5e_urdf), model_name="arm_right")
gripper_left_index = parser.AddModelFromFile(str(robotiq_2f_85_gripper_urdf), model_name="gripper_left")
gripper_right_index = parser.AddModelFromFile(str(robotiq_2f_85_gripper_urdf), model_name="gripper_right")
table_index = parser.AddModelFromFile(str(table_urdf))

# Weld some frames together
world_frame = plant.world_frame()
arm_left_frame = plant.GetFrameByName("base_link", arm_left_index)
arm_right_frame = plant.GetFrameByName("base_link", arm_right_index)
arm_left_wrist_frame = plant.GetFrameByName("wrist_3_link", arm_left_index)
arm_right_wrist_frame = plant.GetFrameByName("wrist_3_link", arm_right_index)
gripper_left_frame = plant.GetFrameByName("base_link", gripper_left_index)
gripper_right_frame = plant.GetFrameByName("base_link", gripper_right_index)
table_frame = plant.GetFrameByName("base_link", table_index)

distance_between_arms = 0.9
distance_between_arms_half = distance_between_arms / 2

plant.WeldFrames(world_frame, arm_left_frame)
plant.WeldFrames(world_frame, arm_right_frame, RigidTransform([distance_between_arms, 0, 0]))
plant.WeldFrames(
    arm_left_wrist_frame, gripper_left_frame, RigidTransform(p=[0, 0, 0], rpy=RollPitchYaw([0, 0, np.pi / 2]))
)
plant.WeldFrames(
    arm_right_wrist_frame, gripper_right_frame, RigidTransform(p=[0, 0, 0], rpy=RollPitchYaw([0, 0, np.pi / 2]))
)
plant.WeldFrames(world_frame, table_frame, RigidTransform([distance_between_arms_half, 0, 0]))


# plant.Finalize()

# Set up collision checking
diagram = robot_diagram_builder.Build()


# Not sure if this is needed
def _configuration_distance(q1, q2):
    return np.linalg.norm(q1 - q2)


collision_checker = SceneGraphCollisionChecker(
    model=diagram,
    robot_model_instances=[arm_left_index, arm_right_index],
    configuration_distance_function=_configuration_distance,
    edge_step_size=0.125,
)

# Create default contexts ~= state
context = diagram.CreateDefaultContext()
plant_context = plant.GetMyContextFromRoot(context)
diagram.ForcedPublish(context)

In [None]:
start_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0])
start_joints_right = np.deg2rad([-136, -116, -110, -133, 40, 0])

plant.SetPositions(plant_context, arm_left_index, start_joints_left)
plant.SetPositions(plant_context, arm_right_index, start_joints_right)

diagram.ForcedPublish(context)

q_all = plant.GetPositions(plant_context)
collision_checker.CheckConfigCollisionFree(q_all)

In [None]:
home_joints_left = np.deg2rad([180, -135, 95, -50, -90, -90])
home_joints_right = np.deg2rad([-180, -45, -95, -130, 90, 90])

plant.SetPositions(plant_context, arm_left_index, home_joints_left)
plant.SetPositions(plant_context, arm_right_index, home_joints_right)
diagram.ForcedPublish(context)

q_all = plant.GetPositions(plant_context)
collision_checker.CheckConfigCollisionFree(q_all)

## Configuring OMPL

### Moving the right arm home 🏠

In [None]:
# !pip install https://github.com/ompl/ompl/releases/download/prerelease/ompl-1.6.0-cp310-cp310-manylinux_2_28_x86_64.whl

In [None]:
from ompl import base as ob
from ompl import geometric as og

space = ob.RealVectorStateSpace(6)
bounds = ob.RealVectorBounds(6)
bounds.setLow(-2 * np.pi)
bounds.setHigh(2 * np.pi)
space.setBounds(bounds)

print(space.settings())

In [None]:
def state_to_numpy(state_ompl: ob.State):
    state = np.zeros(6)
    for i in range(6):
        state[i] = state_ompl[i]
    return state


def is_state_valid(state):
    q_all = np.zeros(12)
    q_all[:6] = start_joints_left  # keep this fixed for now
    q_all[6:12] = state_to_numpy(state)
    return collision_checker.CheckConfigCollisionFree(q_all)


start_state = ob.State(space)
goal_state = ob.State(space)

for i in range(6):
    start_state()[i] = start_joints_right[i]
    goal_state()[i] = home_joints_right[i]


print(is_state_valid(start_state))
print(is_state_valid(goal_state))

In [None]:
simple_setup = og.SimpleSetup(space)
simple_setup.setStateValidityChecker(ob.StateValidityCheckerFn(is_state_valid))
simple_setup.setStartAndGoalStates(start_state, goal_state)

In [None]:
simple_setup.solve(5.0)

In [None]:
n_states = 100

if simple_setup.haveSolutionPath():
    simple_setup.simplifySolution()
    solution_path = simple_setup.getSolutionPath()
    solution_path.interpolate(n_states)
    print(solution_path.printAsMatrix())

In [None]:
import time

total_time = 8.0

for state in solution_path.getStates():
    plant.SetPositions(plant_context, arm_left_index, start_joints_left)
    plant.SetPositions(plant_context, arm_right_index, state_to_numpy(state))
    diagram.ForcedPublish(context)
    time.sleep(total_time / n_states)

    # station.set_arm_config(state_to_numpy(state))
    # scenario.simulate_extra_time(total_time / n_states, blocking=False)

### Dual arm planning 🤼

In [None]:
plant.SetPositions(plant_context, arm_left_index, start_joints_left)
plant.SetPositions(plant_context, arm_right_index, start_joints_right)
diagram.ForcedPublish(context)

In [None]:
from ompl import base as ob
from ompl import geometric as og

space_dual = ob.RealVectorStateSpace(12)
bounds_dual = ob.RealVectorBounds(12)
bounds_dual.setLow(-2 * np.pi)
bounds_dual.setHigh(2 * np.pi)
space_dual.setBounds(bounds_dual)

print(space_dual.settings())

In [None]:
def state_to_numpy_dual(state_ompl: ob.State):
    state = np.zeros(12)
    for i in range(12):
        state[i] = state_ompl[i]
    return state


def is_state_valid_dual(state):
    q_all = state_to_numpy_dual(state)
    return collision_checker.CheckConfigCollisionFree(q_all)


start_state_dual = ob.State(space_dual)
goal_state_dual = ob.State(space_dual)

for i in range(6):
    start_state_dual()[i] = start_joints_left[i]
    start_state_dual()[i + 6] = start_joints_right[i]
    goal_state_dual()[i] = home_joints_left[i]
    goal_state_dual()[i + 6] = home_joints_right[i]


print(is_state_valid_dual(start_state_dual))
print(is_state_valid_dual(goal_state_dual))

In [None]:
%%timeit
is_state_valid_dual(start_state_dual)

In [None]:
simple_setup_dual = og.SimpleSetup(space_dual)
simple_setup_dual.setStateValidityChecker(ob.StateValidityCheckerFn(is_state_valid_dual))
simple_setup_dual.setStartAndGoalStates(start_state_dual, goal_state_dual)

In [None]:
simple_setup_dual.solve(30.0)

In [None]:
n_states = 100

if simple_setup_dual.haveSolutionPath():
    simple_setup_dual.simplifySolution()
    solution_path = simple_setup_dual.getSolutionPath()
    solution_path.interpolate(n_states)
    print(solution_path.printAsMatrix())

In [None]:
import time

total_time = 8.0

for state in solution_path.getStates():
    q_all = state_to_numpy_dual(state)
    plant.SetPositions(plant_context, arm_left_index, q_all[:6])
    plant.SetPositions(plant_context, arm_right_index, q_all[6:])
    diagram.ForcedPublish(context)
    time.sleep(total_time / n_states)

    # station.set_arm_config(state_to_numpy(state))
    # scenario.simulate_extra_time(total_time / n_states, blocking=False)