In [2]:
import jax
import jax.numpy as jnp
from trajax import optimizers



In [3]:
dt = 0.1
horizon = 10

@jax.jit
def system(x, u, t):
  """Classic (omnidirectional) wheeled robot system.
  Args:
    x: state, (3, ) array
    u: control, (3, ) array
    t: scalar time
  Returns:
    xdot: state time derivative, (3, )
  """
  c = jnp.cos(x[2])
  s = jnp.sin(x[2])
  
  A = jnp.array([[c, -s, 0],
                 [s, c, 0],
                 [0, 0, 1]])
  xdot = A @ u

  return xdot

def dynamics(x, u, t):
  return x + dt * system(x, u, t)

In [50]:
def cost(x, u, t, *args):  
  u_x_err = 0.4 - u[0]
  
  stage_cost = jnp.dot(jnp.concatenate([x,u]), jnp.dot(P, jnp.concatenate([x,u]))) + jnp.dot(q, jnp.concatenate([x,u])) + jnp.dot(u_x_err, u_x_err) + jnp.dot(u,u)
  
  # final_cost = 0
  # return jnp.where(t == horizon, final_cost, stage_cost)
  return stage_cost


x0 = jnp.array([0.0, 0.0, 0.0])
u0 = jnp.zeros((horizon, 3))
P = jnp.ones((6,6))
q = jnp.ones(6)

X, U, obj, _, _, _, iteration = optimizers.ilqr(
        cost,
        dynamics,
        x0,
        u0,
        maxiter=1000
    ) 

print(type(U))

<class 'jaxlib.xla_extension.ArrayImpl'>


In [39]:
def cost(x, u, t, embedding):  
  P = embedding[:36]
  q = embedding[36:]
  P = jnp.reshape(P, (6,6))
  
  u_x_err = 0.4 - u[0]
  
  stage_cost = jnp.dot(jnp.concatenate([x,u]), jnp.dot(P, jnp.concatenate([x,u]))) + jnp.dot(q, jnp.concatenate([x,u])) + jnp.dot(u_x_err, u_x_err) + jnp.dot(u,u)
  
  # final_cost = 0
  # return jnp.where(t == horizon, final_cost, stage_cost)
  return stage_cost


In [47]:
hessian = jax.hessian(cost, argnums=3)

hessian(jnp.ones(3), jnp.ones(3), 0, jnp.ones(42))

Array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]], dtype=float32)