In [None]:
%load_ext autoreload
%autoreload 2

import sys
import os
import time
import numpy as np
import pandas as pd
import pickle

from pydrake.common import FindResourceOrThrow, RandomGenerator
from pydrake.geometry import Box, Role, Sphere, RoleAssign
from pydrake.geometry import SceneGraph, IllustrationProperties
from pydrake.geometry.optimization import IrisInConfigurationSpace, IrisOptions
from pydrake.math import RigidTransform, RollPitchYaw, RotationMatrix
from pydrake.multibody.optimization import CalcGridPointsOptions, Toppra
from pydrake.multibody.parsing import LoadModelDirectives, Parser, ProcessModelDirectives
from pydrake.multibody.plant import AddMultibodyPlantSceneGraph, CoulombFriction, MultibodyPlant 
from pydrake.multibody.tree import RevoluteJoint, SpatialInertia, UnitInertia
from pydrake.systems.framework import DiagramBuilder
from pydrake.systems.meshcat_visualizer import ConnectMeshcatVisualizer
from pydrake.all import MultibodyPositionToGeometryPose
from pydrake.systems.primitives import TrajectorySource, Multiplexer, ConstantVectorSource
from pydrake.systems.analysis import Simulator

from comparison.planning import PRM, BiRRT

from pydrake.trajectories import PiecewisePolynomial

from meshcat import Visualizer
import meshcat.geometry as g

#from linear_spp import *
#from bspline_spp import *
#from iris_helpers import *
#from comparison_helpers import *

import matplotlib.pyplot as plt
from matplotlib.colors import to_hex
from IPython.display import HTML, SVG
import pickle
import multiprocessing as mp
from itertools import combinations, chain

from pydrake.solvers.gurobi import GurobiSolver
from pydrake.solvers.mosek import MosekSolver
GurobiSolver.AcquireLicense()
MosekSolver.AcquireLicense()
import cdd

In [None]:
# Setup meshcat
from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
proc, zmq_url, web_url = start_zmq_server_as_subprocess(server_args=[])

# Sporadically need to run `pkill -f meshcat`

In [None]:
def lower_alpha(plant, inspector, model_instances, alpha, scene_graph):
    for model in model_instances:
        for body_id in plant.GetBodyIndices(model):
            frame_id = plant.GetBodyFrameIdOrThrow(body_id)
            geometry_ids = inspector.GetGeometries(frame_id, Role.kIllustration)
            for g_id in geometry_ids:
                prop = inspector.GetIllustrationProperties(g_id)
                new_props = IllustrationProperties(prop)
                phong = prop.GetProperty("phong", "diffuse")
                phong.set(phong.r(), phong.g(), phong.b(), alpha)
                new_props.UpdateProperty("phong", "diffuse", phong)
                scene_graph.AssignRole(plant.get_source_id(), g_id, new_props, RoleAssign.kReplace)

In [None]:
vis = Visualizer(zmq_url=zmq_url)
vis.delete()
display(vis.jupyter_cell())

builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
parser = Parser(plant)
#parser.package_map().Add( "wsg_50_description", os.path.dirname(FindResourceOrThrow(
 #           "drake/manipulation/models/wsg_50_description/package.xml")))

floor_dim = np.array([4, 4, 0.2])
floor = plant.AddRigidBody("floor", SpatialInertia(
        mass=1.0, p_PScm_E=np.array([0., 0., 0.]), G_SP_E=UnitInertia(1.0, 1.0, 1.0)))
plant.WeldFrames(plant.world_frame(), floor.body_frame(), RigidTransform(p=np.array([0, 0, -floor_dim[2]/2.])))
plant.RegisterVisualGeometry(floor, RigidTransform(), Box(*floor_dim), "floor_vis",
                             np.array([0.5, 0.5, 0.5, 1.]))
plant.RegisterCollisionGeometry(floor, RigidTransform(), Box(*floor_dim), "florr_collision",
                                CoulombFriction(0.9, 0.8))

directives_file = "./models/iiwa14_spheres_collision_welded_gripper.yaml"
directives = LoadModelDirectives(directives_file)
models = ProcessModelDirectives(directives, plant, parser)
[iiwa, wsg, shelf, binR, binL] =  models

visual_iiwas = []
visual_wsgs = []
#iiwa_file = FindResourceOrThrow("./models/iiwa_description/urdf/iiwa14_spheres_collision.urdf")
#wsg_file = FindResourceOrThrow("drake/planning/models/schunk_wsg_50_welded_fingers.sdf")
    

plant.Finalize()

viz_role = Role.kIllustration
# viz_role = Role.kProximity
visualizer = ConnectMeshcatVisualizer(builder, scene_graph, zmq_url=zmq_url,
                                      delete_prefix_on_load=False, role=viz_role)
diagram = builder.Build()



visualizer.load()
context = diagram.CreateDefaultContext()
plant_context = plant.GetMyContextFromRoot(context)
plant_context = plant.GetMyMutableContextFromRoot(context)
q0 =      [-0.43303849430001307,
    0.15450520762404665,
    0.30334346818001523,
    -1.0376976962796667,
    0.11739903437607266,
    0.5348776947156673,
    1.0783084430904017]
plant.SetPositions(plant_context,q0)
diagram.Publish(context)

PositionUpperLimits = plant.GetPositionUpperLimits()
PositionLowerLimits = plant.GetPositionLowerLimits()

In [None]:
def visualize_trajectory(traj_list, show_line = False, iiwa_ghosts = [], alpha = 0.5, regions = []):
    """ This will only execute the first trajectory"""
    if not isinstance(traj_list, list):
        traj_list = [traj_list]
    
    combined_traj = combine_trajectory(traj_list)
    vis = Visualizer(zmq_url=zmq_url)
    vis.delete()

    builder = DiagramBuilder()
    scene_graph = builder.AddSystem(SceneGraph())
    plant = MultibodyPlant(time_step=0.0)
    plant.RegisterAsSourceForSceneGraph(scene_graph)
    inspector = scene_graph.model_inspector()
    
    
    parser = Parser(plant, scene_graph)
    parser.package_map().Add( "wsg_50_description", os.path.dirname(FindResourceOrThrow(
            "drake/manipulation/models/wsg_50_description/package.xml")))

    floor_dim = np.array([4, 4, 0.2])
    floor = plant.AddRigidBody("floor", SpatialInertia(
        mass=1.0, p_PScm_E=np.array([0., 0., 0.]), G_SP_E=UnitInertia(1.0, 1.0, 1.0)))
    plant.WeldFrames(plant.world_frame(), floor.body_frame(), RigidTransform(p=np.array([0, 0, -floor_dim[2]/2.])))
    plant.RegisterVisualGeometry(floor, RigidTransform(), Box(*floor_dim), "floor_vis",
                             np.array([0.5, 0.5, 0.5, 0.0]))
    plant.RegisterCollisionGeometry(floor, RigidTransform(), Box(*floor_dim), "florr_collision",
                                CoulombFriction(0.9, 0.8))

    directives_file = FindResourceOrThrow("drake/planning/models/iiwa14_spheres_collision_welded_gripper.yaml")
    directives = LoadModelDirectives(directives_file)
    models = ProcessModelDirectives(directives, plant, parser)
    [iiwa, wsg, shelf, binR, binL] =  models
    
    #add clones versions of the iiwa
    if len(iiwa_ghosts):
        lower_alpha(plant, inspector, [iiwa.model_instance, wsg.model_instance], alpha,scene_graph)
    visual_iiwas = []
    visual_wsgs = []
    iiwa_file = FindResourceOrThrow("drake/manipulation/models/iiwa_description/urdf/iiwa14_spheres_collision.urdf")
    wsg_file = FindResourceOrThrow("drake/planning/models/schunk_wsg_50_welded_fingers.sdf")
    
    for i, q in enumerate(iiwa_ghosts):
        new_iiwa = parser.AddModelFromFile(iiwa_file, "vis_iiwa_"+str(i))
        new_wsg = parser.AddModelFromFile(wsg_file, "vis_wsg_"+str(i))
        plant.WeldFrames(plant.world_frame(), plant.GetFrameByName("base", new_iiwa), RigidTransform())
        plant.WeldFrames(plant.GetFrameByName("iiwa_link_7", new_iiwa), plant.GetFrameByName("body", new_wsg),
                         RigidTransform(rpy=RollPitchYaw([np.pi/2., 0, 0]), p=[0, 0, 0.114]))
        visual_iiwas.append(new_iiwa)
        visual_wsgs.append(new_wsg)
        lower_alpha(plant, inspector, [new_iiwa, new_wsg], alpha, scene_graph)
        index = 0
        for joint_index in plant.GetJointIndices(visual_iiwas[i]):
            joint = plant.get_mutable_joint(joint_index)
            if isinstance(joint, RevoluteJoint):
                joint.set_default_angle(q[index])
                index += 1
    
    plant.Finalize()

    to_pose = builder.AddSystem(MultibodyPositionToGeometryPose(plant))
    builder.Connect(to_pose.get_output_port(), scene_graph.get_source_pose_port(plant.get_source_id()))

    traj_system = builder.AddSystem(TrajectorySource(combined_traj))

    mux = builder.AddSystem(Multiplexer([7 for _ in range(1 + len(iiwa_ghosts))]))
    builder.Connect(traj_system.get_output_port(), mux.get_input_port(0))
    
    for i, q in enumerate(iiwa_ghosts):
        ghost_pos = builder.AddSystem(ConstantVectorSource(q))
        builder.Connect(ghost_pos.get_output_port(), mux.get_input_port(1+i) )
    
    
    builder.Connect(mux.get_output_port(), to_pose.get_input_port())


    viz_role = Role.kIllustration
    # viz_role = Role.kProximity
    visualizer = ConnectMeshcatVisualizer(builder, scene_graph, zmq_url=zmq_url,
                                      delete_prefix_on_load=False, role=viz_role)

    diagram = builder.Build()
    
    if show_line:
        X_lists = []
        for traj in traj_list:
            #X_list =  [ForwardKinematics(k) for k in traj.vector_values(traj.get_segment_times()).T.tolist()]
            X_list = ForwardKinematics(traj.vector_values(np.linspace(traj.start_time(), traj.end_time(), 15000)).T.tolist())
            X_lists.append(X_list)
            
        c_list_hex = [0xFF0000,0x00FF00, 0x0000FF]
        c_list_rgb = [[i/255 for i in (0, 0, 255, 255)],[i/255 for i in (255, 191, 0, 255)],[i/255 for i in (255, 64, 0, 255)]]

        line_type = [('line',g.MeshBasicMaterial), ('phong',g.MeshPhongMaterial), ('lambert',g.MeshLambertMaterial)]
        
        for i, X_list in enumerate(X_lists):
            vertices = list(map(lambda X: X.translation(), X_list))
            colors = [np.array(c_list_rgb[i]) for _ in range(len(X_list))]
            vertices = np.stack(vertices).astype(np.float32).T
            colors = np.array(colors).astype(np.float32).T
            vis["paths"][str(i)].set_object(g.Points(g.PointsGeometry(vertices, color=colors),
                                            g.PointsMaterial(size=0.015)))
        
        
        reg_colors = plt.cm.viridis(np.linspace(0, 1,len(regions)))
        reg_colors[:,3] = 1.0
        
        for i, reg in enumerate(regions):
            X_reg = ForwardKinematics(spider_web(reg))
            vertices = list(map(lambda X: X.translation(), X_reg))
            colors = [reg_colors[i] for _ in range(len(X_reg))]
            vertices = np.stack(vertices).astype(np.float32).T
            colors = np.array(colors).astype(np.float32).T
            vis["regions"][str(i)].set_object(g.Points(g.PointsGeometry(vertices, color=colors),
                                                       g.PointsMaterial(size=0.015)))
        
    visualizer.load()
    simulator = Simulator(diagram)
    visualizer.start_recording()
    simulator.AdvanceTo(combined_traj.end_time())
    visualizer.publish_recording()
    return vis


# Generate Iris Regions
### via manual seeds

In [None]:
iris_options = IrisOptions()
iris_options.require_sample_point_is_contained = True
iris_options.iteration_limit = 10
iris_options.termination_threshold = -1
iris_options.relative_termination_threshold = 0.01
iris_options.enable_ibex = False
CORE_CNT = mp.cpu_count() # you may edit this

In [None]:
#used for paper 
seed_points =  {"Above Shelve": [0, 0.4, 0, -0.8, 0, 0.35, 1.57],   
                "Top Rack":[0, 0.45, 0, -1.35, 0, -0.25, 1.57],   
                "Middle Rack":[0, 0.8, 0, -1.5, 0, -0.7, 1.57],       
                "Right Bin":[1.57, 0.7, 0, -1.6, 0, 0.8, 1.57], 
                "Left Bin":[-1.57, 0.7, 0, -1.6, 0, 0.8, 1.57],
                "Front to Shelve":[0, 0.2, 0, -2.09, 0, -0.3, 1.57], 
                "Right to Shelve":[0.8, 0.7, 0, -1.6, 0, 0, 1.57],
                "Left to Shelve":[-0.8, 0.7, 0, -1.6, 0, 0, 1.57]} 

In [None]:
#optimized for paper
seed_points = { "Above Shelve": [0, 0.4, 0, -0.8, 0, 0.35, 1.57],       
                "Top Rack": [0, 0.45, 0, -1.35, 0, -0.25, 1.57],    
                "Middle Rack": [0, 0.8, 0, -1.5, 0, -0.7, 1.57],       
                "Right Bin": [1.57, 0.7, 0, -1.6, 0, 0.8, 0],          
                "Left Bin": [-1.57, 0.7, 0, -1.6, 0, 0.8, 0],
                "Helper 1": [0.0, 0.175, 0.0, -1.675, 0.0, -0.275, 1.57],
                "Helper 2":[-0.23670698873773643, 0.0869880191351135,
                                -0.14848147591054467, -1.8845683628360246,
                                    0.11582904486351761, -0.004801857412368382, 1.2343852721657547],
                "Helper 3":[0.1901439226352184, 0.24791935747561322, 0.026105794277764514,
                                -1.862822617783086, 0.2509538957570004, -0.2466088852764272, 1.3435623440091402],
                "Helper 4":[-0.43303849430001307, 0.15450520762404665, 0.30334346818001523,
                                -1.0376976962796667, 0.11739903437607266, 0.5348776947156673, 1.0783084430904017]}

In [None]:
def calcRegion(seed, verbose):
    start_time = time.time()
    context = diagram.CreateDefaultContext()
    plant_context = plant.GetMyContextFromRoot(context)
    plant.SetPositions(plant_context, seed)
    hpoly = IrisInConfigurationSpace(plant, plant_context, iris_options)
    print("Seed:", seed, "\tTime:", time.time() - start_time, flush=True)
    return hpoly

def generateRegions(seed_points, verbose = True):
    seeds = list(seed_points.values()) if type(seed_points) is dict else seed_points
    regions = []
    loop_time = time.time()
    with mp.Pool(processes = CORE_CNT) as pool:
        regions = pool.starmap(calcRegion, [[seed, verbose] for seed in seeds])
    
    if verbose:
        print("Loop time:", time.time() - loop_time)
    
    if type(seed_points) is dict:
        return dict(list(zip(seed_points.keys(), regions)))
    
    return regions

In [None]:
regions = generateRegions(seed_points)

### Save regions

In [None]:
with open('./IRIS.reg', 'wb') as f:
    pickle.dump(regions,f)

# Generate PRM

In [None]:
step_size = np.pi/32
collision_step_size = step_size/128
#collision_step_size = step_size/8
solve_timeout = 100
K = 5
roadmap_size = 15000

In [None]:
prm = PRM(K, solve_timeout, plant, plant_context, step_size, collision_step_size)
#stats = prm.Grow(roadmap_size)
#roadmap = prm.roadmap
print(stats) 


## For the final Demonstration 
Unfortunately this is out of order. PRM is not capable of reaching into the shelf directly
So we have to use BiRRT to connect some subgraphs

In [None]:
for q in keypoints:
    AddNodeToRoadmap(q, NNDistanceDirection(), roadmap, distance_fn,check_edge_validity_fn, K, False,True,False )
    

In [None]:
for q in RRT_Connect([0, 0.4, 0, -0.8, 0, 0.35, 1.57],[0, 0.45, 0, -1.35, 0, -0.25, 1.57]).Path():
    AddNodeToRoadmap(q, NNDistanceDirection(), roadmap, distance_fn,check_edge_validity_fn, K, False,True,False )
  

In [None]:
for q in RRT_Connect([0, 0.45, 0, -1.35, 0, -0.25, 1.57], [0, 0.8, 0, -1.5, 0, -0.7, 1.57]).Path():
    AddNodeToRoadmap(q, NNDistanceDirection(), roadmap, distance_fn,check_edge_validity_fn, K, False,True,False )
  

In [None]:
get_prm_path(demo2, 5, verbose = True, smoothing = False, roadmap=roadmap)

### Save roadmap

In [None]:
with open('./PRM.rmp', 'wb') as f:
    pickle.dump(roadmap,f)

# Load Iris Regions and Roadmap

### make sure to run SimpleLinearSPP before running the comparison below

In [None]:
with open('./PRM_BiRRT_Demo_1.rmp', 'rb') as f:
    roadmap = pickle.load(f)

In [None]:
with open('./PRM_BiRRT_super_dense.rmp', 'rb') as f:
    roadmap = pickle.load(f)

In [None]:
with open('./IRIS_paper_intuitive.reg', 'rb') as f:
    regions = pickle.load(f)

In [None]:
with open('./IRIS_paper_optimal.reg', 'rb') as f:
    regions = pickle.load(f)

In [None]:
spp = SimplerLinearSPP(regions.copy())

In [None]:
SVG(spp.VisualizeGraph()) 

In [None]:
bspp = BsplineSPP(regions.copy(), 5)

In [None]:
SVG(bspp.VisualizeGraph())

# Load PRM & Run Comparison

In [None]:
def comparison_sampling_fn():
    nq = len(PositionLowerLimits)
    sample = np.random.rand(nq) * (PositionUpperLimits - PositionLowerLimits) + PositionLowerLimits
    
    #check if sampe is in a region
    sampe_in_regions = False
    for R in regions:
        if R.PointInSet(sample):
            sampe_in_regions = True
            break

    if not sampe_in_regions:
        return None
    
    #verify sample has a collision free configuration 
    plant.SetPositions(plant_context, sample)
    query_object = plant.get_geometry_query_input_port().Eval(plant_context)
    if query_object.HasCollisions():
        return None
    
    return sample


In [None]:
def get_prm_path(sequence, seed, verbose = False, smoothing = False, roadmap = roadmap):
    path = [sequence[0]]
    start_time = time.time()
    for start_pt, goal_pt in zip(sequence[:-1], sequence[1:]):
        prm_path = QueryPath([start_pt], [goal_pt], roadmap, distance_fn,
                check_edge_validity_fn, K,
                use_parallel=False,
                distance_is_symmetric=True,
                add_duplicate_states=False,
                limit_astar_pqueue_duplicates=True).Path()
        if smoothing:
            prm_path = ShortcutSmoothPath(prm_path, 100, 100, 1, 0.5, 0.5, False,
                               check_edge_validity_fn,distance_fn,
                               InterpolateWaypoint, RandomGenerator(seed))
        if len(prm_path) == 0:
            if verbose:
                print(f"Failed between {start_pt} and {goal_pt}")
            return None
        
        path += prm_path[1:]
            
    
    if verbose:
        print(f"Time: {round(time.time() - start_time, 3)}s")
        
    return np.stack(path).T

def get_simple_spp_path(sequence, verbose = False):
    path = [sequence[0]]
    start_time = time.time()
    for start_pt, goal_pt in zip(sequence[:-1], sequence[1:]):
        waypoints, _, _ = spp.SolvePath(start_pt, goal_pt, True, verbose)
        if waypoints is None:
            if verbose:
                print(f"Failed between {start_pt} and {goal_pt}")
            return None
        path += waypoints.T[1:].tolist()
    
    if verbose:
        print(f"Time: {round(time.time() - start_time, 3)}s")
    return np.stack(path).T


#for comparison function

def simple_linear_rebuild(start_pt, goal_pt):
    path = get_simple_spp_path([start_pt, goal_pt])
    if path is None:
        return None, None
    return make_traj(path), None

def bspline_rebuild(start_pt, goal_pt):
    path, _, _ = bspp.SolvePath(start_pt, goal_pt, True, True)
    if path is None:
        return None, None
    return path, None

def dense_prm(start_pt, goal_pt):
    path = get_prm_path([start_pt, goal_pt], 5)
    if path is None:
        return None, None
    return make_traj(path), None

def smoothed_dense_prm(start_pt, goal_pt):
    path = get_prm_path([start_pt, goal_pt], 5, False, True)
    if path is None:
        return None, None
    return make_traj(path), None

# Demonstration

In [None]:
demo_a = [[0, 0.4, 0, -0.8, 0, 0.35, 1.57],        # above shelf
            [0, 0.45, 0, -1.35, 0, -0.25, 1.57]]     # in shelf 1

demo_b = [[0, 0.45, 0, -1.35, 0, -0.25, 1.57],     # in shelf 1
            [0, 0.8, 0, -1.5, 0, -0.7, 1.57]]        # in shelf 2 

demo_c = [[0, 0.8, 0, -1.5, 0, -0.7, 1.57],        # in shelf 2 
            [1.57, 0.7, 0, -1.6, 0, 0.8, 1.57]]      # in left  bin

demo_d = [[1.57, 0.7, 0, -1.6, 0, 0.8, 1.57],      # in left  bin
            [-1.57, 0.7, 0, -1.6, 0, 0.8, 1.57]]      # in right  bin

demo_e = [[-1.57, 0.7, 0, -1.6, 0, 0.8, 1.57],      # in right  bin
            [0, 0.4, 0, -0.8, 0, 0.35, 1.57]]      # above shelf

demo_circle = [[0, 0.4, 0, -0.8, 0, 0.35, 1.57],        # above shelf
                [0, 0.45, 0, -1.35, 0, -0.25, 1.57],     # in shelf 1
               [0, 0.8, 0, -1.5, 0, -0.7, 1.57],        # in shelf 2 
                 [1.57, 0.7, 0, -1.6, 0, 0.8, 1.57],      # in left  bin
                [-1.57, 0.7, 0, -1.6, 0, 0.8, 1.57],      # in right  bin
                 [0, 0.4, 0, -0.8, 0, 0.35, 1.57]]      # above shelf

In [None]:
def combine_trajectory(traj_list, wait = 2):
    knotList = []
    time_delta = 0
    time_list = []
    for traj in traj_list:
        knots = traj.vector_values(traj.get_segment_times()).T
        knotList.append(knots)
        
        duration = traj.end_time() - traj.start_time()
        offset = 0 
        try:
            offset = time_list[-1][-1] + 0.1
        except:
            pass
        time_list.append(np.linspace(offset, duration + offset,  knots.shape[0]))
        
        #add wait time
        if wait > 0.0:
            knotList.append(knotList[-1][-1,:])
            time_list.append(np.array([time_list[-1][-1] + wait]))
            
        
    path = np.vstack(knotList).T
    time_break = np.hstack(time_list)

    return PiecewisePolynomial.FirstOrderHold(time_break, path)
        
def make_traj(path, speed = 2):
    t_breaks = [0]
    movement_between_segment = np.sqrt(np.sum(np.square(path.T[1:,:] - path.T[:-1,:]), axis = 1))
    for s in movement_between_segment/speed:
        t_breaks += [s + t_breaks[-1]]
    return PiecewisePolynomial.FirstOrderHold(t_breaks, path)


def get_traj_length(trajectory, bspline = False):
    path_length = 0
    if bspline:
        knots = trajectory.vector_values(np.linspace(trajectory.start_time(), trajectory.end_time(), 1000))
    else:
        knots = trajectory.vector_values(trajectory.get_segment_times())
    
    individual_mov = []
    for ii in range(knots.shape[1] - 1):
        path_length += np.linalg.norm(knots[:, ii+1] - knots[:, ii])
        individual_mov.append([np.linalg.norm(knots[j, ii+1] - knots[j, ii]) for j in range(7)])
    
    print(np.sum(individual_mov, axis = 0))
    return path_length

In [None]:
execute_demo = demo_a
linear_spp_traj = make_traj(get_simple_spp_path(execute_demo, verbose = True), speed = 2)
print(f"Linear SPP length: {get_traj_length(linear_spp_traj)}")

# prm_traj = make_traj(get_prm_path(execute_demo, 5, verbose = True, smoothing = False, roadmap=roadmap), speed = 2)
# print(f"PRM length: {get_traj_length(prm_traj)}")

smoothed_prm_traj = make_traj(get_prm_path(execute_demo, 5, verbose = True, smoothing = True, roadmap=roadmap), speed = 2)
print(f"smoothed PRM length: {get_traj_length(smoothed_prm_traj)}")


In [None]:
vis_meshcat = visualize_trajectory([linear_spp_traj,prm_traj,smoothed_prm_traj],
                     show_line = True,
                     iiwa_ghosts = execute_demo,
                     alpha =  0.3,
                     regions = [])


In [None]:
vis_meshcat = visualize_trajectory([linear_spp_traj,prm_traj,smoothed_prm_traj],
                     show_line = False,
                     iiwa_ghosts = [[0, 0.2, 0, -2.09, 0, -0.3, 1.57],[0.8, 0.7, 0, -1.6, 0, 0, 1.57],[-0.8, 0.7, 0, -1.6, 0, 0, 1.57]],
                     alpha =  0.3,
                     regions = [])

In [None]:
vis_meshcat.static_html()

In [None]:
with open ("SPP_PRM_comparison_simple.html", "w") as f:
    f.write(vis_meshcat.static_html())

In [None]:
visualize_trajectory()

### Tools to Analyze roadmap

In [None]:
def return_subgraphs_idx(roadmap):
    milestones_idx = set(range(len(roadmap.GetNodesMutable())))
    subgraphs = []
    
    idx = 0
    while len(milestones_idx) != 0:
        #pick a random milestone
        subgraph = set()
        root = milestones_idx.pop()
        queue = [root]
    
        #start expanding bfs style 
        while len(queue) != 0:
            expand_node_idx = queue.pop(0)
        
            if expand_node_idx in subgraph:
                #already expanded
                continue
            
            subgraph.add(expand_node_idx)
            expand_node = roadmap.GetNodeImmutable(expand_node_idx)

            if list(expand_node.GetValueImmutable()) in seed_points:
                print(idx, expand_node.GetValueImmutable())
            for child in [edge.GetToIndex() for edge in expand_node.GetOutEdgesImmutable()]:
                queue.append(child)
    
        subgraphs.append(subgraph)
        idx +=1
        #remove subgraph from milestones 
        milestones_idx -= set(subgraph)
    return milestones_idx, subgraphs

milestones_idx, subgraphs = return_subgraphs_idx(roadmap)

print(f"Found {len(subgraphs)} subgraphs, with sizes {list(map(len,subgraphs))}")

## RRT Connect function

In [None]:
def connect_fn(nearest, sample, is_start_tree):
    total_dist = distance_fn(nearest, sample)
    total_steps = int(np.ceil(total_dist / step_size))        
        
    propagated_states = []
    parent_offset = -1
    current = nearest
    for steps in range(total_steps):
        current_target = None
        target_dist = distance_fn(current, sample)
        if target_dist > step_size:
            #interpolate
            current_target = current + step_size/target_dist * (sample - current)
            
        elif target_dist < 1e-6:
            break
        else:
            current_target = sample
        
        if not check_edge_validity_fn(current, current_target):
            return propagated_states
    
                    
        propagated_states.append(PropagatedState(state=current_target, relative_parent_index=parent_offset))
        parent_offset += 1
        current = current_target

    return propagated_states

def states_connected_fn(source, target, is_start_tree):
    return np.linalg.norm(source - target) < 1e-6

In [None]:
def RRT_Connect(q_ini,q_final):
    solve_timeout_rrt = 500
    start_tree = [SimpleRRTPlannerState(q_ini)]
    end_tree = [SimpleRRTPlannerState(q_final)]
    
    goal_bias = 0.05
    
    def birrt_sampling():
        if np.random.rand() < goal_bias:
            if np.random.rand() < 0.5:
                return q_ini
            else:
                if np.random.rand() < 0.5:
                    return [ 0.18731861,  0.37941396, -0.27091993, -0.84213184, -0.21627076,
           0.45351121,  1.55646851]
                else:
                    return q_final
        return np.random.rand(len(PositionLowerLimits))*(PositionUpperLimits-PositionLowerLimits) + PositionLowerLimits




    nearest_neighbor_fn = MakeKinematicLinearBiRRTNearestNeighborsFunction(distance_fn=distance_fn, use_parallel = False)

    termination_fn = MakeBiRRTTimeoutTerminationFunction(solve_timeout_rrt)
    
    connect_result = BiRRTPlanSinglePath(
            start_tree=start_tree, goal_tree=end_tree,
            state_sampling_fn=birrt_sampling,
            nearest_neighbor_fn=nearest_neighbor_fn, propagation_fn=connect_fn,
            state_added_callback_fn=None,
            states_connected_fn=states_connected_fn,
            goal_bridge_callback_fn=None,
            tree_sampling_bias=0.5, p_switch_tree=0.25,
            termination_check_fn=termination_fn, rng=RandomGenerator(5))
    return connect_result

In [None]:
def get_rrt_path(sequence, seed, verbose = False, smoothing = False):
    path = [sequence[0]]
    start_time = time.time()
    for start_pt, goal_pt in zip(sequence[:-1], sequence[1:]):
        prm_path =RRT_Connect(start_pt, goal_pt).Path()
        if smoothing:
            prm_path = ShortcutSmoothPath(prm_path, 400, 400, 2, 0.5, 0.5, False,
                               check_edge_validity_fn,distance_fn,
                               InterpolateWaypoint, RandomGenerator(seed))
        if len(prm_path) == 0:
            if verbose:
                print(f"Failed between {start_pt} and {goal_pt}")
            return None
        
        for pt in prm_path[1:]:
            path.append(pt)
            
    
    if verbose:
        print(f"Time: {round(time.time() - start_time, 3)}s")
        
    return np.stack(path).T