In [1]:
import sapien.core as sapien
from sapien.core import Pose
import jax.numpy as np
import numpy as onp
from jax import grad, jacfwd, jacrev, random, jit
import jax
import time
import matplotlib.pyplot as plt
from tqdm.notebook import trange, tqdm
from multiprocessing import Queue, Process, Array, Value
import ctypes as c
from transforms3d.quaternions import axangle2quat as aa
import transforms3d.euler as euler
import transforms3d.quaternions as quat

# from jax.config import config
# config.update("jax_enable_x64", True)

%matplotlib notebook

Using default glsl path /home/zack/anaconda3/envs/ml/lib/python3.7/site-packages/sapien/glsl_shader/130


# Construct Scene

In [2]:
# construct poses
ant_poses = {
    'j1': (
        Pose([0.282, 0, 0], [0.7071068, 0, 0.7071068, 0]),
        Pose([0.141, 0, 0], [-0.7071068, 0, 0.7071068, 0])),
    'j2': (
        Pose([-0.282, 0, 0], [0, -0.7071068, 0, 0.7071068]),
        Pose([0.141, 0, 0], [-0.7071068, 0, 0.7071068, 0])),
    'j3': (
        Pose([0, 0.282, 0], [0.5, -0.5, 0.5, 0.5]),
        Pose([0.141, 0, 0], [0.7071068, 0, -0.7071068, 0])),
    'j4': (
        Pose([0, -0.282, 0], [0.5, 0.5, 0.5, -0.5]),
        Pose([0.141, 0, 0], [0.7071068, 0, -0.7071068, 0])),
    'j11': (
        Pose([-0.141, 0, 0], [0, 0.7071068, 0.7071068, 0]),
        Pose([0.282, 0, 0], [0, 0.7071068, 0.7071068, 0])),
    'j21': (
        Pose([-0.141, 0, 0], [0, 0.7071068, 0.7071068, 0]),
        Pose([0.282, 0, 0], [0, 0.7071068, 0.7071068, 0])),
    'j31': (
        Pose([-0.141, 0, 0], [0, 0.7071068, 0.7071068, 0]),
        Pose([0.282, 0, 0], [0, 0.7071068, 0.7071068, 0])),
    'j41': (
        Pose([-0.141, 0, 0], [0, 0.7071068, 0.7071068, 0]),
        Pose([0.282, 0, 0], [0, 0.7071068, 0.7071068, 0])),
}

In [3]:
copper = sapien.PxrMaterial()
copper.set_base_color([0.875, 0.553, 0.221, 1])
copper.metallic = 1
copper.roughness = 0.2

def create_ant_builder(scene):
    builder = scene.create_articulation_builder()
    body = builder.create_link_builder()
    body.add_sphere_shape(Pose(), 0.25)
    body.add_sphere_visual_complex(Pose(), 0.25, copper)
    body.add_capsule_shape(Pose([0.141, 0, 0]), 0.08, 0.141)
    body.add_capsule_visual_complex(Pose([0.141, 0, 0]), 0.08, 0.141, copper)
    body.add_capsule_shape(Pose([-0.141, 0, 0]), 0.08, 0.141)
    body.add_capsule_visual_complex(Pose([-0.141, 0, 0]), 0.08, 0.141, copper)
    body.add_capsule_shape(Pose([0, 0.141, 0], aa([0, 0, 1], np.pi / 2)), 0.08, 0.141)
    body.add_capsule_visual_complex(Pose([0, 0.141, 0], aa([0, 0, 1], np.pi / 2)), 0.08, 0.141, copper)
    body.add_capsule_shape(Pose([0, -0.141, 0], aa([0, 0, 1], np.pi / 2)), 0.08, 0.141)
    body.add_capsule_visual_complex(Pose([0, -0.141, 0], aa([0, 0, 1], np.pi / 2)), 0.08, 0.141, copper)
    body.set_name("body")

    l1 = builder.create_link_builder(body)
    l1.set_name("l1")
    l1.set_joint_name("j1")
    l1.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[-0.5236, 0.5236]],
                            ant_poses['j1'][0], ant_poses['j1'][1], 0.1)
    l1.add_capsule_shape(Pose(), 0.08, 0.141)
    l1.add_capsule_visual_complex(Pose(), 0.08, 0.141, copper)

    l2 = builder.create_link_builder(body)
    l2.set_name("l2")
    l2.set_joint_name("j2")
    l2.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[-0.5236, 0.5236]],
                            ant_poses['j2'][0], ant_poses['j2'][1], 0.1)
    l2.add_capsule_shape(Pose(), 0.08, 0.141)
    l2.add_capsule_visual_complex(Pose(), 0.08, 0.141, copper)

    l3 = builder.create_link_builder(body)
    l3.set_name("l3")
    l3.set_joint_name("j3")
    l3.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[-0.5236, 0.5236]],
                            ant_poses['j3'][0], ant_poses['j3'][1], 0.1)
    l3.add_capsule_shape(Pose(), 0.08, 0.141)
    l3.add_capsule_visual_complex(Pose(), 0.08, 0.141, copper)

    l4 = builder.create_link_builder(body)
    l4.set_name("l4")
    l4.set_joint_name("j4")
    l4.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[-0.5236, 0.5236]],
                            ant_poses['j4'][0], ant_poses['j4'][1], 0.1)
    l4.add_capsule_shape(Pose(), 0.08, 0.141)
    l4.add_capsule_visual_complex(Pose(), 0.08, 0.141, copper)

    f1 = builder.create_link_builder(l1)
    f1.set_name("f1")
    f1.set_joint_name("j11")
    f1.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[0.5236, 1.222]],
                            ant_poses['j11'][0], ant_poses['j11'][1], 0.1)
    f1.add_capsule_shape(Pose(), 0.08, 0.282)
    f1.add_capsule_visual_complex(Pose(), 0.08, 0.282, copper)

    f2 = builder.create_link_builder(l2)
    f2.set_name("f2")
    f2.set_joint_name("j21")
    f2.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[0.5236, 1.222]],
                            ant_poses['j21'][0], ant_poses['j21'][1], 0.1)
    f2.add_capsule_shape(Pose(), 0.08, 0.282)
    f2.add_capsule_visual_complex(Pose(), 0.08, 0.282, copper)

    f3 = builder.create_link_builder(l3)
    f3.set_name("f3")
    f3.set_joint_name("j31")
    f3.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[0.5236, 1.222]],
                            ant_poses['j31'][0], ant_poses['j31'][1], 0.1)
    f3.add_capsule_shape(Pose(), 0.08, 0.282)
    f3.add_capsule_visual_complex(Pose(), 0.08, 0.282, copper)

    f4 = builder.create_link_builder(l4)
    f4.set_name("f4")
    f4.set_joint_name("j41")
    f4.set_joint_properties(sapien.ArticulationJointType.REVOLUTE, [[0.5236, 1.222]],
                            ant_poses['j41'][0], ant_poses['j41'][1], 0.1)
    f4.add_capsule_shape(Pose(), 0.08, 0.282)
    f4.add_capsule_visual_complex(Pose(), 0.08, 0.282, copper)

    return builder

In [4]:
sim = sapien.Engine()
renderer = sapien.OptifuserRenderer()
sim.set_renderer(renderer)
render_controller = sapien.OptifuserController(renderer)

stabled = False 

def create_scene(timestep, visual):
    s = sim.create_scene()
    s.add_ground(-1)
    s.set_timestep(timestep)

    loader = s.create_urdf_loader()
    loader.fix_root_link = 0
    if visual:
        loader.collision_is_visual = True
        s.set_ambient_light([0.5, 0.5, 0.5])
        s.set_shadow_light([0, 1, -1], [0.5, 0.5, 0.5])
    
    # build
    ant_builder = create_ant_builder(s)   
    ant = ant_builder.build()
    
    return s, ant

sim_timestep = 1/30
optim_timestep = 1/20
s0, ant = create_scene(sim_timestep, True)

render_controller.set_camera_position(-5, 0, 0)
render_controller.set_current_scene(s0)

# Wrap IO

## State
    - COM pos (x y)
        - 13 * 2
    - foot pos (x y z)
        - additional 1
    - torso pos (x y z)
        - additional 1
    - COM velo?
        - additional 13 * 3
## Action
    - 21 params

In [5]:
links_cache = {}
joints_cache = {}

def get_links(robot, sig_mass=1):
    '''
        Return all the link that have significant mass for the robot
    '''
    if robot in links_cache:
        return links_cache[robot]
    
    links = {}
    
    for l in robot.get_links():
        name = l.get_name()
        mass = l.get_mass()
        if mass > sig_mass:
            links[name] = l
    
    links_cache[robot] = links
    return links

def get_joints(robot):
    '''
        Return all the joints
    '''
    if robot in joints_cache:
        return joints_cache[robot]
    
    joints = {}
    
    for j in robot.get_joints():
        if j.get_dof() > 0:
            name = j.get_name()
            joints[name] = j
    
    joints_cache[robot] = joints
    return joints

def get_target_link(links, target_body_names):
    target_body_parts = {}
    
    for name, l in links.items():
        if name in target_body_names:
            target_body_parts[name] = l
    
    return target_body_parts

def report_link(l):
    '''
        Report name, mass, velocity, and position of the link.
    '''
    name = l.get_name()
    mass = l.get_mass()
    velo = l.get_velocity()
    pos = l.get_pose().p
    print(f"{name}\n\tmass: {mass:.6f}\n\tvelo: {velo}\n\tpos:  {pos}\n")

def report_all_links(links):
    '''
        Repot info for all links.
    '''
    for _, l in links.items():
        report_link(l)
        
def get_mass(links):
    '''
        Input:
            links: the links dict
        
        Output:
            array of masses for each link
    '''
    mass = []
    for name, l in links.items():
        mass.append(l.get_mass())
    return np.array(mass)

def get_link_state(robot):
    '''
        Input:
            robot
            
        Output:
            array of shape (n_l, 6)
    '''
    links = get_links(robot)
    state = []
    
    for name, l in links.items():
        state.append(np.concatenate((l.get_pose().p, l.get_velocity())))
    
    return np.array(state)

def get_joint_state(robot):
    '''
        Input:
            robot
            
        Output:
            array of shape (n_j, 7)
    '''
    joints = get_joints(robot)
    state = []
    
    for name, j in joints.items():
        pose = j.get_global_pose()
        state.append(np.concatenate((pose.p, pose.q)))
    
    return np.array(state)

def get_state(robot):
    '''
        dof + 3 + 4
        [qpos(8), pos(3), quat(4)]
    '''
    ant_pos = ant.get_pose()
    return np.concatenate((robot.get_qpos(), ant_pos.p, ant_pos.q))

In [6]:
def compose_p(pose):
    '''
        Get the affine matrix from pose
    '''
    
    return compose(pose.p, pose.q)


def m_quat2mat(q):
    '''
        Borrowed from transform3d, altered a bit
    '''
    float_eps = 2.220446049250313e-16
    
    w, x, y, z = q
    Nq = w*w + x*x + y*y + z*z
#     if Nq < float_eps:
#         return np.eye(3)
    s = 2.0/Nq
    X = x*s
    Y = y*s
    Z = z*s
    wX = w*X; wY = w*Y; wZ = w*Z
    xX = x*X; xY = x*Y; xZ = x*Z
    yY = y*Y; yZ = y*Z; zZ = z*Z
    return np.array(
           [[ 1.0-(yY+zZ), xY-wZ, xZ+wY ],
            [ xY+wZ, 1.0-(xX+zZ), yZ-wX ],
            [ xZ-wY, yZ+wX, 1.0-(xX+yY) ]])


def m_mat2quat(M):
    '''
        Borrowed from transform3d, altered a bit
    '''
    Qxx, Qyx, Qzx, Qxy, Qyy, Qzy, Qxz, Qyz, Qzz = M.flatten()
    # Fill only lower half of symmetric matrix
    K = np.array([
        [Qxx - Qyy - Qzz, 0,               0,               0              ],
        [Qyx + Qxy,       Qyy - Qxx - Qzz, 0,               0              ],
        [Qzx + Qxz,       Qzy + Qyz,       Qzz - Qxx - Qyy, 0              ],
        [Qyz - Qzy,       Qzx - Qxz,       Qxy - Qyx,       Qxx + Qyy + Qzz]]
        ) / 3.0
    # Use Hermitian eigenvectors, values for speed
    vals, vecs = np.linalg.eigh(K)
    # Select largest eigenvector, reorder to w,x,y,z quaternion
    q = vecs[[3, 0, 1, 2], np.argmax(vals)]
    # Prefer quaternion with positive w
    # (q * -1 corresponds to same rotation as q)
#     if q[0] < 0:
#         q *= -1
    return q * q[0] / np.abs(q[0])


def decompose(m):
    '''
        Get the pose
    '''
    p = m[:3, -1]
    q = m_mat2quat(m[:3,:3])
#     q = quat.mat2quat(m[:3,:3])
    
    return np.concatenate((p,q))


def compose(p, q):
    '''
        Get the affine matrix from pose
    '''
    rot = m_quat2mat(q)
#     rot = quat.quat2mat(q)
    
    affine_matrix = np.eye(4)
    
    affine_matrix = jax.ops.index_update(affine_matrix, jax.ops.index[:3, :3], rot)
    affine_matrix = jax.ops.index_update(affine_matrix, jax.ops.index[:3, -1], p)
    
    return np.array(affine_matrix)

In [7]:
def pose2mat(poses_dict):
    mat_dic = {}
    
    for name, pose in poses_dict.items():
        m1 = compose_p(pose[0]).tolist()
        m2 = compose_p(pose[1]).tolist()
        mat_dic[name] = (np.array(m1), np.array(onp.linalg.inv(m2)))
        
    return mat_dic

ant_pos_mat = pose2mat(ant_poses)

In [8]:

def m_axangle2mat(axis, angle):
    '''
        Borrowed from transform3d
    '''
    x, y, z = axis
    n = np.sqrt(x*x + y*y + z*z)
    x = x/n
    y = y/n
    z = z/n
    c = np.cos(angle); s = np.sin(angle); C = 1-c
    xs = x*s;   ys = y*s;   zs = z*s
    xC = x*C;   yC = y*C;   zC = z*C
    xyC = x*yC; yzC = y*zC; zxC = z*xC
    return np.array([
            [ x*xC+c,   xyC-zs,   zxC+ys ],
            [ xyC+zs,   y*yC+c,   yzC-xs ],
            [ zxC-ys,   yzC+xs,   z*zC+c ]])



def kforward_step(parent_mat, pose_mats, theta):
    '''
        pp:          parent frame pos
        pq:          parent frame quat
        pose_mast:   joint pose mat
        theta:       the revol joint theta
    '''
#     t_cos = np.cos(theta)
#     t_sin = np.sin(theta)
#     joint_mat = np.array([
#         [1, 0, 0, 0],
#         [0, t_cos, t_sin, 0],
#         [0, -t_sin, t_cos, 0],
#         [0, 0, 0, 1]])
    joint_mat = np.eye(4)
    m = m_axangle2mat([1,0,0], theta)
    joint_mat = jax.ops.index_update(joint_mat, jax.ops.index[:3, :3], m)
    
    
    jp_mat = pose_mats[0]
    jc_mat = pose_mats[1]
    
    return parent_mat @ jp_mat @ joint_mat @ jc_mat

In [9]:
dof = ant.dof

def forward_kinematics(state):
    '''
        dof + 3 + 4
        [qpos(8), pos(3), quat(4)]
        
        return the poses
    '''
    qs = state[:dof]
    
    ant_pos = state[dof:dof+3]
    ant_quat = state[-4:]
    
    ant_mat = compose(ant_pos, ant_quat)
    
    # forwarding
    l1 = kforward_step(ant_mat, ant_pos_mat['j1'], qs[0])
    l2 = kforward_step(ant_mat, ant_pos_mat['j2'], qs[1])
    l3 = kforward_step(ant_mat, ant_pos_mat['j3'], qs[2])
    l4 = kforward_step(ant_mat, ant_pos_mat['j4'], qs[3])
    
    l11 = kforward_step(l1, ant_pos_mat['j11'], qs[4])
    l21 = kforward_step(l2, ant_pos_mat['j21'], qs[5])
    l31 = kforward_step(l3, ant_pos_mat['j31'], qs[6])
    l41 = kforward_step(l4, ant_pos_mat['j41'], qs[7])
    
    links = (l1, l2, l3, l4, l11, l21, l31, l41)
    
    return np.concatenate([decompose(m)[:3] for m in links])

In [10]:
class Worker(Process):
    def __init__(self, request_queue, ans_arr, num_task, timestep, workId):
        super(Worker, self).__init__()
        self.request_queue = request_queue
        self.ans_arr = ans_arr
        self.num_task = num_task
        self.workId = workId
        
        # init scene
        self.scene, self.robot = create_scene(timestep, False)
        self.dof = self.robot.dof
        
#         self.links = self.get_links()
    
#     def get_links(self):
#         '''
#             Return all the link that have significant mass for the robot

#             target_body_parts : [str body parts]
#         '''
#         links = {}

#         for l in self.robot.get_links():
#             name = l.get_name()
#             mass = l.get_mass()
#             if mass > 1:
#                 links[name] = l

#         return links    
    
    def get_state(self):
        qpos = self.robot.get_qpos()
        pose = self.robot.get_pose()
        
        return onp.concatenate((qpos, pose.p, pose.q))
    
    def fu(self, ini_pack, u, index):
        self.robot.unpack(ini_pack)

        # simulate
        self.robot.set_qf(u)

        self.scene.step()
        new_state = self.get_state()

        self.ans_arr[index] = new_state
        
    def fx(self, ini_pack, u, state, new_val, index):
        self.robot.unpack(ini_pack)
        
        if state == 'qpos':
            self.robot.set_qpos(new_val)
        elif state == 'robo_pos':
            p, q = new_val
            self.robot.set_pose(Pose(p,q))

        # simulate
        self.robot.set_qf(u)
        self.scene.step()

        self.ans_arr[index] = self.get_state()
    
    '''
        Args:
            - task_arg
            - i
    '''
    def run(self):
        print(f'Worker-{self.workId} started')
        
        for task_arg, index, workType in iter(self.request_queue.get, None):
            if workType == 'fu':
                ini_pack, new_u = task_arg
                self.fu(ini_pack, new_u, index)
            elif workType == 'fx':
                ini_pack, u, state, new_val = task_arg
                self.fx(ini_pack, u, state, new_val, index)
                
            with self.num_task.get_lock():
                self.num_task.value -= 1
                
        print(f'Worker-{self.workId} exits')

In [11]:
def create_workers(num_workers, m, n, optim_timestep):
    request_queue = Queue()
    mp_arr = Array(c.c_double, m * n)
    ans_arr = onp.frombuffer(mp_arr.get_obj())
    ans_arr = ans_arr.reshape(m, n)
    num_task = Value('i', 1)
    
    # create worker
    for i in range(num_workers):
        wok = Worker(request_queue, ans_arr, num_task, optim_timestep, i)
        wok.start()
        
    return request_queue, ans_arr, num_task

num_x = len(get_state(ant))
num_u = len(ant.get_qf())

def mp_num_fu(request_queue, ans_arr, u, robot, num_task, eps=1e-3):
    '''
        Only doing pos now, return (n_x, n_u)
    '''
    
    ini_pack = robot.pack()
    ini_state = get_state(robot)
    
    res = []
#     passive_force = robot.compute_passive_force()
    u = u.tolist()

    with num_task.get_lock():
        num_task.value = num_u

    # dispatch work
    for i in range(num_u):
        # prep args
        new_u = u.copy()
        new_u[i] += eps
        
        ini_pack = ini_pack
        task_arg = (ini_pack, new_u)

        wok_args = (task_arg, i, 'fu')
        request_queue.put(wok_args)

    # block until done
    while num_task.value > 0:
        time.sleep(0.00001)
    
    res = (ans_arr[:num_u] - ini_state) / eps
    res = res.T
    
    return res

def mp_num_fx(request_queue, ans_arr, u, robot, num_task, eps=1e-3):
    '''
        Only doing pos now, return (n_x, n_x)
    '''
    ini_pack = robot.pack()
    ini_state = get_state(robot)
    
    links = get_links(robot)
    
    u = u.tolist()
    
    with num_task.get_lock():
        num_task.value = num_x
    
    for i in range(num_x):
        # prep args
        state = None
        new_val = None
        
        # ini_pack, u, state, new_val, index
        if i < dof:
            state = 'q_pos'
            new_val = ini_state[:dof].tolist()
            new_val[i] += eps
        else:
            state = 'robo_pos'
            p = ini_state[dof: dof+3].tolist()
            q = ini_state[-4:].tolist()
            
            j = i - dof
            
            if j < 3:
                p[j] += eps
            else:
                j -= 3
                q[j] += eps
                
            new_val = (p, q)
        
        task_args = (ini_pack, u, state, new_val)
        wok_args = (task_args, i, 'fx')
        
        request_queue.put(wok_args)
            
        
    while num_task.value > 0:
        time.sleep(0.00001)
    
    res = (ans_arr[:num_x] - ini_state) / eps
    res = np.array(res).T
    
    return res

# ILQR

In [12]:
class ILQR:
    def __init__(self, final_cost, running_cost, model, u_range, horizon, per_iter, model_der=None):
        '''
            final_cost:     v(x)    ->  cost, float
            running_cost:   l(x, u) ->  cost, float
            model:          f(x, u) ->  new state, [n_x]
        '''
        self.f = model
        self.v = final_cost
        self.l = running_cost

        self.u_range = u_range
        self.horizon = horizon
        self.per_iter = per_iter

        # specify derivatives
        self.l_x = grad(self.l, 0)
        self.l_u = grad(self.l, 1)
        self.l_xx = jacfwd(self.l_x, 0)
        self.l_uu = jacfwd(self.l_u, 1)
        self.l_ux = jacrev(self.l_u, 0)

        self.v_x = jacrev(self.v)
        self.v_xx = jacfwd(self.v_x)

        if model_der == None:
            self.f_x = jacrev(self.f, 0)
            self.f_u = jacfwd(self.f, 1)
            
            (self.f, self.f_u, self.f_x,) = [jit(e) for e in [self.f, self.f_u, self.f_x,]]
        else:
            # using provided function for step
            self.f_x = model_der['f_x']
            self.f_u = model_der['f_u']
            

        # speed up
        (self.l, self.l_u, self.l_uu, self.l_ux, self.l_x, self.l_xx,
         self.v, self.v_x, self.v_xx) = \
            [jit(e) for e in [self.l, self.l_u, self.l_uu, self.l_ux, self.l_x, self.l_xx,
                              self.v, self.v_x, self.v_xx]]


    def cal_K(self, x_seq, u_seq):
        '''
            Calculate all the necessary derivatives, and compute the Ks
        '''
        state_dim = x_seq[0].shape[-1]
#         v_seq = [None] * self.horizon
        v_x_seq = [None] * self.horizon
        v_xx_seq = [None] * self.horizon

        last_x = x_seq[-1]
#         v_seq[-1] = self.v(last_x)
        v_x_seq[-1] = self.v_x(last_x)
        v_xx_seq[-1] = self.v_xx(last_x)

        k_seq = [None] * self.horizon
        kk_seq = [None] * self.horizon

        for i in tqdm(range(self.horizon - 2, -1, -1), desc='backward', leave=False):
            x, u = x_seq[i], u_seq[i]

            # get all grads
            lx = self.l_x(x, u)
            lu = self.l_u(x, u)
            lxx = self.l_xx(x, u)
            luu = self.l_uu(x, u)
            lux = self.l_ux(x, u)

            fx = self.f_x(x, u)
            fu = self.f_u(x, u)
#             fxx = self.f_xx(x, u)
#             fuu = self.f_uu(x, u)
#             fux = self.f_ux(x, u)

            vx = v_x_seq[i+1]
            vxx = v_xx_seq[i+1]
        
            # cal Qs
            q_x = lx + fx.T @ vx
            q_u = lu + fu.T @ vx
            q_xx = lxx + fx.T @ vxx @ fx
            q_uu = luu + fu.T @ vxx @ fu
            q_ux = lux + fu.T @ vxx @ fx
#             q_xx = lxx + fx.T @ vxx @ fx + vx @ fxx
#             q_uu = luu + fu.T @ vxx @ fu + (fuu.T @ vx).T
#             q_ux = lux + fu.T @ vxx @ fx + (fux.T @ vx).T
    
            # cal Ks
            inv_quu = np.linalg.inv(q_uu)
            
            k = - inv_quu @ q_u
            kk = - inv_quu @ q_ux
            
#             if i == 0:
            names = ['k', 'kk', 'inv_qq','lx', 'lu', 'lxx', 'luu', 'lux', 'fx', 'fu', 'vx', 'vxx', 'qx', 'qu', 'qxx', 'quu', 'qux']
            Ms = [k, kk, inv_quu, lx, lu, lxx, luu, lux, fx, fu, vx, vxx, q_x, q_u, q_xx, q_uu, q_ux]

            print(f"\n\n-------------ITER {i}------------------------------")
            for n, m in zip(names, Ms):
                print(f"{n}\n\t{np.max(m)}\n")
            
#             print(f"INVQUU:\n\t {inv_quu}")
#             print(f"QU:\n\t {inv_quu}")
#             print(f"k:\n\t {k}\n\n\n")

            # cal Vs
            new_v = q_u @ k / 2
            new_vx = q_x + q_u @ kk
            new_vxx = q_xx + q_ux.T @ kk

            # record
            k_seq[i] = k
            kk_seq[i] = kk
            v_x_seq[i] = new_vx
            v_xx_seq[i] = new_vxx

        return k_seq, kk_seq

    def forward(self, x_seq, u_seq, k_seq, kk_seq):
        new_x_seq = [None] * self.horizon
        new_u_seq = [None] * self.horizon

        new_x_seq[0] = x_seq[0]  # copy

        for i in trange(self.horizon - 1, desc='forward', leave=False):
            x = new_x_seq[i]

            new_u = u_seq[i] + k_seq[i] + kk_seq[i] @ (x - x_seq[i])
            new_x = self.f(x, new_u)

            new_u_seq[i] = new_u
            new_x_seq[i+1] = new_x

        return new_x_seq, new_u_seq

    def predict(self, x_seq, u_seq):
        for _ in trange(self.per_iter, desc='ILQR', leave=False):
            k_seq, kk_seq = self.cal_K(x_seq, u_seq)
            
            x_seq, u_seq = self.forward(x_seq, u_seq, k_seq, kk_seq)
        
        u_seq[-1] = u_seq[-2] # filling
        return np.array(x_seq), np.array(u_seq)


In [13]:
def sim_step(scene, robot, action):
    robot.set_qf(action)
    scene.step()
        
    return get_state(robot)


masses = get_mass(get_links(ant))
masses = np.expand_dims(masses, axis=1)
mass_sum = np.sum(masses)

def final_cost(x, alpha=0.2):
    '''
        goal: stand
            - com close to body center
            - body points upward
            - body center 0.3 above ground
    '''
    smooth_abs = lambda x : np.sum(np.sqrt(x**2 + alpha**2) - alpha)
    
    pose = forward_kinematics(x).reshape(-1, 3)
    pos = pose[:, :3]
    body_center = x[num_u: num_u+3]
    body_quat = x[-4:]
    
    com = (np.sum(pos * masses[1:]) + body_center * masses[0]) / mass_sum

    body_up = m_quat2mat(body_quat) @ np.array([0,0,1])
    body_up = body_up / np.linalg.norm(body_up)
    body_up_theta = np.arccos(body_up[2])
    
    # calculate terms
    term1 = smooth_abs(com - body_center)
    term2 = smooth_abs(body_up_theta)
    term3 = smooth_abs(body_center[2] - 0.3)
    
    return term1 + term2 + term3
    
@jit
def running_cost(x, u, alpha=0.3):
    return np.sum((alpha ** 2) * (np.cosh(u/alpha) - 1))

In [14]:
bigger_dim = max(num_u, num_x)

request_queue, ans_arr, num_task = create_workers(4, bigger_dim, bigger_dim, optim_timestep)

Worker-0 started
Worker-1 started
Worker-2 started
Worker-3 started


In [15]:
u_range = np.array([[-1000] * ant.dof, [1000] * ant.dof])
pred_time = 5
# horizon = int(pred_time / optim_timestep) + 1
horizon = 30
per_iter = 1

eps = 1e-6
model_der = {
    'f_x' : lambda x, u : mp_num_fx(request_queue, ans_arr, u, ant, num_task, eps=1e-3),
    'f_u' : lambda x, u : mp_num_fu(request_queue, ans_arr, u, ant, num_task, eps=1e-3)
}

ilqr = ILQR(final_cost, running_cost, lambda x, u : sim_step(s0, ant, u), u_range, horizon, per_iter, model_der)

In [16]:
# prepare simulation

ant.set_pose(Pose([0,0,0]))
# run to stable
if not stabled:
    for i in range(1000):
        s0.step()
    stabled = True

ini_pack = ant.pack()
    
u_seq = np.zeros((horizon, num_u))
x_seq = []
x_seq.append(get_state(ant))
for i in range(horizon - 1):
    state = sim_step(s0, ant, u_seq[i])
    x_seq.append(state)
x_seq = np.array(x_seq)

In [17]:
ant.get_pose()

Pose([0.0260843, 0.140073, -0.659846], [0.9996, -0.000545851, -0.00262666, -0.0281518])

In [None]:
pack = ant.pack()
x_seq, u_seq = ilqr.predict(x_seq, u_seq)
ant.unpack(pack)

HBox(children=(FloatProgress(value=0.0, description='ILQR', max=1.0, style=ProgressStyle(description_width='in…

HBox(children=(FloatProgress(value=0.0, description='backward', max=29.0, style=ProgressStyle(description_widt…



-------------ITER 28------------------------------
k
	-0.06016075611114502

kk
	-0.10692329704761505

inv_qq
	0.8755161166191101

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.2053399085998535

fu
	0.20535288751125336

vx
	0.08337128907442093

vxx
	19.763029098510742

qx
	14.85741138458252

qu
	14.654681205749512

qxx
	33.136016845703125

quu
	31.269594192504883

qux
	31.22096824645996



-------------ITER 27------------------------------
k
	0.002810955047607422

kk
	-0.11681552976369858

inv_qq
	0.8750636577606201

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.2053399085998535

fu
	0.20535288751125336

vx
	2.3192644119262695

vxx
	1.3640670776367188

qx
	-7.085445880889893

qu
	-7.812483787536621

qxx
	365.8864440917969

quu
	366.8847351074219

qux
	365.8855895996094



-------------ITER 26------------------------------
k
	0.00038461387157440186

kk
	-0.11686136573553085

inv_qq
	0.8750731348991394

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.20533990859985



-------------ITER 3------------------------------
k
	nan

kk
	nan

inv_qq
	nan

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.2053399085998535

fu
	0.20535288751125336

vx
	nan

vxx
	nan

qx
	nan

qu
	nan

qxx
	nan

quu
	nan

qux
	nan



-------------ITER 2------------------------------
k
	nan

kk
	nan

inv_qq
	nan

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.2053399085998535

fu
	0.20535288751125336

vx
	nan

vxx
	nan

qx
	nan

qu
	nan

qxx
	nan

quu
	nan

qux
	nan



-------------ITER 1------------------------------
k
	nan

kk
	nan

inv_qq
	nan

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.2053399085998535

fu
	0.20535288751125336

vx
	nan

vxx
	nan

qx
	nan

qu
	nan

qxx
	nan

quu
	nan

qux
	nan



-------------ITER 0------------------------------
k
	nan

kk
	nan

inv_qq
	nan

lx
	0.0

lu
	0.0

lxx
	0.0

luu
	1.0

lux
	0.0

fx
	1.2053399085998535

fu
	0.20535288751125336

vx
	nan

vxx
	nan

qx
	nan

qu
	nan

qxx
	nan

quu
	nan

qux
	nan



HBox(children=(FloatProgress(value=0.0, description='forward', max=29.0, style=ProgressStyle(description_width…

In [None]:
# ctrl = []

# render_controller.show_window()

# from tqdm.notebook import tqdm

# l0 = ant.get_links()[0]

# # use another thread for rendering
# # def thread_render():
# #     render_controller.focus(l0)
# #     render_controller.render()

# # thread = threading.Thread(target=thread_render)

# # s0.update_render()
# # thread.start()
# for i in trange(1):
#     x_seq, u_range = ilqr.predict(x_seq, u_seq)
    
#     u = u_seq[0]
#     ctrl.append(u.tolist())
#     print(u)
    
#     ant.set_qf(u)
#     s0.step()
#     s0.update_render()
    
#     render_controller.focus(l0)
#     render_controller.render()

    
# print(ctrl)
# np.save('ctrl', ctrl)
