In [60]:
import gym
import sys
import time
import jax.numpy as np
import jax
from tqdm.notebook import trange

In [3]:
import jax.numpy as np
from jax import grad, jacfwd, jacrev, random, jit


class ILQR:
    def __init__(self, final_cost, running_cost, model, u_range, horizon, per_iter, model_der=None):
        '''
            final_cost:     v(x)    ->  cost, float
            running_cost:   l(x, u) ->  cost, float
            model:          f(x, u) ->  new state, [n_x]
        '''
        self.f = model
        self.v = final_cost
        self.l = running_cost

        self.u_range = u_range
        self.horizon = horizon
        self.per_iter = per_iter

        # specify derivatives
        self.l_x = grad(self.l, 0)
        self.l_u = grad(self.l, 1)
        self.l_xx = jacfwd(self.l_x, 0)
        self.l_uu = jacfwd(self.l_u, 1)
        self.l_ux = jacrev(self.l_u, 0)

        self.v_x = grad(self.v)
        self.v_xx = jacfwd(self.v_x)

        if model_der == None:
            self.f_x = jacrev(self.f, 0)
            self.f_u = jacfwd(self.f, 1)
            
            (self.f, self.f_u, self.f_x,) = [jit(e) for e in [self.f, self.f_u, self.f_x,]]
        else:
            # using provided function for step
            self.f_x = model_der.f_x
            self.f_u = model_der.f_u
            

        # speed up
        (self.l, self.l_u, self.l_uu, self.l_ux, self.l_x, self.l_xx,
         self.v, self.v_x, self.v_xx) = \
            [jit(e) for e in [self.l, self.l_u, self.l_uu, self.l_ux, self.l_x, self.l_xx,
                              self.v, self.v_x, self.v_xx]]


    def cal_K(self, x_seq, u_seq):
        '''
            Calculate all the necessary derivatives, and compute the Ks
        '''
        state_dim = x_seq[0].shape[-1]
#         v_seq = [None] * self.horizon
        v_x_seq = [None] * self.horizon
        v_xx_seq = [None] * self.horizon

        last_x = x_seq[-1]
#         v_seq[-1] = self.v(last_x)
        v_x_seq[-1] = self.v_x(last_x)
        v_xx_seq[-1] = self.v_xx(last_x)

        k_seq = [None] * self.horizon
        kk_seq = [None] * self.horizon

        for i in range(self.horizon - 2, -1, -1):
            x, u = x_seq[i], u_seq[i]

            # get all grads
            lx = self.l_x(x, u)
            lu = self.l_u(x, u)
            lxx = self.l_xx(x, u)
            luu = self.l_uu(x, u)
            lux = self.l_ux(x, u)

            fx = self.f_x(x, u)
            fu = self.f_u(x, u)
            fxx = self.f_xx(x, u)
            fuu = self.f_uu(x, u)
            fux = self.f_ux(x, u)

            vx = v_x_seq[i+1]
            vxx = v_xx_seq[i+1]

            # cal Qs
            q_x = lx + fx.T @ vx
            q_u = lu + fu.T @ vx
            q_xx = lxx + fx.T @ vxx @ fx + vx @ fxx
            q_uu = luu + fu.T @ vxx @ fu + (fuu.T @ vx).T
            q_ux = lux + fu.T @ vxx @ fx + (fux.T @ vx).T

            # cal Ks
            inv_quu = np.linalg.inv(q_uu)
            k = - inv_quu @ q_u
            kk = - inv_quu @ q_ux

            # cal Vs
            new_v = q_u @ k / 2
            new_vx = q_x + q_u @ kk
            new_vxx = q_xx + q_ux.T @ kk

            # record
            k_seq[i] = k
            kk_seq[i] = kk
            v_x_seq[i] = new_vx
            v_xx_seq[i] = new_vxx

        return k_seq, kk_seq

    def forward(self, x_seq, u_seq, k_seq, kk_seq):
        new_x_seq = [None] * self.horizon
        new_u_seq = [None] * self.horizon

        new_x_seq[0] = x_seq[0]  # copy

        for i in range(self.horizon - 1):
            x = new_x_seq[i]

            new_u = u_seq[i] + k_seq[i] + kk_seq[i] @ (x - x_seq[i])
            new_u = np.clip(new_u, self.u_range[0], self.u_range[1])
            new_x = self.f(x, new_u)

            new_u_seq[i] = new_u
            new_x_seq[i+1] = new_x

        return new_x_seq, new_u_seq

    def predict(self, x_seq, u_seq):
        for _ in range(self.per_iter):
            k_seq, kk_seq = self.cal_K(x_seq, u_seq)
            x_seq, u_seq = self.forward(x_seq, u_seq, k_seq, kk_seq)

        return u_seq


In [93]:
env = gym.make('Humanoid-v3')
obs = env.reset()


'''
    state info:
        - COM
        - torso pos
        - foot pos
        - COM velocity
'''

def info2array(info):
    concat = []
    for l in info.values():
        concat += l
    
    return np.array(concat)

def array2info(arr):
    info = {
        'com_pos': arr[0:2],
        'com_v': arr[2:4],
        'torso_pos': arr[4:7],
        'lfoot_pos': arr[7:10],
        'rfoot_pos': arr[10:13],
    }
    
    return info

def sim_step(env, action):
    _, _, _, info = env.step(action)
    
    com_pos = [info['x_position'], info['y_position']]
    com_v = [info['x_velocity'], info['y_velocity']]
    torso_pos = env.env.data.get_body_xpos('torso').tolist()
    lfoot_pos = env.env.data.get_body_xpos('left_foot').tolist()
    rfoot_pos = env.env.data.get_body_xpos('right_foot').tolist()
    
    temp = [com_pos, com_v, torso_pos, lfoot_pos, rfoot_pos]
    concat = []
    for l in temp:
        concat += l
        
    return np.array(concat)


def final_cost(x, alpha=0.2):
    com_pos, com_v, torso_pos, lfoot_pos, rfoot_pos = array2info(x).values()
    
    smooth_abs = lambda x : np.sum(np.sqrt(x**2 + alpha**2) - alpha)
    
    # calculate terms
    mean_foot = (lfoot_pos + rfoot_pos) / 2
    term1 = smooth_abs(com_pos - mean_foot[:2])
    
    term2 = smooth_abs(com_pos - torso_pos[:2])
    
    mean_foot_air = jax.ops.index_add(mean_foot, jax.ops.index[2], 1.3)
    term3 = smooth_abs(torso_pos - mean_foot_air)
    
    term4 = np.linalg.norm(com_v)
    
    return term1 + term2 + term3 + term4
    


def running_cost(x, u, alpha=0.3):
    return np.sum((alpha ** 2) * (np.cosh(u/alpha) - 1))


def model(_, u):
    return sim_step(env, u)

def model_deri_fx(x, u, eps):
    orig_env = env.sim.get_state()
    
    der = []
    for i in len(x):
        new_x = x.tolist()
        new_x[i] += eps
        
        der.append(sim_step())
        
        
    
env.sim.get_state().qpos

array([ 8.18700352e-03, -2.65489280e-03,  1.40420801e+00,  9.99955550e-01,
        3.39517384e-03, -8.08512679e-03, -3.46426057e-03,  6.68744986e-03,
        7.94259070e-03, -3.79750563e-04, -7.75025414e-03, -3.85111780e-03,
       -3.22911379e-03, -4.08617108e-03,  6.39850009e-03, -7.13986900e-03,
       -3.23650759e-03, -6.47747199e-03, -8.43558847e-03,  1.63022500e-03,
       -3.41619092e-03, -8.31960013e-03,  3.16487514e-03, -5.24772930e-03])

In [5]:
env.reset()

s = env.sim.get_state()

cnt = 0
while True:
    if cnt > 500:
        cnt = 0
        env.sim.set_state(s)
    cnt+=1
    
    env.step(env.action_space.sample())
    env.render()

Creating window glfw


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [8]:
env.env.model.body_mass.shape

(14,)

In [12]:
env.env.sim.data.xipos.shape

(14, 3)

In [30]:
a = [1,2]
b = [2,3]



[1, 2, 2, 3]

In [62]:
a = np.ones(3)
jax.ops.index_add(a, jax.ops.index[1], 6.)

DeviceArray([1., 7., 1.], dtype=float32)

In [63]:
f = lambda x : x+1
f(1)

2