In [1]:
import sapien.core as sapien
from Tools import misc, ForwardKinematics, SimWorker, NumForwardDynamicsDer
import numpy as onp
import jax.numpy as np
import jax
from jax import jit, jacfwd, jacrev, grad
from ilqr import ILQR
from tqdm.notebook import trange

# from jax.config import config
# config.update("jax_debug_nans", True)

Using default glsl path /home/zack/anaconda3/envs/ml/lib/python3.7/site-packages/sapien/glsl_shader/130


In [2]:
sim = sapien.Engine()
renderer = sapien.OptifuserRenderer()
sim.set_renderer(renderer)
render_controller = sapien.OptifuserController(renderer)

stabled = False

def create_scene(timestep, visual):
    s = sim.create_scene([0,0,0])
    s.add_ground(-1)
    s.set_timestep(timestep)

    loader = s.create_urdf_loader()
    loader.fix_root_link = True
    if visual:
        loader.collision_is_visual = True
        s.set_ambient_light([0.5, 0.5, 0.5])
        s.set_shadow_light([0, 1, -1], [0.5, 0.5, 0.5])

    # build
    robot = loader.load("../../../assets/Arm/panda.urdf")

    return s, robot

sim_timestep = 1/60
optim_timestep = 1/60
s0, robot = create_scene(sim_timestep, True)

render_controller.set_camera_position(-5, 0, 0)
render_controller.set_current_scene(s0)

In [3]:
if not stabled:
    for _ in range(3000):
        s0.step()
    stabled = True


In [4]:
def smooth_abs(x, alpha):
    return np.sum((alpha ** 2) * (np.cosh(x / alpha) - 1))

robo_pose = robot.get_root_pose()

@jit
def final_cost(x, alpha=0.3):
    # add base pose
    x = np.concatenate((x, robo_pose.p, robo_pose.q))

    cart_pos = fk.fk(x).reshape(-1, 3)[:, :3]
    end_effector_pos = cart_pos[-3]

    target_height = x[-1] + 1

    term1 = smooth_abs((target_height - end_effector_pos[2]) * 2, alpha)

    return term1 * 10

@jit
def running_cost(x, u, alpha=0.3):
    return np.sum(smooth_abs(u/300, alpha))


vx = jacfwd(final_cost)
vxx = jacrev(vx)

In [5]:
state = misc.get_state(robot)
num_x = len(state)
num_u = robot.dof
dof = robot.dof

num_deri = NumForwardDynamicsDer(robot, sim_timestep)
fk = ForwardKinematics(robot)
sim_worker = SimWorker(robot, create_scene, optim_timestep,True)

Worker-0 started


In [6]:
u_range = np.array([[-100] * robot.dof, [100] * robot.dof])
pred_time = 1
horizon = int(pred_time / optim_timestep) + 1
horizon = 100
per_iter = 3

ilqr = ILQR(final_cost, running_cost, None, u_range, horizon, per_iter, num_deri, sim_worker, True)

In [7]:
#prep seq
x_seq = []
u_seq = list(np.ones((horizon, dof)))
pack_seq = []

bak_pack = robot.pack()

for i in range(horizon):
    x = misc.get_state(robot)
    pack = robot.pack()

    x_seq.append(x)
    pack_seq.append(pack)

    u = u_seq[i]
    robot.set_qf(u)
    s0.step()
robot.unpack(bak_pack)

In [8]:
# ilqr.predict(x_seq, u_seq, pack_seq)

In [9]:
ctrl = []

render_controller.show_window()
for i in trange(1000):
   x_seq, u_seq, pack_seq = ilqr.predict(x_seq, u_seq, pack_seq)

   u = u_seq[0]
   ctrl.append(u)
   print(u)

   robot.set_qf(u)
   s0.step()
   s0.update_render()
   render_controller.render()

   new_x = misc.get_state(robot)
   new_pack = robot.pack()

   x_seq[0] = new_x
   pack_seq[0] = new_pack

np.save('ctrl', ctrl)

HBox(children=(FloatProgress(value=0.0, max=1000.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='ILQR', max=3.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

[0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0.]
[ 3.7000727e-22  1.4743888e-21 -4.0099030e-21 -3.1643742e-23
  6.6991665e-21  2.5091144e-20  3.1091454e-20 -1.2683403e-19
  4.2728166e-20]
[0. 0. 0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0.]


-------------ITER 37------------------------------
k
	nan

kk
	nan

inv_qq
	nan

lx
	0.0

lu
	8.407790785948902e-45

lxx
	0.0

luu
	1.1111112144135404e-05

lux
	0.0

fx
	1.556087851524353

fu
	0.004467773716896772

vx
	371568995729408.0

vxx
	inf

qx
	576692976353280.0

qu
	296410218496.0

qxx
	nan

quu
	nan

qux
	nan

x
	3.3093738555908203

u
	7.346839692639297e-40

last_x
	3.3093738555908203




HBox(children=(FloatProgress(value=0.0, description='ILQR', max=3.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='ILQR', max=3.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='ILQR', max=3.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='ILQR', max=3.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

HBox(children=(FloatProgress(value=0.0, description='forward', max=99.0, style=ProgressStyle(description_width…

HBox(children=(FloatProgress(value=0.0, description='ILQR', max=3.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=99.0, style=ProgressStyle(description_widt…

Exception: ILQR Invalid

In [None]:
for u in ctrl:
   robot.set_qf(u)
   s0.step()
   s0.update_render()

   render_controller.render()
