In [1]:
from ilqr import iLQR
import gym
import numpy as np

from aprl.agents import MujocoFiniteDiffDynamics, MujocoFiniteDiffCost
from aprl.agents.mujoco_control import MujocoFiniteDiffDynamicsLowLevel



Logging to /tmp/openai-2019-01-25-19-17-01-302287
Choosing the latest nvidia driver: /usr/lib/nvidia-396, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-396']
Choosing the latest nvidia driver: /usr/lib/nvidia-396, among ['/usr/lib/nvidia-375', '/usr/lib/nvidia-396']


In [2]:
# Environment setup
env = gym.make('Reacher-v2').unwrapped
env.seed(42)
_obs = env.reset()

In [3]:
# Planning setup
N = 100  # planning horizon
us_init = np.array([env.action_space.sample() for _ in range(N)])

dynamics = {
    # Uses mujoco_py's MjSimState. This saves time, qpos, qvel, act and udd_state.
    'mujoco_py': MujocoFiniteDiffDynamics(env),
    # Uses my MujocoRelevantSimState, which contains all fields that MuJoCo's derivative.cpp copies.
    'my_all': MujocoFiniteDiffDynamicsLowLevel(env, kind='all'),
    # All fields I think matter; excludes qfrc and xfrc_applied
    'my_recommended': MujocoFiniteDiffDynamicsLowLevel(env, kind='recommended'),
    # qpos, qvel and qacc; no warmstart.
    'my_basic_plus': MujocoFiniteDiffDynamicsLowLevel(env, kind='basic_plus'),
    # As above, but restricts to fields qpos & qvel. 
    # I expect this to match mujoco_py, since time does not matter, and act and udd_state are blank for Reacher.
    'my_basic': MujocoFiniteDiffDynamicsLowLevel(env, kind='basic'),
}
x0s = {k: dyn.get_state() for k, dyn in dynamics.items()}

# Finite difference cost

In [None]:
finite_cost = MujocoFiniteDiffCost(env)
finite_ilqr = iLQR(dynamics, finite_cost, N)

In [None]:
finite_xs, finite_us = finite_ilqr.fit(x0, us_init, n_iterations=100, on_iteration=on_iteration)

# Analytic cost

In [4]:
import theano
from theano import tensor as T
from ilqr.cost import AutoDiffCost

# Reacher, Gym observation:
# obs[0:1]: xs; np.cos(qpos[0:2]) (qpos[0] is joint0, qpos[1] is joint1)
# obs[2:3]: ys; np.sin(qpos[0:2]);
# obs[4:5]: goal x and y; qpos[2:]; (target_x and target_y)
# obs[6:7]: theta dot
# obs[8:9]: xy of fingertip - target

def make_reacher_cost(kind, control_weight=1.0):
    # qpos[0:3]: theta of joint 0, theta of joint 1; target x and y.
    qpos_inputs = [T.dscalar('theta'), T.dscalar('phi'), T.dscalar('targetx'), T.dscalar('targety')]
    # qvel: derivatives of the above; note target x and y are constant so have derivative zero.
    qvel_inputs = [T.dscalar('thetadot'), T.dscalar('phidot'), T.dscalar('_zero1'), T.dscalar('_zero2')]
    # qacc: second derivatives of qpos. We don't actually use these in the cost.
    qacc_inputs = [T.dscalar('_acc{}'.format(i)) for i in range(len(qpos_inputs))]
    # qacc_warmstart: same shape as qacc
    qacc_warmstart_inputs = [T.dscalar('_accwarm{}'.format(i)) for i in range(len(qpos_inputs))]
    # qfrc_applied: same shape as qacc
    qfrc_applied_inputs = [T.dscalar('_qfrc_applied{}'.format(i)) for i in range(len(qpos_inputs))]
    # xfrc_applied: (5,6)
    xfrc_applied_inputs = [T.dscalar('_xfrc_applied{}'.format(i)) for i in range(5 * 6)]
    if kind == 'mujoco_py':
        # Reacher, MJSimState.flatten():
        # obs[0]: time step, obs[1:4]: qpos[0:3]; obs[5:8]: qvel[0:3]
        # In general might include action and udd_state, but not for Reacher.
        x_inputs = [T.dscalar('_time')] + qpos_inputs + qvel_inputs
    elif kind == 'my_all':
        # Reacher, MujocoRelevantState.flatten()
        x_inputs = qpos_inputs + qvel_inputs + qacc_inputs + qacc_warmstart_inputs + qfrc_applied_inputs + xfrc_applied_inputs
    elif kind == 'my_recommended':
        x_inputs = qpos_inputs + qvel_inputs + qacc_inputs + qacc_warmstart_inputs
    elif kind == 'my_basic_plus':
        x_inputs = qpos_inputs + qvel_inputs + qacc_inputs
    elif kind == 'my_basic':
        x_inputs = qpos_inputs + qvel_inputs
    else:
        raise ValueError("Unrecognised kind: '{}'".format(kind))
    u_inputs = [T.dscalar('thetadotdot'), T.dscalar('phidotdot')]
    qpos = T.stack(qpos_inputs)
    u = T.stack(u_inputs)
    
    control_cost = T.dot(u, u)
    target_xpos = qpos[2:4]
    body1_xpos = 0.1 * T.stack([T.cos(qpos[0]), T.sin(qpos[1])])
    fingertip_xpos_delta = 0.11 * T.stack([T.cos(qpos[1]), T.sin(qpos[1])])
    fingertip_xpos = body1_xpos + fingertip_xpos_delta
    delta = fingertip_xpos - target_xpos
    state_cost = T.sqrt(T.dot(delta, delta))
    l = state_cost + control_weight * control_cost
    l_terminal = T.zeros(())
    return AutoDiffCost(l, l_terminal, x_inputs, u_inputs)

In [5]:
def on_iteration(iteration_count, xs, us, J_opt, accepted, converged):
    info = "converged" if converged else ("accepted" if accepted else "failed")
    print("iteration", iteration_count, info, J_opt, xs[-1])

costs = {k: make_reacher_cost(k) for k in dynamics.keys()}
ilqrs = {k: iLQR(dyn, costs[k], N) for k, dyn in dynamics.items()}
xs = {}
us = {}
print(ilqrs.keys())
for k, ilqr in ilqrs.items():
    print('*** Fitting {} ***'.format(k))
    x0 = x0s[k]
    xs[k], us[k] = ilqr.fit(x0, us_init, n_iterations=100, on_iteration=on_iteration)

dict_keys(['mujoco_py', 'my_all', 'my_recommended', 'my_basic_plus', 'my_basic'])
*** Fitting mujoco_py ***
iteration -1 converged 79.96275887661503 [ 2.          7.77702008  1.21356644  0.02243766  0.07369059  3.36928471
 -8.26249354  0.          0.        ]
iteration 0 accepted 29.265523052069145 [ 2.          7.70043922  1.385693    0.02243766  0.07369059  3.60590171
 -8.192679    0.          0.        ]
iteration 1 accepted 23.25980857049346 [ 2.          7.54924921  1.55563349  0.02243766  0.07369059  3.85633321
 -8.08002884  0.          0.        ]
iteration 2 accepted 20.571476283879594 [ 2.          7.48510752  1.91564321  0.02243766  0.07369059  4.1656226
 -7.47378228  0.          0.        ]
iteration 3 accepted 19.314515201704687 [ 2.          7.05905434  2.27970603  0.02243766  0.07369059  3.3447239
 -5.31276474  0.          0.        ]
iteration 4 accepted 19.270514329106753 [ 2.          6.82783495  2.63384984  0.02243766  0.07369059  1.71902347
 -2.00996315  0.          

iteration 1 accepted 79.26610987757593 [ 7.74641531e+00  1.21524421e+00  2.24376640e-02  7.36905865e-02
  3.34649536e+00 -8.25581721e+00  0.00000000e+00  0.00000000e+00
 -1.79986737e+02 -1.79249592e+01  0.00000000e+00  0.00000000e+00
 -1.79986737e+02 -1.79249592e+01  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00]
iteration 2 accepted 79.2601534342358 [ 7.72579799e+00  1.21428814e+00  2.24376640e-02  7.36905865e-02
  3.34246028e+00 -8.25919360e+00  0.00000000

iteration 13 accepted 26.897422833397016 [ 18.20800332   2.27735588   0.02243766   0.07369059  14.43727856
  -1.07484165   0.           0.         -14.43653293   1.04040941
   0.           0.         -14.43653293   1.04040941   0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.        ]
iteration 14 accepted 26.791625484347236 [16.78959612  1.53880965  0.02243766  0.07369059  8.76467883 -2.98204726
  0.          0.         -8.76859402  2.96630747  0.          0.
 -8.76859402  2.96630747  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.     

iteration 26 accepted 24.698671507196277 [ 18.4138264    2.21424491   0.02243766   0.07369059  14.07665191
  -1.12148933   0.           0.         -14.01375834   1.25737646
   0.           0.         -14.01375834   1.25737646   0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.        ]
iteration 27 failed 24.698671507196284 [ 18.4138264    2.21424491   0.02243766   0.07369059  14.07665191
  -1.12148933   0.           0.         -14.01375834   1.25737646
   0.           0.         -14.01375834   1.25737646   0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.    

iteration 39 accepted 21.950124464440965 [ 13.29876674   2.74478408   0.02243766   0.07369059   3.41982764
   0.06593536   0.           0.         -12.57037146   1.13310907
   0.           0.         -12.57037146   1.13310907   0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.
   0.           0.           0.           0.           0.        ]
iteration 40 accepted 21.804287067927053 [12.68090097  2.77663537  0.02243766  0.07369059  0.4689909   0.19680666
  0.          0.         -7.18477383  0.54176146  0.          0.
 -7.18477383  0.54176146  0.          0.          0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.          0.          0.     

iteration 51 accepted 21.512946684027757 [ 1.25085449e+01  2.75565743e+00  2.24376640e-02  7.36905865e-02
  1.15091647e-01  1.75261360e-02  0.00000000e+00  0.00000000e+00
 -4.83819055e-01 -1.05453683e-02  0.00000000e+00  0.00000000e+00
 -4.83819055e-01 -1.05453683e-02  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00]
iteration 52 failed 21.51294668402777 [ 1.25085449e+01  2.75565743e+00  2.24376640e-02  7.36905865e-02
  1.15091647e-01  1.75261360e-02  0.000000

iteration -1 converged 79.96275887661503 [ 7.77702008e+00  1.21356644e+00  2.24376640e-02  7.36905865e-02
  3.36928471e+00 -8.26249354e+00  0.00000000e+00  0.00000000e+00
 -1.80000140e+02 -1.79251765e+01  0.00000000e+00  0.00000000e+00
 -1.80000140e+02 -1.79251765e+01  0.00000000e+00  0.00000000e+00]
iteration 0 accepted 79.3541596217529 [ 7.77051152e+00  1.21481979e+00  2.24376640e-02  7.36905865e-02
  3.36582069e+00 -8.25763157e+00  0.00000000e+00  0.00000000e+00
 -1.79995963e+02 -1.79262821e+01  0.00000000e+00  0.00000000e+00
 -1.79995963e+02 -1.79262821e+01  0.00000000e+00  0.00000000e+00]
iteration 1 accepted 79.26610987757593 [ 7.74641531e+00  1.21524421e+00  2.24376640e-02  7.36905865e-02
  3.34649536e+00 -8.25581721e+00  0.00000000e+00  0.00000000e+00
 -1.79986737e+02 -1.79249592e+01  0.00000000e+00  0.00000000e+00
 -1.79986737e+02 -1.79249592e+01  0.00000000e+00  0.00000000e+00]
iteration 2 accepted 79.2601534342358 [ 7.72579799e+00  1.21428814e+00  2.24376640e-02  7.36905865e

iteration 33 accepted 24.546389897578482 [ 17.63971951   2.20571726   0.02243766   0.07369059  13.41329664
  -1.11854826   0.           0.         -13.84616121   1.30026141
   0.           0.         -13.84616121   1.30026141   0.
   0.        ]
iteration 34 accepted 24.36416474473091 [ 16.19893132   1.87984618   0.02243766   0.07369059  11.32768013
  -1.69257883   0.           0.         -13.79354211   1.31322168
   0.           0.         -13.79354211   1.31322168   0.
   0.        ]
iteration 35 failed 24.364164744730903 [ 16.19893132   1.87984618   0.02243766   0.07369059  11.32768013
  -1.69257883   0.           0.         -13.79354211   1.31322168
   0.           0.         -13.79354211   1.31322168   0.
   0.        ]
iteration 36 failed 24.364164744730903 [ 16.19893132   1.87984618   0.02243766   0.07369059  11.32768013
  -1.69257883   0.           0.         -13.79354211   1.31322168
   0.           0.         -13.79354211   1.31322168   0.
   0.        ]
iteration 37 accepted

iteration 2 failed 79.45141719726762 [ 7.76475012e+00  1.21132293e+00  2.24376640e-02  7.36905865e-02
  3.36602213e+00 -8.27035752e+00  0.00000000e+00  0.00000000e+00
 -1.79992194e+02 -1.79153793e+01  0.00000000e+00  0.00000000e+00]
iteration 3 accepted 79.29343498839117 [ 7.59289365e+00  1.21178163e+00  2.24376640e-02  7.36905865e-02
  3.22742790e+00 -8.26805679e+00  0.00000000e+00  0.00000000e+00
 -1.79953232e+02 -1.79100671e+01  0.00000000e+00  0.00000000e+00]
iteration 4 accepted 79.19898728224125 [ 7.74602116e+00  1.21047684e+00  2.24376640e-02  7.36905865e-02
  3.35543650e+00 -8.27248328e+00  0.00000000e+00  0.00000000e+00
 -1.79945533e+02 -1.79012823e+01  0.00000000e+00  0.00000000e+00]
iteration 5 accepted 78.97860240082639 [ 8.15353185e+00  1.20584760e+00  2.24376640e-02  7.36905865e-02
  3.71293283e+00 -8.28787940e+00  0.00000000e+00  0.00000000e+00
 -1.79665105e+02 -1.78581489e+01  0.00000000e+00  0.00000000e+00]
iteration 6 accepted 77.25365544202994 [ 7.89436327e+00  1.343

iteration 44 failed 16.146162144797845 [ 0.49071538  2.72267212  0.02243766  0.07369059  2.23837856  0.05972772
  0.          0.         -4.33310257  0.81044702  0.          0.        ]
iteration 45 accepted 16.129999274379568 [ 0.19892593  2.74675209  0.02243766  0.07369059  0.97405025  0.08004285
  0.          0.         -1.76941833  0.1322781   0.          0.        ]
iteration 46 failed 16.129999274379568 [ 0.19892593  2.74675209  0.02243766  0.07369059  0.97405025  0.08004285
  0.          0.         -1.76941833  0.1322781   0.          0.        ]
iteration 47 accepted 16.098823059404282 [ 0.1959196   2.71240428  0.02243766  0.07369059  0.97612047  0.02217629
  0.          0.         -1.765779    0.18848328  0.          0.        ]
iteration 48 accepted 16.09727650018056 [ 0.18350292  2.74104139  0.02243766  0.07369059  0.91693798 -0.00877375
  0.          0.         -1.09279693  0.05001971  0.          0.        ]
iteration 49 failed 16.097276500180566 [ 0.18350292  2.74104139  

iteration 30 accepted 15.205856888689862 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 31 failed 15.205856888689864 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 32 failed 15.205856888689864 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 33 failed 15.205856888689864 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 34 failed 15.205856888689864 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 35 failed 15.205856888689864 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 36 failed 15.205856888689864 [ 3.60215695  2.72770904  0.02243766  0.07369059 -5.06951977  1.12678279
  0.          0.        ]
iteration 37 accepted 14.2351514

# Receding horizon

In [None]:
from ilqr.controller import RecedingHorizonController

controller = RecedingHorizonController(x0, analytic_ilqr)  # can also use finite_ilqr
rew = []
receding_xs = []
receding_us = []
for x, u in controller.control(us_init, subsequent_n_iterations=10):
    ob, r, done, info = env.step(u)
    receding_xs.append(x)
    receding_us.append(u)
    rew.append(r)
    print('iteration', len(rew), r, x, u)
    if len(rew) == 50:
        break

# Rollouts

In [6]:
import time

def rollout(env, dynamics, x0, us, render=False):
    dynamics.set_state(x0)
    if render:
        env.render()
    rew = []
    actual_xs = []
    for u in us:
        _obs, r, done, info = env.step(u)
        rew.append(r)
        actual_xs.append(dynamics.get_state())
        assert not done
        if render:
            env.render()
            time.sleep(0.05)
    return rew, actual_xs

In [7]:
rews = {}
actual_xs = {}
for k, solved_us in us.items():
    rews[k], actual_xs[k] = rollout(env.unwrapped, dynamics[k], x0s[k], solved_us, render=False)
summary = {k: sum(r) for k, r in rews.items()}
print(summary)

{'mujoco_py': -11.768968149064781, 'my_all': -19.136440316727295, 'my_recommended': -19.136440316727295, 'my_basic_plus': -15.940242778681176, 'my_basic': -11.768968149064781}
