In [2]:
from ilqr import iLQR
import gym
import numpy as np

from aprl.agents import MujocoFiniteDiffDynamics, MujocoFiniteDiffCost

In [3]:
def on_iteration(iteration_count, xs, us, J_opt, accepted, converged):
    info = "converged" if converged else ("accepted" if accepted else "failed")
    print("iteration", iteration_count, info, J_opt, xs[-1])

In [7]:
# Environment setup
env = gym.make('Reacher-v2').unwrapped
env.reset()
dynamics = MujocoFiniteDiffDynamics(env)
x0 = dynamics.get_state()
N = 50  # planning horizon
us_init = np.array([env.action_space.sample() for _ in range(N)])

# Finite difference cost

In [9]:
finite_cost = MujocoFiniteDiffCost(env)
finite_ilqr = iLQR(dynamics, finite_cost, N)

In [10]:
finite_xs, finite_us = finite_ilqr.fit(x0, us_init, n_iterations=100, on_iteration=on_iteration)

iteration 0 accepted 31.503252071646976 [ 1.          0.80515276 -2.89302955 -0.10886765  0.0151158  -0.6736436
 -4.67787549  0.          0.        ]
iteration 1 accepted 31.169911245071532 [ 1.          1.68549552 -2.95233647 -0.10886765  0.0151158  -0.57630312
 -4.213096    0.          0.        ]
iteration 2 accepted 27.68207734296457 [ 1.          2.83242984 -2.9748163  -0.10886765  0.0151158  -0.28743043
 -3.43852481  0.          0.        ]
iteration 3 failed 27.682077342964572 [ 1.          2.83242984 -2.9748163  -0.10886765  0.0151158  -0.28743043
 -3.43852481  0.          0.        ]
iteration 4 failed 27.682077342964572 [ 1.          2.83242984 -2.9748163  -0.10886765  0.0151158  -0.28743043
 -3.43852481  0.          0.        ]
iteration 5 accepted 25.381046962201076 [ 1.          4.83426061 -2.94636317 -0.10886765  0.0151158  -0.42449212
 -3.20644243  0.          0.        ]
iteration 6 accepted 25.308472522867778 [ 1.          4.74544457 -3.01319983 -0.10886765  0.0151158 

iteration 56 accepted 22.82935461042388 [ 1.          2.21533187 -3.05904446 -0.10886765  0.0151158  -7.45646789
  2.47987691  0.          0.        ]
iteration 57 failed 22.829354610423877 [ 1.          2.21533187 -3.05904446 -0.10886765  0.0151158  -7.45646789
  2.47987691  0.          0.        ]
iteration 58 failed 22.829354610423877 [ 1.          2.21533187 -3.05904446 -0.10886765  0.0151158  -7.45646789
  2.47987691  0.          0.        ]
iteration 59 accepted 22.828965922967488 [ 1.          2.20704675 -3.05900626 -0.10886765  0.0151158  -7.46599812
  2.47585278  0.          0.        ]
iteration 60 converged 22.82896521450159 [ 1.          2.20685722 -3.0590116  -0.10886765  0.0151158  -7.46590504
  2.47619624  0.          0.        ]


# Analytic cost

In [11]:
import theano
from theano import tensor as T
from ilqr.cost import AutoDiffCost

def make_reacher_cost(control_weight=1.0):
    x_inputs = [T.dscalar('_x1'), T.dscalar('theta'), T.dscalar('phi'), T.dscalar('targetx'), T.dscalar('targety'),
                T.dscalar('_x5'), T.dscalar('_x6'), T.dscalar('_x7'), T.dscalar('_x8')]
    u_inputs = [T.dscalar('thetadotdot'), T.dscalar('phidotdot')]
    x = T.stack(x_inputs)
    u = T.stack(u_inputs)
    
    control_cost = T.dot(u, u)
    target_xpos = x[3:5]
    body1_xpos = 0.1 * T.stack([T.cos(x[1]), T.sin(x[1])])
    fingertip_xpos_delta = 0.11 * T.stack([T.cos(x[2]), T.sin(x[2])])
    fingertip_xpos = body1_xpos + fingertip_xpos_delta
    delta = fingertip_xpos - target_xpos
    state_cost = T.sqrt(T.dot(delta, delta))
    l = state_cost + control_weight * control_cost
    l_terminal = T.zeros(())
    return AutoDiffCost(l, l_terminal, x_inputs, u_inputs)
analytic_cost = make_reacher_cost()
analytic_ilqr = iLQR(dynamics, analytic_cost, N)

In [12]:
analytic_xs, analytic_us = analytic_ilqr.fit(x0, us_init, n_iterations=100, on_iteration=on_iteration)

iteration 0 accepted 11.529334923966342 [ 1.          0.39344263 -2.74474423 -0.10886765  0.0151158  -0.32013018
 -4.67728636  0.          0.        ]
iteration 1 accepted 10.734666700881823 [ 1.          0.63199452 -2.79961197 -0.10886765  0.0151158  -0.35707333
 -4.59049351  0.          0.        ]
iteration 2 accepted 10.018588848608541 [ 1.          0.68160732 -2.91154566 -0.10886765  0.0151158  -0.3805866
 -4.35061088  0.          0.        ]
iteration 3 failed 10.018588848608541 [ 1.          0.68160732 -2.91154566 -0.10886765  0.0151158  -0.3805866
 -4.35061088  0.          0.        ]
iteration 4 accepted 8.842045816112698 [ 1.          3.75999298 -2.15240289 -0.10886765  0.0151158   0.19061677
 -3.28279928  0.          0.        ]
iteration 5 accepted 8.696877894270953 [ 1.          4.76429844 -3.00136766 -0.10886765  0.0151158   1.34812178
 -0.01047241  0.          0.        ]
iteration 6 accepted 8.650918706408529 [ 1.          3.42003518 -3.0013279  -0.10886765  0.0151158  

iteration 55 failed 7.258948308822207 [ 1.          2.05596659 -2.18576684 -0.10886765  0.0151158   2.11635868
  2.17683682  0.          0.        ]
iteration 56 accepted 7.258744944462422 [ 1.          2.04771439 -2.18923855 -0.10886765  0.0151158   2.10779276
  2.16230065  0.          0.        ]
iteration 57 failed 7.258744944462421 [ 1.          2.04771439 -2.18923855 -0.10886765  0.0151158   2.10779276
  2.16230065  0.          0.        ]
iteration 58 accepted 7.258735752595286 [ 1.          2.04611872 -2.18997128 -0.10886765  0.0151158   2.10614738
  2.15922319  0.          0.        ]
iteration 59 failed 7.258735752595286 [ 1.          2.04611872 -2.18997128 -0.10886765  0.0151158   2.10614738
  2.15922319  0.          0.        ]
iteration 60 failed 7.258735752595286 [ 1.          2.04611872 -2.18997128 -0.10886765  0.0151158   2.10614738
  2.15922319  0.          0.        ]
iteration 61 failed 7.258735752595286 [ 1.          2.04611872 -2.18997128 -0.10886765  0.0151158   2.

# Receding horizon

In [33]:
from ilqr.controller import RecedingHorizonController

controller = RecedingHorizonController(x0, analytic_ilqr)  # can also use finite_ilqr
rew = []
receding_xs = []
receding_us = []
for x, u in controller.control(us_init, subsequent_n_iterations=10):
    ob, r, done, info = env.step(u)
    receding_xs.append(x)
    receding_us.append(u)
    rew.append(r)
    print('iteration', len(rew), r, x, u)
    if len(rew) == 50:
        break

iteration 1 -2.2629039864544644 [[ 0.00000000e+00 -5.72154454e-02  1.66843426e-02 -1.08867654e-01
   1.51157970e-02  3.07588298e-03  3.24761839e-03  0.00000000e+00
   0.00000000e+00]
 [ 2.00000000e-02 -9.68604896e-02  5.64919316e-02 -1.08867654e-01
   1.51157970e-02 -3.95439911e+00  3.96430418e+00  0.00000000e+00
   0.00000000e+00]] [[-1.0277718  1.0263855]]
iteration 2 -1.7445505974161928 [[ 0.02       -0.09686049  0.05649193 -0.10886765  0.0151158  -3.95439911
   3.96430418  0.          0.        ]
 [ 0.04       -0.20578382  0.17420181 -0.10886765  0.0151158  -6.92803386
   7.79390958  0.          0.        ]] [[-0.77109252  0.98670814]]
iteration 3 -1.1243071611342692 [[ 0.04       -0.20578382  0.17420181 -0.10886765  0.0151158  -6.92803386
   7.79390958  0.          0.        ]
 [ 0.06       -0.36570322  0.35982352 -0.10886765  0.0151158  -9.05683125
  10.7583604   0.          0.        ]] [[-0.57246426  0.78746068]]
iteration 4 -0.7941646107293611 [[  0.06        -0.36570322   0.3

iteration 26 -0.6584565770488134 [[ 5.00000000e-01 -2.04337058e+00 -6.07803701e-02 -1.08867654e-01
   1.51157970e-02  8.79703382e+00 -2.37486714e+01  0.00000000e+00
   0.00000000e+00]
 [ 5.20000000e-01 -1.84642687e+00 -5.48100550e-01 -1.08867654e-01
   1.51157970e-02  1.08903573e+01 -2.49791902e+01  0.00000000e+00
   0.00000000e+00]] [[ 0.57305754 -0.42934729]]
iteration 27 -0.517001755650502 [[ 5.20000000e-01 -1.84642687e+00 -5.48100550e-01 -1.08867654e-01
   1.51157970e-02  1.08903573e+01 -2.49791902e+01  0.00000000e+00
   0.00000000e+00]
 [ 5.40000000e-01 -1.60995674e+00 -1.05571333e+00 -1.08867654e-01
   1.51157970e-02  1.27505250e+01 -2.57793359e+01  0.00000000e+00
   0.00000000e+00]] [[ 0.52458413 -0.32693386]]
iteration 28 -0.37443237532147305 [[ 5.40000000e-01 -1.60995674e+00 -1.05571333e+00 -1.08867654e-01
   1.51157970e-02  1.27505250e+01 -2.57793359e+01  0.00000000e+00
   0.00000000e+00]
 [ 5.60000000e-01 -1.33990924e+00 -1.57493385e+00 -1.08867654e-01
   1.51157970e-02  1.4

# Rollouts

In [14]:
import time

def rollout(env, dynamics, x0, us, render=False):
    dynamics.set_state(x0)
    if render:
        env.render()
    rew = []
    actual_xs = []
    for u in us:
        _obs, r, done, info = env.step(u)
        rew.append(r)
        actual_xs.append(dynamics.get_state())
        assert not done
        if render:
            env.render()
            time.sleep(0.05)
    return rew, actual_xs

In [37]:
rew, actual_xs = rollout(env.unwrapped, dynamics, x0, receding_us, render=True)