In [1]:
%load_ext autoreload
%autoreload 2

In [40]:
import numpy as np
import jax
import jax.numpy as jnp
jax.config.update('jax_platform_name', 'cpu')
from iLQR import iLQR, Path
import csv
import random


def load_path(filepath: str):
    """
    Gets the centerline of the track from the trajectory data. We currently only
    support 2D track.

    Args:
        filepath (str): the path to file consisting of the centerline position.

    Returns:
        np.ndarray: centerline, of the shape (2, N).
    """
    x = []
    y = []
    with open(filepath) as f:
        spamreader = csv.reader(f, delimiter=',')
        for i, row in enumerate(spamreader):
            if i > 0:
                x.append(float(row[0]))
                y.append(float(row[1]))

    return np.array([x, y])

centerline = load_path('outerloop_center_smooth.csv')
path = Path(centerline, 0.6, 0.6, loop=True)

n = 10

In [41]:
from iLQR.config import Config
from iLQR.dynamics import Bicycle5D

config = Config()
dyn = Bicycle5D(config)

state = jnp.asarray(np.random.rand(5,n))
control = jnp.asarray(np.random.rand(2,n))

dyn.get_jacobian(state*random.random(), control*random.random())

%timeit dyn.get_jacobian(state*random.random(), control*random.random())[0].block_until_ready()

84.5 µs ± 1.71 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [77]:
from iLQR.cost import Cost

init_state = np.array([1, 5.4, 2, 3.14, 0])
controls = jnp.zeros((2, n))

# init_state = jnp.asarray(init_state)
states, controls = dyn.rollout_nominal(init_state, controls)

print("dyn.rollout_nominal")
%timeit dyn.rollout_nominal(init_state*random.random(), controls*random.random())[0].block_until_ready()

dyn.rollout_nominal
53.3 µs ± 252 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [43]:
refs = path.get_reference(states[:2, :])
s = refs[4,:]

print("path.get_reference use jax (Not good)")
%timeit path.get_reference(states[:2, :])

print("path.get_reference use np")
states_np = np.asarray(states)
%timeit path.get_reference(states_np[:2, :]*random.random(), 1e-3)

path.get_reference use jax (Not good)
6.35 ms ± 117 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
path.get_reference use np
3.05 ms ± 79 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [49]:
from iLQR.cost.collision_checker import CollisionChecker, Obstacle
collision_checker = CollisionChecker(config)

obs1 = np.array([[-1, -1, -0.5, -0.5], [5.2, 5.9, 5.9, 5.2]]).T
obs2 = np.array([[1, 1, 1.5, 1.5], [-0.2, 0.5, 0.5, -0.2]]).T
temp = [obs1 for _ in range(10)]

print("Create an obstacle")
%timeit Obstacle(temp)

Create an obstacle
2.22 ms ± 122 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [86]:


obstacle_list = [Obstacle(temp)]
# polygon = 
print("!!! BottleNeck")
print("Avoid conversion between np array and jax array!")
print("Avoid slice of jax array with [] in non-jited code")

print("collision_checker._check_collision") 
states_np = np.asarray(states)
%timeit collision_checker._check_collision(states_np[:,0]*random.random(), obstacle_list[0].at(0))
print("collision_checker.check_collisions")
# states = jnp.asarray(states)
%timeit collision_checker.check_collisions(states_np*random.random(), obstacle_list)

obs_refs = collision_checker.check_collisions(states_np, obstacle_list)


!!! BottleNeck
Avoid conversion between np array and jax array!
Avoid slice of jax array with [] in non-jited code
collision_checker._check_collision
82.3 µs ± 494 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
collision_checker.check_collisions
778 µs ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [87]:
cost = Cost(config)

cost.get_traj_cost(states, controls, refs, obs_refs)
print("!!! BottleNeck")
print("make sure obs_refs is np array")
cost.get_traj_cost(states, controls, refs, jnp.asarray(obs_refs))
print(" cost.get_traj_cost")
%timeit cost.get_traj_cost(states*random.random(), controls*random.random(), refs*random.random(), obs_refs*random.random()).block_until_ready()



!!! BottleNeck
make sure obs_refs is np array
 cost.get_traj_cost
151 µs ± 2.08 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [90]:
print("cost.get_derivatives")
%timeit cost.get_derivatives(states*random.random(), controls*random.random(), refs*random.random(), obs_refs*random.random())[-1].block_until_ready()


cost.get_derivatives
463 µs ± 6.58 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
