# Drake & OMPL integration example

This example shows how to use Drake for collision checking and visualization, and OMPL for planning.

In [None]:
import numpy as np
from pydrake.planning import RobotDiagramBuilder, SceneGraphCollisionChecker
from cloth_tools.drake.building import add_meshcat_to_builder, finish_build
from cloth_tools.drake.scenes import add_dual_ur5e_and_table_to_builder

In [None]:
robot_diagram_builder = RobotDiagramBuilder()
meshcat = add_meshcat_to_builder(robot_diagram_builder)
arm_indices, gripper_indices = add_dual_ur5e_and_table_to_builder(robot_diagram_builder)

In [None]:
diagram, context = finish_build(robot_diagram_builder, meshcat)
plant = diagram.plant()
plant_context = plant.GetMyContextFromRoot(context)

In [None]:
collision_checker = SceneGraphCollisionChecker(
    model=diagram,
    robot_model_instances=[*arm_indices, *gripper_indices],
    edge_step_size=0.125,  # Arbitrary value: we don't use the CheckEdgeCollisionFree
    env_collision_padding=0.005,
    self_collision_padding=0.005,
)

In [None]:
start_joints_left = np.deg2rad([0, -90, -90, -90, 90, 0])
start_joints_right = np.deg2rad([-136, -116, -110, -133, 40, 0])

arm_left_index, arm_right_index = arm_indices
plant.SetPositions(plant_context, arm_left_index, start_joints_left)
plant.SetPositions(plant_context, arm_right_index, start_joints_right)

diagram.ForcedPublish(context)

q_all = plant.GetPositions(plant_context)
collision_checker.CheckConfigCollisionFree(q_all)

In [None]:
home_joints_left = np.deg2rad([180, -120, 60, -30, -90, -90])
home_joints_right = np.deg2rad([-180, -60, -60, -150, 90, 90])

plant.SetPositions(plant_context, arm_left_index, home_joints_left)
plant.SetPositions(plant_context, arm_right_index, home_joints_right)
diagram.ForcedPublish(context)

q_all = plant.GetPositions(plant_context)
collision_checker.CheckConfigCollisionFree(q_all)

## Configuring OMPL

### Moving the right arm home 🏠

In [None]:
# !pip install https://github.com/ompl/ompl/releases/download/prerelease/ompl-1.6.0-cp310-cp310-manylinux_2_28_x86_64.whl

In [None]:
from ompl import base as ob
from ompl import geometric as og

space = ob.RealVectorStateSpace(6)
bounds = ob.RealVectorBounds(6)
bounds.setLow(-2 * np.pi)
bounds.setHigh(2 * np.pi)
space.setBounds(bounds)

print(space.settings())

In [None]:
def state_to_numpy(state_ompl: ob.State):
    state = np.zeros(6)
    for i in range(6):
        state[i] = state_ompl[i]
    return state


def is_state_valid(state):
    q_all = np.zeros(12)
    q_all[:6] = start_joints_left  # keep this fixed for now
    q_all[6:12] = state_to_numpy(state)
    return collision_checker.CheckConfigCollisionFree(q_all)


start_state = ob.State(space)
goal_state = ob.State(space)

for i in range(6):
    start_state()[i] = start_joints_right[i]
    goal_state()[i] = home_joints_right[i]


print(is_state_valid(start_state))
print(is_state_valid(goal_state))

In [None]:
simple_setup = og.SimpleSetup(space)
simple_setup.setStateValidityChecker(ob.StateValidityCheckerFn(is_state_valid))
simple_setup.setStartAndGoalStates(start_state, goal_state)

# TODO: Should investigate effect of this further
step = float(np.deg2rad(5))
resolution = step / space.getMaximumExtent()
simple_setup.getSpaceInformation().setStateValidityCheckingResolution(resolution)

planner = og.RRTConnect(simple_setup.getSpaceInformation())
simple_setup.setPlanner(planner)

In [None]:
simple_setup.solve(20.0)

In [None]:
n_states = 100

if simple_setup.haveSolutionPath():
    simple_setup.simplifySolution()
    solution_path = simple_setup.getSolutionPath()
    print(solution_path.printAsMatrix())
    solution_path.interpolate(n_states)

In [None]:
import time

total_time = 8.0

for state in solution_path.getStates():
    plant.SetPositions(plant_context, arm_left_index, start_joints_left)
    plant.SetPositions(plant_context, arm_right_index, state_to_numpy(state))
    diagram.ForcedPublish(context)
    time.sleep(total_time / n_states)

### Dual arm planning 🤼

In [None]:
plant.SetPositions(plant_context, arm_left_index, start_joints_left)
plant.SetPositions(plant_context, arm_right_index, start_joints_right)
diagram.ForcedPublish(context)

In [None]:
from ompl import base as ob
from ompl import geometric as og

space_dual = ob.RealVectorStateSpace(12)
bounds_dual = ob.RealVectorBounds(12)
bounds_dual.setLow(-2 * np.pi)
bounds_dual.setHigh(2 * np.pi)
space_dual.setBounds(bounds_dual)

print(space_dual.settings())

In [None]:
def state_to_numpy_dual(state_ompl: ob.State):
    state = np.zeros(12)
    for i in range(12):
        state[i] = state_ompl[i]
    return state


def is_state_valid_dual(state):
    q_all = state_to_numpy_dual(state)
    return collision_checker.CheckConfigCollisionFree(q_all)


start_state_dual = ob.State(space_dual)
goal_state_dual = ob.State(space_dual)

for i in range(6):
    start_state_dual()[i] = start_joints_left[i]
    start_state_dual()[i + 6] = start_joints_right[i]
    goal_state_dual()[i] = home_joints_left[i]
    goal_state_dual()[i + 6] = home_joints_right[i]


print(is_state_valid_dual(start_state_dual))
print(is_state_valid_dual(goal_state_dual))

In [None]:
%%timeit
collision_checker.CheckConfigCollisionFree(np.concatenate([start_joints_left, start_joints_right]))

In [None]:
%%timeit
is_state_valid_dual(start_state_dual)

In [None]:
simple_setup_dual = og.SimpleSetup(space_dual)
simple_setup_dual.setStateValidityChecker(ob.StateValidityCheckerFn(is_state_valid_dual))
simple_setup_dual.setStartAndGoalStates(start_state_dual, goal_state_dual)

planner = og.RRTConnect(simple_setup_dual.getSpaceInformation())
simple_setup_dual.setPlanner(planner)

In [None]:
simple_setup_dual.solve(30.0)

In [None]:
n_states = 100

if simple_setup_dual.haveSolutionPath():
    simple_setup_dual.simplifySolution()
    solution_path = simple_setup_dual.getSolutionPath()
    print(solution_path.printAsMatrix())
    solution_path.interpolate(n_states)

In [None]:
import time

total_time = 8.0

for state in solution_path.getStates():
    q_all = state_to_numpy_dual(state)
    plant.SetPositions(plant_context, arm_left_index, q_all[:6])
    plant.SetPositions(plant_context, arm_right_index, q_all[6:])
    diagram.ForcedPublish(context)
    time.sleep(total_time / n_states)

    # station.set_arm_config(state_to_numpy(state))
    # scenario.simulate_extra_time(total_time / n_states, blocking=False)