In [3]:
from ilqr import iLQR
import gym
import numpy as np
import pandas as pd
import time

from aprl.agents import MujocoFiniteDiffDynamics, MujocoFiniteDiffCost
from aprl.agents.mujoco_control import MujocoFiniteDiffDynamicsLowLevel, MujocoFiniteDiffDynamicsPerformance

In [2]:
# Environment setup
env = gym.make('Reacher-v2').unwrapped
env.frame_skip = 1
env.seed(42)
_obs = env.reset()

In [4]:
# Planning setup
N = 100  # planning horizon
us_init = np.array([env.action_space.sample() for _ in range(N)])

dynamics = {
    # Uses mujoco_py's MjSimState. This saves time, qpos, qvel, act and udd_state.
    #'mujoco_py': MujocoFiniteDiffDynamics(env),
    # Uses my MujocoRelevantSimState, which contains all fields that MuJoCo's derivative.cpp copies.
    #'my_all': MujocoFiniteDiffDynamicsLowLevel(env, kind='all'),
    # All fields I think matter; excludes qfrc and xfrc_applied
    #'my_recommended': MujocoFiniteDiffDynamicsLowLevel(env, kind='recommended'),
    # qpos, qvel and qacc; no warmstart.
    #'my_basic_plus': MujocoFiniteDiffDynamicsLowLevel(env, kind='basic_plus'),
    # As above, but restricts to fields qpos & qvel. 
    # I expect this to match mujoco_py, since time does not matter, and act and udd_state are blank for Reacher.
    'my_basic': MujocoFiniteDiffDynamicsLowLevel(env, kind='basic'),
    # Like my_basic, but saving qacc_warmstart.
    'my_performance': MujocoFiniteDiffDynamicsPerformance(env),
}
x0s = {k: dyn.get_state() for k, dyn in dynamics.items()}

# Finite difference cost

In [None]:
finite_cost = MujocoFiniteDiffCost(env)
finite_ilqr = iLQR(dynamics, finite_cost, N)

In [None]:
finite_xs, finite_us = finite_ilqr.fit(x0, us_init, n_iterations=100, on_iteration=on_iteration)

# Analytic cost

In [5]:
import theano
from theano import tensor as T
from ilqr.cost import AutoDiffCost

# Reacher, Gym observation:
# obs[0:1]: xs; np.cos(qpos[0:2]) (qpos[0] is joint0, qpos[1] is joint1)
# obs[2:3]: ys; np.sin(qpos[0:2]);
# obs[4:5]: goal x and y; qpos[2:]; (target_x and target_y)
# obs[6:7]: theta dot
# obs[8:9]: xy of fingertip - target

def make_reacher_cost(kind, control_weight=1.0):
    # qpos[0:3]: theta of joint 0, theta of joint 1; target x and y.
    qpos_inputs = [T.dscalar('theta'), T.dscalar('phi'), T.dscalar('targetx'), T.dscalar('targety')]
    # qvel: derivatives of the above; note target x and y are constant so have derivative zero.
    qvel_inputs = [T.dscalar('thetadot'), T.dscalar('phidot'), T.dscalar('_zero1'), T.dscalar('_zero2')]
    # qacc: second derivatives of qpos. We don't actually use these in the cost.
    qacc_inputs = [T.dscalar('_acc{}'.format(i)) for i in range(len(qpos_inputs))]
    # qacc_warmstart: same shape as qacc
    qacc_warmstart_inputs = [T.dscalar('_accwarm{}'.format(i)) for i in range(len(qpos_inputs))]
    # qfrc_applied: same shape as qacc
    qfrc_applied_inputs = [T.dscalar('_qfrc_applied{}'.format(i)) for i in range(len(qpos_inputs))]
    # xfrc_applied: (5,6)
    xfrc_applied_inputs = [T.dscalar('_xfrc_applied{}'.format(i)) for i in range(5 * 6)]
    if kind == 'mujoco_py':
        # Reacher, MJSimState.flatten():
        # obs[0]: time step, obs[1:4]: qpos[0:3]; obs[5:8]: qvel[0:3]
        # In general might include action and udd_state, but not for Reacher.
        x_inputs = [T.dscalar('_time')] + qpos_inputs + qvel_inputs
    elif kind == 'my_all':
        # Reacher, MujocoRelevantState.flatten()
        x_inputs = qpos_inputs + qvel_inputs + qacc_inputs + qacc_warmstart_inputs + qfrc_applied_inputs + xfrc_applied_inputs
    elif kind == 'my_recommended':
        x_inputs = qpos_inputs + qvel_inputs + qacc_inputs + qacc_warmstart_inputs
    elif kind == 'my_basic_plus':
        x_inputs = qpos_inputs + qvel_inputs + qacc_inputs
    elif kind in ['my_basic', 'my_warmstart', 'my_performance']:
        x_inputs = qpos_inputs + qvel_inputs
    else:
        raise ValueError("Unrecognised kind: '{}'".format(kind))
    u_inputs = [T.dscalar('thetadotdot'), T.dscalar('phidotdot')]
    qpos = T.stack(qpos_inputs)
    u = T.stack(u_inputs)
    
    control_cost = T.dot(u, u)
    target_xpos = qpos[2:4]
    body1_xpos = 0.1 * T.stack([T.cos(qpos[0]), T.sin(qpos[0])])
    fingertip_xpos_delta = 0.11 * T.stack([T.cos(qpos[1]), T.sin(qpos[1])])
    fingertip_xpos = body1_xpos + fingertip_xpos_delta
    delta = fingertip_xpos - target_xpos
    state_cost = T.sqrt(T.dot(delta, delta))
    l = state_cost + control_weight * control_cost
    l_terminal = T.zeros(())
    return AutoDiffCost(l, l_terminal, x_inputs, u_inputs)

In [16]:
def on_iteration(iteration_count, xs, us, J_opt, accepted, converged):
    info = "converged" if converged else ("accepted" if accepted else "failed")
    print("iteration", iteration_count, info, J_opt, xs[-1])

costs = {k: make_reacher_cost(k) for k in dynamics.keys()}
ilqrs = {k: iLQR(dyn, costs[k], N) for k, dyn in dynamics.items()}
xs = {}
us = {}
print(ilqrs.keys())
for k, ilqr in ilqrs.items():
    start = time.time()
    print('*** Fitting {} ***'.format(k))
    x0 = x0s[k]
    xs[k], us[k] = ilqr.fit(x0, us_init, n_iterations=100, on_iteration=on_iteration)
    end = time.time()
    print('*** Fitted {} in {}s ***'.format(k, end - start))

dict_keys(['my_basic', 'my_performance'])
*** Fitting my_basic ***
iteration -1 converged 77.0672963001787 [ 2.32864069 -1.19993367  0.02243766  0.07369059  3.22880925 -4.04084829
  0.          0.        ]
iteration 0 accepted 16.85489457400974 [ 2.6949496  -0.24070435  0.02243766  0.07369059  3.40360537 -3.92220608
  0.          0.        ]
iteration 1 accepted 13.977863492838418 [ 2.84173297  0.20747149  0.02243766  0.07369059  3.52506559 -3.71466247
  0.          0.        ]
iteration 2 accepted 12.170881768875319 [ 2.85318372 -0.05022272  0.02243766  0.07369059  3.6027983  -3.36014486
  0.          0.        ]
iteration 3 failed 12.170881768875327 [ 2.85318372 -0.05022272  0.02243766  0.07369059  3.6027983  -3.36014486
  0.          0.        ]
iteration 4 accepted 12.129830429913762 [ 2.75884007  0.05266829  0.02243766  0.07369059  3.51072004 -3.25887641
  0.          0.        ]
iteration 5 accepted 12.124607048985949 [ 2.56258972  0.37131624  0.02243766  0.07369059  3.39298188 -

iteration 44 accepted 7.553103912646529 [-0.15547401  2.27696798  0.02243766  0.07369059  1.6043329  -1.70183766
  0.          0.        ]
iteration 45 failed 7.553103912646528 [-0.15547401  2.27696798  0.02243766  0.07369059  1.6043329  -1.70183766
  0.          0.        ]
iteration 46 failed 7.553103912646528 [-0.15547401  2.27696798  0.02243766  0.07369059  1.6043329  -1.70183766
  0.          0.        ]
iteration 47 accepted 7.552928529182998 [-0.15664361  2.28831727  0.02243766  0.07369059  1.60988271 -1.68860811
  0.          0.        ]
iteration 48 failed 7.552928529183002 [-0.15664361  2.28831727  0.02243766  0.07369059  1.60988271 -1.68860811
  0.          0.        ]
iteration 49 failed 7.552928529183002 [-0.15664361  2.28831727  0.02243766  0.07369059  1.60988271 -1.68860811
  0.          0.        ]
iteration 50 accepted 7.551963435468378 [-0.1568755   2.29686283  0.02243766  0.07369059  1.6168613  -1.68131666
  0.          0.        ]
iteration 51 failed 7.5519634354683

In [None]:
res = {k: dyn.f_u(x0s[k], us_init[0], 0) for k, dyn in dynamics.items()}

In [None]:
res['my_performance']

In [None]:
res['my_basic']

# Receding horizon

In [10]:
from ilqr.controller import RecedingHorizonController

def receding(underlying):
    k = 'receding_' + underlying
    dynamics[k] = dynamics[underlying]
    x0s[k] = x0s[underlying]
    controller = RecedingHorizonController(x0s[k], ilqrs[underlying])
    rew = []
    xs[k] = []
    us[k] = []
    for x, u in controller.control(us_init, subsequent_n_iterations=10):
        ob, r, done, info = env.step(u)
        xs[k].append(x)
        us[k].append(u)
        rew.append(r)
        print('iteration', len(rew), r, x, u)
        if len(rew) == N:
            break

In [11]:
receding('my_basic')
receding('my_performance')

iteration 1 -0.39631516072175293 [[-0.02517132 -0.00313229  0.02243766  0.07369059 -0.00359075 -0.00120678
   0.          0.        ]
 [-0.02571403  0.00114664  0.02243766  0.07369059 -0.10478231  0.8555652
   0.          0.        ]] [[-0.05075596  0.43057972]]
iteration 2 -0.38107781025698595 [[-2.57140348e-02  1.14664135e-03  2.24376640e-02  7.36905865e-02
  -1.04782306e-01  8.55565199e-01  0.00000000e+00  0.00000000e+00]
 [-2.72516856e-02  1.37047499e-02  2.24376640e-02  7.36905865e-02
  -2.02584508e-01  1.65472474e+00  0.00000000e+00  0.00000000e+00]] [[-0.04956816  0.40590877]]
iteration 3 -0.3611488605643892 [[-0.02725169  0.01370475  0.02243766  0.07369059 -0.20258451  1.65472474
   0.          0.        ]
 [-0.02973687  0.03396746  0.02243766  0.07369059 -0.29429858  2.39658014
   0.          0.        ]] [[-0.04700576  0.38110525]]
iteration 4 -0.34335625998103986 [[-0.02973687  0.03396746  0.02243766  0.07369059 -0.29429858  2.39658014
   0.          0.        ]
 [-0.0331164

iteration 32 -0.026378726825605173 [[-0.17205307  1.8107782   0.02243766  0.07369059  0.10203078  5.56683072
   0.          0.        ]
 [-0.17061144  1.86505248  0.02243766  0.07369059  0.18615388  5.28848917
   0.          0.        ]] [[ 0.04275744 -0.11205284]]
iteration 33 -0.033825111462525315 [[-0.17061144  1.86505248  0.02243766  0.07369059  0.18615388  5.28848917
   0.          0.        ]
 [-0.16831064  1.91634692  0.02243766  0.07369059  0.27385944  4.97092693
   0.          0.        ]] [[ 0.04498076 -0.13315674]]
iteration 34 -0.036584250502927046 [[-0.16831064  1.91634692  0.02243766  0.07369059  0.27385944  4.97092693
   0.          0.        ]
 [-0.16520092  1.96455436  0.02243766  0.07369059  0.34796051  4.67106066
   0.          0.        ]] [[ 0.03858398 -0.12585177]]
iteration 35 -0.040500118455699036 [[-0.16520092  1.96455436  0.02243766  0.07369059  0.34796051  4.67106066
   0.          0.        ]
 [-0.16138021  2.00974812  0.02243766  0.07369059  0.41606519  4.3

iteration 63 -0.015649177482454067 [[-0.0399126   2.35671581  0.02243766  0.07369059 -0.02502422 -0.02810448
   0.          0.        ]
 [-0.04018372  2.35635046  0.02243766  0.07369059 -0.02919287 -0.04493661
   0.          0.        ]] [[-0.00222082 -0.00860023]]
iteration 64 -0.016933266010436872 [[-0.04018372  2.35635046  0.02243766  0.07369059 -0.02919287 -0.04493661
   0.          0.        ]
 [-0.0405522   2.35583869  0.02243766  0.07369059 -0.04447742 -0.05739663
   0.          0.        ]] [[-0.00782948 -0.00648711]]
iteration 65 -0.01197646459018898 [[-0.0405522   2.35583869  0.02243766  0.07369059 -0.04447742 -0.05739663
   0.          0.        ]
 [-0.04103096  2.35522585  0.02243766  0.07369059 -0.05126357 -0.06515941
   0.          0.        ]] [[-0.0036338  -0.00418854]]
iteration 66 -0.02245591758021152 [[-0.04103096  2.35522585  0.02243766  0.07369059 -0.05126357 -0.06515941
   0.          0.        ]
 [-0.04156442  2.35458921  0.02243766  0.07369059 -0.05542053 -0.062

iteration 92 -0.0377859790822442 [[-0.04569025  2.35245221  0.02243766  0.07369059  0.01281906  0.00946438
   0.          0.        ]
 [-0.04541617  2.35237602  0.02243766  0.07369059  0.04194882 -0.02464468
   0.          0.        ]] [[ 0.01470706 -0.0170953 ]]
iteration 93 -0.03302751330289988 [[-0.04541617  2.35237602  0.02243766  0.07369059  0.04194882 -0.02464468
   0.          0.        ]
 [-0.04506454  2.35239318  0.02243766  0.07369059  0.02840024  0.02798788
   0.          0.        ]] [[-0.00660052  0.02632928]]
iteration 94 -0.019523570259498326 [[-0.04506454  2.35239318  0.02243766  0.07369059  0.02840024  0.02798788
   0.          0.        ]
 [-0.04482195  2.35258811  0.02243766  0.07369059  0.02013119  0.01102713
   0.          0.        ]] [[-0.00401494 -0.00838445]]
iteration 95 -0.009182775421319848 [[-0.04482195  2.35258811  0.02243766  0.07369059  0.02013119  0.01102713
   0.          0.        ]
 [-0.04450775  2.3526242   0.02243766  0.07369059  0.04267095 -0.0037

iteration 23 -0.018231034815700646 [[-0.15357488  1.23365383  0.02243766  0.07369059 -0.55372756  7.46282478
   0.          0.        ]
 [-0.15882107  1.30761317  0.02243766  0.07369059 -0.49560868  7.32926603
   0.          0.        ]] [[ 0.02639501 -0.02980406]]
iteration 24 -0.009782225531320098 [[-0.15882107  1.30761317  0.02243766  0.07369059 -0.49560868  7.32926603
   0.          0.        ]
 [-0.16345782  1.38013937  0.02243766  0.07369059 -0.43184756  7.17622995
   0.          0.        ]] [[ 0.02952118 -0.04026086]]
iteration 25 -0.01723921644442882 [[-0.16345782  1.38013937  0.02243766  0.07369059 -0.43184756  7.17622995
   0.          0.        ]
 [-0.16734325  1.45098539  0.02243766  0.07369059 -0.34538268  6.99328001
   0.          0.        ]] [[ 0.04125494 -0.05605868]]
iteration 26 -0.012111675439781484 [[-0.16734325  1.45098539  0.02243766  0.07369059 -0.34538268  6.99328001
   0.          0.        ]
 [-0.17041669  1.519864    0.02243766  0.07369059 -0.26943129  6.78

iteration 54 -0.026784340283487328 [[-0.04935378  2.35090059  0.02243766  0.07369059  0.32462277  0.17304299
   0.          0.        ]
 [-0.0464413   2.35240663  0.02243766  0.07369059  0.25798353  0.12823986
   0.          0.        ]] [[-0.03187637 -0.02165298]]
iteration 55 -0.022128511636990048 [[-0.0464413   2.35240663  0.02243766  0.07369059  0.25798353  0.12823986
   0.          0.        ]
 [-0.04424369  2.35346872  0.02243766  0.07369059  0.18166558  0.08425128
   0.          0.        ]] [[-0.0370749  -0.02146772]]
iteration 56 -0.02209108334987302 [[-0.04424369  2.35346872  0.02243766  0.07369059  0.18166558  0.08425128
   0.          0.        ]
 [-0.04290126  2.3543914   0.02243766  0.07369059  0.08697874  0.10025738
   0.          0.        ]] [[-0.04668977  0.00846491]]
iteration 57 -0.025263731582290165 [[-0.04290126  2.3543914   0.02243766  0.07369059  0.08697874  0.10025738
   0.          0.        ]
 [-0.04207428  2.35510024  0.02243766  0.07369059  0.0784325   0.04

iteration 84 -0.012378652281327746 [[-0.04443881  2.35183238  0.02243766  0.07369059  0.04750288  0.01655863
   0.          0.        ]
 [-0.04411101  2.35200321  0.02243766  0.07369059  0.01810578  0.01760641
   0.          0.        ]] [[-0.01454014  0.00060912]]
iteration 85 -0.005070997972735085 [[-0.04411101  2.35200321  0.02243766  0.07369059  0.01810578  0.01760641
   0.          0.        ]
 [-0.04420554  2.35238962  0.02243766  0.07369059 -0.03691972  0.05960406
   0.          0.        ]] [[-0.02756992  0.02119514]]
iteration 86 -0.031002373696067582 [[-0.04420554  2.35238962  0.02243766  0.07369059 -0.03691972  0.05960406
   0.          0.        ]
 [-0.04461591  2.35295418  0.02243766  0.07369059 -0.04514032  0.05331957
   0.          0.        ]] [[-0.00431708 -0.00286058]]
iteration 87 -0.015494905416412625 [[-0.04461591  2.35295418  0.02243766  0.07369059 -0.04514032  0.05331957
   0.          0.        ]
 [-0.04492685  2.35332716  0.02243766  0.07369059 -0.01709356  0.0

# Rollouts

In [7]:
import time

def rollout(env, dynamics, x0, us, render=False):
    dynamics.set_state(x0)
    if render:
        env.render()
    rew = []
    actual_xs = []
    for u in us:
        _obs, r, done, info = env.step(u)
        rew.append(r)
        actual_xs.append(dynamics.get_state())
        assert not done
        if render:
            env.render()
            time.sleep(0.02)
    return rew, actual_xs

In [12]:
rews = {}
actual_xs = {}
for k, solved_us in us.items():
    print(k)
    rews[k], actual_xs[k] = rollout(env.unwrapped, dynamics[k], x0s[k], solved_us, render=True)
rewards = {k: sum(r) for k, r in rews.items()}
lengths = {k: len(r) for k, r in rews.items()}
pd.DataFrame({'rewards': rewards, 'lengths': lengths})

my_basic
my_performance
receding_my_basic
receding_my_performance


Unnamed: 0,rewards,lengths
my_basic,-12.291561,100
my_performance,-7.497279,100
receding_my_basic,-7.410396,100
receding_my_performance,-7.236362,100
