In [115]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [116]:
import numpy as np
import jax
import jax.numpy as jnp
from iLQR import iLQR, Path
import csv


def load_path(filepath: str):
    """
    Gets the centerline of the track from the trajectory data. We currently only
    support 2D track.

    Args:
        filepath (str): the path to file consisting of the centerline position.

    Returns:
        np.ndarray: centerline, of the shape (2, N).
    """
    x = []
    y = []
    with open(filepath) as f:
        spamreader = csv.reader(f, delimiter=',')
        for i, row in enumerate(spamreader):
            if i > 0:
                x.append(float(row[0]))
                y.append(float(row[1]))

    return np.array([x, y])

centerline = load_path('outerloop_center_smooth.csv')
path = Path(centerline, 0.6, 0.6, loop=True)

n = 10

In [117]:
from iLQR.config import Config
from iLQR.dynamics import Bicycle5D

config = Config()
dyn = Bicycle5D(config)

state = jnp.array(np.random.rand(5,n))
control = jnp.array(np.random.rand(2,n))

dyn.get_jacobian(state, control)

%timeit dyn.get_jacobian(state, control)[0].block_until_ready()

201 µs ± 13.5 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [118]:
from iLQR.cost import Cost

init_state = np.array([1, 5.4, 2, 3.14, 0])
controls = jnp.zeros((2, n))

# init_state = jnp.asarray(init_state)
states, controls = dyn.rollout_nominal(init_state, controls)

print("dyn.rollout_nominal")
%timeit dyn.rollout_nominal(init_state, controls)[0].block_until_ready()

dyn.rollout_nominal
763 µs ± 30.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [119]:
refs = path.get_reference(states[:2, :])
s = refs[4,:]

print("path.get_reference use jax (Not good)")
%timeit path.get_reference(states[:2, :])

print("path.get_reference use np")
%timeit path.get_reference(np.asarray(states)[:2, :])

path.get_reference use jax (Not good)
6.1 ms ± 243 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
path.get_reference use np
2.98 ms ± 19.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [120]:
from iLQR.cost.collision_checker import CollisionChecker, Obstacle
collision_checker = CollisionChecker(config)

obs1 = np.array([[-1, -1, -0.5, -0.5], [5.2, 5.9, 5.9, 5.2]]).T
obs2 = np.array([[1, 1, 1.5, 1.5], [-0.2, 0.5, 0.5, -0.2]]).T
temp = [obs1 for _ in range(10)]

print("Create an obstacle")
%timeit Obstacle(temp)

Create an obstacle
1.83 ms ± 169 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [230]:


obstacle_list = [Obstacle(temp)]
# polygon = 
print("!!! BottleNeck")
print("Avoid conversion between np array and jax array!")
print("Avoid slice of jax array with [] in non-jited code")

print("collision_checker._check_collision") 
states_np = np.asarray(states)
%timeit collision_checker._check_collision(states_np[:,0], obstacle_list[0].at(0))
print("collision_checker.check_collisions")
# states = jnp.asarray(states)
%timeit collision_checker.check_collisions(states_np, obstacle_list)

obs_refs = collision_checker.check_collisions(states_np, obstacle_list)


!!! BottleNeck
Avoid conversion between np array and jax array!
Avoid slice of jax array with [] in non-jited code
collision_checker._check_collision
49.5 µs ± 4 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
collision_checker.check_collisions
500 µs ± 6.06 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [237]:
cost = Cost(config)
print("!!! BottleNeck with obs_refs")
%timeit cost.get_traj_cost(states, controls, refs, obs_refs).block_until_ready()

1.5 ms ± 271 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
