In [None]:
import os
import time
from copy import deepcopy
from tqdm import tqdm
from typing import Sequence, Any, cast, Dict, List, Optional, Tuple
from functools import partial

import jax
from jax.lib import xla_bridge
import jax.numpy as jnp

import brax
from brax.io import file, html, image
from brax.physics import config_pb2

from google.protobuf import text_format

import gc
import numpy as np
from IPython.display import HTML, Image
from scipy.spatial.transform import Rotation as R
import matplotlib.pyplot as plt
import pybullet as p
from mpl_toolkits.mplot3d import Axes3D

print('Jax now using: {}'.format(xla_bridge.get_backend().platform))
devices = jax.devices('gpu')
print('Found {} GPU devices: {}'.format(len(devices), devices))

In [None]:
class BraxKukaObstacle():
    def __init__(self, args):
        self.args = args
        self.brax_config = None
        self.brax_sys = None
        self.fruit_link_names = []

        # Setup PyBullet for controller compute
        self._physics_client_id = p.connect(p.DIRECT) # SHARED_MEMORY, DIRECT
        p.setRealTimeSimulation(False)
        robot_urdf = os.path.join(self.args['root'], "gyms/robot/kuka_with_gripper_fix_brax.urdf")
        flags = p.URDF_ENABLE_CACHED_GRAPHICS_SHAPES | p.URDF_USE_INERTIA_FROM_FILE | p.URDF_USE_SELF_COLLISION
        self.kukaUid = p.loadURDF(robot_urdf, basePosition=[0, 0, 0], baseOrientation=[0, 0, 0, 1], 
                                  flags=flags, physicsClientId=self._physics_client_id)
        self.numJoints = p.getNumJoints(self.kukaUid, physicsClientId=self._physics_client_id)
        self.armJointsInfo = {}
        for i in range(self.numJoints):
            info = p.getJointInfo(self.kukaUid, i, physicsClientId=self._physics_client_id)
            if info[2] == p.JOINT_REVOLUTE:
                self.armJointsInfo[i] = info[1]
            elif info[1].decode('utf-8') == 'grasp_target':
                self.graspTargetLink = i # Index of the grasp target link
        self.armJoints = list(self.armJointsInfo.keys())

        # self.Kp = jnp.diag(jnp.array([2631, 1331, 1865, 1740, 2451, 2601]))
        # self.Kd = jnp.diag(jnp.array([52, 32, 62, 46, 74, 74]))
        self.Kp = jnp.diag(jnp.array([2450, 3993, 2937, 5778, 5663, 4219]))
        self.Kd = jnp.diag(jnp.array([38, 75, 45, 43, 47, 64]))
        self.Kqp = jnp.diag(jnp.array([50, 50, 50, 50, 50, 50, 50]))
        self.Kqd = jnp.diag(jnp.array([8, 8, 8, 8, 8, 8, 8]))
        # self.restPose = jnp.array([0.0, -0.523599, 0.0, -2.0944, 0.0, 1.5708, 0.0])
        self.restPose = jnp.array([0.0, 0.0, 0.0, -1.5708, 0.0, 1.5708, 0.0])
        
    ###################################################################################
    ###################################### Robot ######################################
    ###################################################################################

    def load_robot(self, complexity='capsule'):
        assert complexity == 'capsule', f'Robot complexity {complexity} not support'

        # Read the config file for the robot
        config_file = os.path.join(self.args['root'], f'brax_config/kuka_{complexity}_rigid_config.txt')
        with open(config_file, 'r') as f:
            data = f.readlines()
        _KUKA_CONFIG = ''.join(data)
        self.brax_config = text_format.Parse(_KUKA_CONFIG, brax.Config())

        # Change the parameters for the robot
        for joint in self.brax_config.joints:
            joint.stiffness = self.args['robot_joint_stiffness']
            joint.angular_damping = self.args['robot_joint_angular_damping']
            joint.spring_damping = self.args['robot_joint_spring_damping']
            joint.limit_strength = self.args['robot_joint_limit_strength']

        for actuator in self.brax_config.actuators:
            actuator.strength = self.args['robot_actuator_strength']
        
        # Get some useful parameters
        self.kuka_bodies = [body.name for body in self.brax_config.bodies][1:]
        self.first_node_index = 9

    ###################################################################################
    ################################## Shelf Picking ##################################
    ###################################################################################

    def load_shelf(self, shelf_height=0.45, shelf_length=0.9, shelf_width=0.9, 
                   shelf_thickness=0.04, n_obstacle=2, obstacle_rad=0.07,
                   first_link_in_col=2, block_target=False, obstacle_fix=False):
        # shelf_length in y direction, shelf_width in x direction
        self.n_obstacle = n_obstacle

        # Add plates
        for plate_name in ['shelf_bottom_plate', 'shelf_upper_plate']:
            plate = self.brax_config.bodies.add(name=plate_name, mass=1)
            plate.inertia.x = 1
            plate.inertia.y = 1
            plate.inertia.z = 1
            collider = plate.colliders.add()
            collider.box.halfsize.x = shelf_width / 2
            collider.box.halfsize.y = shelf_length / 2
            collider.box.halfsize.z = shelf_thickness / 2
            plate.frozen.all = True

        # Add columns
        for i in range(4):
            column = self.brax_config.bodies.add(name=f'shelf_column_{i}', mass=1)
            column.inertia.x = 1
            column.inertia.y = 1
            column.inertia.z = 1
            collider = column.colliders.add()
            collider.capsule.radius = shelf_thickness / 2
            collider.capsule.length = shelf_height
            column.frozen.all = True

        # Connect the shelf
        childs = ['shelf_column_0', 'shelf_column_1', 'shelf_column_2', 'shelf_column_3', 'shelf_upper_plate']
        offsets = [((shelf_width - shelf_thickness) / 2, (shelf_length - shelf_thickness) / 2, shelf_height / 2), 
                  (-(shelf_width - shelf_thickness) / 2, (shelf_length - shelf_thickness) / 2, shelf_height / 2),
                  ((shelf_width - shelf_thickness) / 2, -(shelf_length - shelf_thickness) / 2, shelf_height / 2),
                  (-(shelf_width - shelf_thickness) / 2, -(shelf_length - shelf_thickness) / 2, shelf_height / 2),
                  (0, 0, shelf_height)]
        
        if block_target: # Put a column in the front of the robot to block the target
            offsets[0] = (-(shelf_width / 2) / 2, 0, shelf_height / 2)


        for i, (child, offset) in enumerate(zip(childs, offsets)):
            joint = self.brax_config.joints.add(name=f'bottom_plate_to_{child}', parent='shelf_bottom_plate', child=child)
            joint.child_offset.x = - offset[0]
            joint.child_offset.y = - offset[1]
            joint.child_offset.z = - offset[2]

        # Put a fixed target on the shelf
        target = self.brax_config.bodies.add(name='target')
        target.frozen.all = True
        collider = target.colliders.add()
        collider.sphere.radius = shelf_thickness

        # Put movable obstacles on the shelf
        for i in range(n_obstacle):
            shape = 'sphere' # np.random.choice(['sphere', 'box'])
            obstacle = self.brax_config.bodies.add(name=f'shelf_obstacle_{i}', mass=1)
            obstacle.inertia.x = 1
            obstacle.inertia.y = 1
            obstacle.inertia.z = 1
            collider = obstacle.colliders.add()
            if shape == 'sphere':
                collider.capsule.radius = obstacle_rad
                collider.capsule.length = obstacle_rad * 2
            else:
                collider.box.halfsize.x = obstacle_rad
                collider.box.halfsize.y = obstacle_rad
                collider.box.halfsize.z = obstacle_rad
            if obstacle_fix:
                obstacle.frozen.all = True
        
        # Create collisions for robot-shelf and obstacle-shelf
        self.shelf_bodies = ['shelf_bottom_plate', 'shelf_upper_plate'] # 'shelf_column_0', 'shelf_column_1', 'shelf_column_2', 'shelf_column_3']
        for body in self.brax_config.bodies:
            if 'shelf' in body.name: # robot-shelf
                for robot_link in self.kuka_bodies[first_link_in_col:]:
                    c = self.brax_config.collide_include.add()
                    c.first = robot_link
                    c.second = body.name
            if 'shelf_obstacle' in body.name: # obstacle-shelf
                for shelf_link in self.shelf_bodies:
                    c = self.brax_config.collide_include.add()
                    c.first = shelf_link
                    c.second = body.name

    def load_target_obstacles(self, n_obstacle=2, obstacle_rad=0.07, 
                              first_link_in_col=2, obstacle_fix=False):
        self.n_obstacle = n_obstacle

        # Put a fixed target on the shelf
        target = self.brax_config.bodies.add(name='target')
        target.frozen.all = True
        collider = target.colliders.add()
        collider.color = '#365963'
        collider.sphere.radius = 0.04

        # Put movable obstacles on the shelf
        for i in range(n_obstacle):
            shape = 'sphere' # np.random.choice(['sphere', 'box'])
            obstacle = self.brax_config.bodies.add(name=f'shelf_obstacle_{i}', mass=1)
            obstacle.inertia.x = 1
            obstacle.inertia.y = 1
            obstacle.inertia.z = 1
            collider = obstacle.colliders.add()
            collider.color = '#466336'
            if shape == 'sphere':
                collider.capsule.radius = obstacle_rad
                collider.capsule.length = obstacle_rad * 2
            else:
                collider.box.halfsize.x = obstacle_rad
                collider.box.halfsize.y = obstacle_rad
                collider.box.halfsize.z = obstacle_rad
            if obstacle_fix:
                obstacle.frozen.all = True
        
        # Create collisions for robot-obstacles
        for body in self.brax_config.bodies:
            if 'shelf' in body.name: # robot-shelf
                for robot_link in self.kuka_bodies[first_link_in_col:]:
                    c = self.brax_config.collide_include.add()
                    c.first = robot_link
                    c.second = body.name

    def put_obstacle_and_target(self, qp, target_position, obstacle_positions, init_offset=[0, 0, 0]):
        assert self.brax_sys is not None, "Brax simulation not initialized"
        self.target_idx = self.brax_sys.body.index['target']
        self.obstacle_idxs = [self.brax_sys.body.index[f'shelf_obstacle_{i}'] for i in range(self.n_obstacle)]

        pos = brax.jumpy.index_update(qp.pos, self.target_idx, target_position)
        for obstacle_pos, obstacle_idx in zip(obstacle_positions[0:self.n_obstacle], self.obstacle_idxs):
            pos = brax.jumpy.index_update(pos, obstacle_idx, obstacle_pos)
  
        total_link = pos.shape[0]
        for shelf_link in range(self.first_node_index, total_link):
            link_pos = pos[shelf_link] + jnp.array(init_offset)
            pos = brax.jumpy.index_update(pos, shelf_link, link_pos)

        self.curr_offset = jnp.array(init_offset)
        qp = qp.replace(pos=pos)
        return qp

    def put_shelf(self, qp, shelf_position):
        pos = qp.pos
        total_link = pos.shape[0]
        offset = jnp.array(shelf_position) - self.curr_shelf_position
        for shelf_link in range(self.first_node_index, total_link):
            link_pos = pos[shelf_link] + offset
            pos = brax.jumpy.index_update(pos, shelf_link, link_pos)
        qp = qp.replace(pos=pos)
        self.curr_shelf_position = jnp.array(shelf_position)
        return qp

    ###################################################################################
    ######################### Dynamics & Controller & Helpers #########################
    ###################################################################################

    def create_sys(self):
        assert self.brax_config, 'No config exist yet'
        self.brax_sys = brax.System(self.brax_config)
        
        # Jit some useful functions
        self.step = jax.jit(self.brax_sys.step)
        if self.first_node_index == 9:
            self.get_kuka_joints = jax.jit(self._get_kuka_joints) # in rad
            self.get_kuka_target_pose = jax.jit(self._get_kuka_target_pose)
        if len(self.fruit_link_names) > 0:
            self.get_fruit_pose = jax.jit(self._get_fruit_pose)

        qp = self.brax_sys.default_qp()
        return qp
    
    def _get_kuka_joints(self, qp):
        joint_angle, joint_vel = self.brax_sys.joints[0].angle_vel(qp)
        return joint_angle, joint_vel
    
    def _get_kuka_target_pose(self, qp):
        # GT pose of the grasp target in PyBullet at [0] * 7
        # [7.89874410e-13 2.97529613e-07 1.55850000e+00]
        # [1.11022302e-16 2.80692743e-16 0.00000000e+00 1.00000000e+00]
        link_7_pos = qp.pos[self.first_node_index - 1]
        link_7_quat = qp.rot[self.first_node_index - 1] # in (w, x, y, z)
        link_7_rot_z = self.quat2matrix(link_7_quat, z_only=True)
        grasp_target_pos = link_7_pos + 0.2975 * link_7_rot_z
        grasp_target_quat = jnp.roll(link_7_quat, -1) # in (x, y, z, w)

        # use link 7 velocity to represent grasp target velocity, angular velocity is wrong
        link_7_vel = qp.vel[self.first_node_index - 1]
        link_7_ang = qp.ang[self.first_node_index - 1]

        return grasp_target_pos, grasp_target_quat, link_7_vel, link_7_ang

    def quat2matrix(self, quat, z_only=False, wxyz=True):
        if wxyz:
            # quat in w, x, y, z
            qr, qi, qj, qk = quat
        else:
            qi, qj, qk, qr = quat
        s = 1 / (jnp.linalg.norm(quat) ** 2)
        
        r02 = 2 * s * (qi * qk + qj * qr)
        r12 = 2 * s * (qj * qk - qi * qr)
        r22 = 1 - 2 * s * (qi * qi + qj * qj)
        
        if z_only:
            return jnp.array([r02, r12, r22])
        
        r00 = 1 - 2 * s * (qj * qj + qk * qk)
        r10 = 2 * s * (qi * qj + qk * qr)
        r20 = 2 * s * (qi * qk - qj * qr)
        
        r01 = 2 * s * (qi * qj - qk * qr)
        r11 = 1 - 2 * s * (qi * qi + qk * qk)
        r21 = 2 * s * (qj * qk + qi * qr)
        
        return jnp.array([[r00, r01, r02],
                          [r10, r11, r12],
                          [r20, r21, r22]])
        
    def compute_robot_dynamics(self, q, dq):
        q = q.tolist() if not isinstance(q, list) else q
        dq = dq.tolist() if not isinstance(dq, list) else dq
        zero = [0] * len(q)

        # Compute inertial
        M = jnp.array(p.calculateMassMatrix(bodyUniqueId=self.kukaUid, objPositions=q, 
                                            physicsClientId=self._physics_client_id))

        # Compute jacobian
        J = p.calculateJacobian(bodyUniqueId=self.kukaUid, linkIndex=self.graspTargetLink, 
                                localPosition=[0, 0, 0], objPositions=q, objVelocities=zero, 
                                objAccelerations=zero, physicsClientId=self._physics_client_id)
        J = jnp.vstack(J)
        
        # Compute drift
        g_plus_c = jnp.asarray(p.calculateInverseDynamics(bodyUniqueId=self.kukaUid, objPositions=q, 
                                                          objVelocities=dq, objAccelerations=zero, 
                                                          physicsClientId=self._physics_client_id))
        
        # Compute task space inertial
        M_inv = jnp.linalg.inv(M)
        M_task = np.linalg.pinv(J @ M_inv @ J.T)

        # Compute jacobian inverse
        J_inv = M_inv @ J.T @ M_task

        # Compute the full nullspace
        N = jnp.eye(len(self.armJoints)) - J_inv @ J

        return M, M_task, J, J_inv, g_plus_c, N

    def compute_torque_from_eef_offset(self, x_err, q_err, constants):
        dx_err, dq_err, J, M_task, N, g_plus_c = constants
        tau_task = J.T @ M_task @ (- self.Kp @ x_err.reshape(-1, 1) - self.Kd @ dx_err.reshape(-1, 1))
        tau_joint = N.T @ (- self.Kqp @ q_err.reshape(-1, 1) - self.Kqd @ dq_err.reshape(-1, 1))
        tau = tau_task + tau_joint + g_plus_c.reshape(-1, 1)
        # Seems scale the torque here is very important
        return tau.reshape(-1,) / self.args['robot_actuator_strength']

In [None]:
sys_args = {'root': os.path.abspath(os.path.join(os.getcwd(), os.pardir)),

            'robot_joint_stiffness': 0, # Use pbd simulation for the shelf picking env
            'robot_joint_spring_damping': 0.0, # Velocity between parent/child bodies
            'robot_joint_limit_strength': 0.0,
            'robot_joint_angular_damping': 50.0, # Control the angular velocity between parent/child bodies
            'robot_actuator_strength': 100.0, # Used to scale the input torque

            'shelf_height': 0.6, 
            'shelf_length': 0.9, 
            'shelf_width': 0.9, 
            'shelf_thickness': 0.04,
            'obstacle_rad': 0.07,
            'obstacle_fix': True,
            'block_target': False,
}

target_position = [0.3, 0.4, 0.4]
obstacle_positions = [[0.15, 0.15, 0.75]]

sys_args['n_obstacle'] = len(obstacle_positions)
np.random.seed(428)

KukaShelf = BraxKukaObstacle(sys_args)
KukaShelf.load_robot()
KukaShelf.load_target_obstacles(n_obstacle=sys_args['n_obstacle'], obstacle_rad=sys_args['obstacle_rad'], 
                                obstacle_fix=sys_args['obstacle_fix'])

qp = KukaShelf.create_sys()
qp = KukaShelf.put_obstacle_and_target(qp, target_position, obstacle_positions, init_offset=[2, 0, 0])

# Control the robot to the rest pose
n_waypoints = 100
q, dq = KukaShelf.get_kuka_joints(qp)
waypoints = np.linspace(start=q, stop=KukaShelf.restPose, num=n_waypoints)
Kqp = jnp.diag(jnp.array([1, 1, 1, 1, 1, 1, 1])) * 2
Kqd = jnp.diag(jnp.array([1, 1, 1, 1, 1, 1, 1])) * 0.7
for step in range(300):
    next_goal = waypoints[min(step, n_waypoints - 1)]
    q, dq = KukaShelf.get_kuka_joints(qp)
    err = jnp.rad2deg(next_goal - q).reshape(-1, 1)
    tau = Kqp @ err + Kqd @ dq.reshape(-1, 1)
    qp, info = KukaShelf.step(qp, tau.reshape(-1, ))

qp = KukaShelf.put_obstacle_and_target(qp, target_position, obstacle_positions, init_offset=[0, 0, 0])
default_qp = qp

In [None]:
HTML(html.render(KukaShelf.brax_sys, [default_qp], height=480))

# CMAES

In [None]:
_EPS = 1e-8
_MEAN_MAX = 1e32
_SIGMA_MAX = 1e32

@jax.jit
def _eigen_decomposition(C):
    C = (C + C.T) / 2
    D2, B = jnp.linalg.eigh(C)
    D = jnp.sqrt(jnp.where(D2 < 0, _EPS, D2))
    C = jnp.dot(jnp.dot(B, jnp.diag(D ** 2)), B.T)
    return B, D

@jax.jit
def sample_white_gaussian(mean, sigma, B, D, z, n_dim):
    """z ~ N(0, I)"""
    y = cast(jnp.ndarray, B.dot(jnp.diag(D))).dot(z)  # ~ N(0, C)
    x = mean + sigma * y  # ~ N(m, σ^2 C)
    return x

batch_sample_white_gaussian = jax.jit(jax.vmap(sample_white_gaussian, in_axes=[None, None, None, None, 0, None]))

class CMAES():
    def __init__(self, mean, sigma, key, bounds, n_max_resampling=100):
        n_dim = len(mean)
        population_size = int(4 + np.floor(3 * np.log(n_dim)))  # (eq. 48): lambda
        mu = population_size // 2 # number of top candidates to update

        # (eq. 49)
        tmp = jnp.log((population_size + 1) / 2)
        weights_prime = jnp.array([tmp - np.log(i + 1) for i in range(population_size)])
        mu_eff = (jnp.sum(weights_prime[:mu]) ** 2) / jnp.sum(weights_prime[:mu] ** 2)
        mu_eff_minus = (jnp.sum(weights_prime[mu:]) ** 2) / jnp.sum(weights_prime[mu:] ** 2)

        alpha_cov = 2
        c1 = alpha_cov / ((n_dim + 1.3) ** 2 + mu_eff)
        # learning rate for the rank-μ update
        cmu = min(1 - c1 - 1e-8,  # 1e-8 is for large popsize
                  alpha_cov * (mu_eff - 2 + 1 / mu_eff) / ((n_dim + 2) ** 2 + alpha_cov * mu_eff / 2))
        
        min_alpha = min(1 + c1 / cmu,  # eq.50
                        1 + (2 * mu_eff_minus) / (mu_eff + 2),  # eq.51
                        (1 - c1 - cmu) / (n_dim * cmu))  # eq.52

        # (eq.53)
        positive_sum = jnp.sum(weights_prime[weights_prime > 0])
        negative_sum = jnp.sum(jnp.abs(weights_prime[weights_prime < 0]))
        weights = jnp.where(weights_prime >= 0,
                            1 / positive_sum * weights_prime,
                            min_alpha / negative_sum * weights_prime)
        cm = 1  # (eq. 54)

        # (eq.55)
        c_sigma = (mu_eff + 2) / (n_dim + mu_eff + 5)
        d_sigma = 1 + 2 * max(0, jnp.sqrt((mu_eff - 1) / (n_dim + 1)) - 1) + c_sigma
        
        # (eq.56)
        cc = (4 + mu_eff / n_dim) / (n_dim + 4 + 2 * mu_eff / n_dim)
        
        self._n_dim = n_dim
        self._popsize = population_size
        self._mu = mu
        self._mu_eff = mu_eff

        self._cc = cc
        self._c1 = c1
        self._cmu = cmu
        self._c_sigma = c_sigma
        self._d_sigma = d_sigma
        self._cm = cm

        # E||N(0, I)|| (p.28)
        self._chi_n = jnp.sqrt(self._n_dim) * (1.0 - (1.0 / (4.0 * self._n_dim)) + 1.0 / (21.0 * (self._n_dim ** 2)))

        self._weights = weights

        # evolution path
        self._p_sigma = jnp.zeros(self._n_dim)
        self._pc = jnp.zeros(self._n_dim)

        self._mean = mean

        self._C = jnp.eye(self._n_dim)

        self._sigma = sigma

        self._bounds = bounds
        self._n_max_resampling = n_max_resampling

        self._g = 0 # generation
        self._rng = key

        # Termination criteria
        self._tolx = 1e-5 * sigma
        self._tolxup = 1e5
        self._tolfun = 1e-5
        self._tolconditioncov = 1e14

        self._funhist_term = int(10 + jnp.ceil(30 * n_dim / population_size))
        self._funhist_values = jnp.empty(self._funhist_term * 2)

    def ask(self):
        B, D = _eigen_decomposition(self._C)
        self._rng, split = jax.random.split(self._rng)
        zs = jax.random.normal(split, shape=(self._popsize, self._n_dim)) # ~ N(0, I)
        xs = batch_sample_white_gaussian(self._mean, self._sigma, B, D, zs, self._n_dim)

        # repair the samples to the feasible set
        xs_repair = xs.clip(-self._bounds, self._bounds)
        return xs, xs_repair

    def tell(self, repaired_solutions, alpha=1):
        """repaired_solutions: [( x, x_repair, cost(x_repair) )]
           solutions: [( x, cost_fitness(x) )]"""
        self._g += 1

        # (eq.62) repair the cost function, choose alpha such that two terms are in the similar magnitude
        solutions = []
        for x, x_repair, c in repaired_solutions:
            dist = jnp.linalg.norm(x - x_repair)
            alpha = 10 ** (int(jnp.log10(c)) - int(jnp.log10(dist))) if dist > 0 else 0
            solutions.append((x, c + alpha * dist))

        # solutions = [(x, c + alpha * jnp.linalg.norm(x - x_repair)) for x, x_repair, c in repaired_solutions]
        solutions.sort(key=lambda s: s[1])

        # Stores 'best' and 'worst' values of the last 'self._funhist_term' generations.
        funhist_idx = 2 * (self._g % self._funhist_term)
        self._funhist_values.at[funhist_idx].set(solutions[0][1])
        self._funhist_values.at[funhist_idx + 1].set(solutions[-1][1])

        # Sample new population of search_points, for k=1, ..., popsize
        B, D = _eigen_decomposition(self._C)

        x_k = jnp.array([s[0] for s in solutions])  # ~ N(m, σ^2 C)
        y_k = (x_k - self._mean) / self._sigma  # ~ N(0, C)

        # Selection and recombination
        y_w = jnp.sum(y_k[: self._mu].T * self._weights[: self._mu], axis=1)  # eq.41
        self._mean += self._cm * self._sigma * y_w

        # Step-size control
        C_2 = cast(jnp.ndarray, cast(jnp.ndarray, B.dot(jnp.diag(1 / D))).dot(B.T))  # C^(-1/2) = B D^(-1) B^T
        self._p_sigma = (1 - self._c_sigma) * self._p_sigma + jnp.sqrt(
            self._c_sigma * (2 - self._c_sigma) * self._mu_eff
        ) * C_2.dot(y_w)

        norm_p_sigma = jnp.linalg.norm(self._p_sigma)
        self._sigma *= jnp.exp(
            (self._c_sigma / self._d_sigma) * (norm_p_sigma / self._chi_n - 1)
        )
        self._sigma = min(self._sigma, _SIGMA_MAX)

        # Covariance matrix adaption
        h_sigma_cond_left = norm_p_sigma / jnp.sqrt(
            1 - (1 - self._c_sigma) ** (2 * (self._g + 1))
        )
        h_sigma_cond_right = (1.4 + 2 / (self._n_dim + 1)) * self._chi_n
        h_sigma = 1.0 if h_sigma_cond_left < h_sigma_cond_right else 0.0  # (p.28)

        # (eq.45)
        self._pc = (1 - self._cc) * self._pc + h_sigma * jnp.sqrt(
            self._cc * (2 - self._cc) * self._mu_eff
        ) * y_w

        # (eq.46)
        w_io = self._weights * jnp.where(
            self._weights >= 0,
            1,
            self._n_dim / (jnp.linalg.norm(C_2.dot(y_k.T), axis=0) ** 2 + _EPS),
        )

        delta_h_sigma = (1 - h_sigma) * self._cc * (2 - self._cc)  # (p.28)

        # (eq.47)
        rank_one = jnp.outer(self._pc, self._pc)
        rank_mu = jnp.sum(
            jnp.array([w * jnp.outer(y, y) for w, y in zip(w_io, y_k)]), axis=0
        )
        self._C = ((1 + self._c1 * delta_h_sigma - self._c1 - self._cmu * jnp.sum(self._weights)) * self._C
                   + self._c1 * rank_one + self._cmu * rank_mu)

    def should_stop(self):
        B, D = _eigen_decomposition(self._C)
        dC = jnp.diag(self._C)

        # Stop if the range of function values of the recent generation is below tolfun.
        if (
            self._g > self._funhist_term
            and jnp.max(self._funhist_values) - jnp.min(self._funhist_values)
            < self._tolfun
        ):
            return True

        # Stop if the std of the normal distribution is smaller than tolx
        # in all coordinates and pc is smaller than tolx in all components.
        if jnp.all(self._sigma * dC < self._tolx) and jnp.all(
            self._sigma * self._pc < self._tolx
        ):
            return True

        # Stop if detecting divergent behavior.
        if self._sigma * jnp.max(D) > self._tolxup:
            return True

        # No effect coordinates: stop if adding 0.2-standard deviations
        # in any single coordinate does not change m.
        if jnp.any(self._mean == self._mean + (0.2 * self._sigma * jnp.sqrt(dC))):
            return True

        # No effect axis: stop if adding 0.1-standard deviation vector in
        # any principal axis direction of C does not change m. "pycma" check
        # axis one by one at each generation.
        i = self._g % self._n_dim
        if jnp.all(self._mean == self._mean + (0.1 * self._sigma * D[i] * B[:, i])):
            return True

        # Stop if the condition number of the covariance matrix exceeds 1e14.
        condition_cov = jnp.max(D) / jnp.min(D)
        if condition_cov > self._tolconditioncov:
            return True

        return False

# Solve

In [None]:
def get_args_constants(state):
    q, dq = KukaShelf.get_kuka_joints(state)
    M, M_task, J, J_inv, g_plus_c, N = KukaShelf.compute_robot_dynamics(q, dq)
    eef_pos, eef_quat, eef_vel, eef_ang = KukaShelf.get_kuka_target_pose(state)

    dx_err = jnp.hstack((eef_vel, eef_ang))
    dq_err = dq

    args = [dx_err, dq_err, J, M_task, N, g_plus_c]
    return args

@jax.jit
def dynamics(state, u, args):
    x_err, q_err = jnp.split(u, [6])
    tau = KukaShelf.compute_torque_from_eef_offset(x_err, q_err, args[0:6]) # input args = [dx_err, dq_err, J, M_task, N, g_plus_c]
    state, info = KukaShelf.step(state, tau)
    return state, info

batch_dynamics = jax.jit(jax.vmap(dynamics, in_axes=[None, 0, None]))

@jax.jit
def cost(state, us, args):
    #  us.shape = horizon x (6+7)
    # step the system
    eef_pos_init, eef_quat_init, eef_vel_init, eef_ang_init = KukaShelf.get_kuka_target_pose(state)
    logs = []
    for u in us:
        state, info = dynamics(state, u, args)
        contact_force = info.contact.vel[1: KukaShelf.first_node_index] # can clip with jnp.clip(f, -1, 1)
        contact_torque = info.contact.ang[1: KukaShelf.first_node_index] 
        eef_pos, eef_quat, eef_vel, eef_ang = KukaShelf.get_kuka_target_pose(state)
        logs.append([contact_force, contact_torque, eef_pos, eef_quat])
    eef_pos_next, eef_quat_next, eef_vel_next, eef_ang_next = KukaShelf.get_kuka_target_pose(state)

    # compute cost
    weights, goal_pos = args[-2], args[-1]
    dist_to_goal_end = jnp.linalg.norm((eef_pos_next - goal_pos) * 100)

    dist_to_goal_step = jnp.array([jnp.linalg.norm((eef_pos - args[-1]) * 100) for _, _, eef_pos, _ in logs]).mean()
    contact_force_step = jnp.array([jnp.sum(jnp.square(contact_force)) for contact_force, _, _, _ in logs]).mean()
    dist_so_far_step = jnp.array([jnp.linalg.norm((eef_pos - eef_pos_init) * 100) for _, _, eef_pos, _ in logs]).mean()

    return weights[0] * contact_force_step + weights[1] * dist_to_goal_step + weights[2] * dist_so_far_step + dist_to_goal_end

batch_cost = jax.jit(jax.vmap(cost, in_axes=[None, 0, None]))
d_cost = jax.jit(jax.grad(cost, argnums=1))

def cmaes_mpc(key, mean, sigma, bound, max_step, horizon, state, args):
    solver = CMAES(mean, sigma, key, bound)
    for cmaes_step in range(max_step):
        xs, xs_repair = solver.ask()
        controls = xs_repair.reshape(solver._popsize, horizon, -1)
        costs = batch_cost(state, controls, args)
        repaired_solutions = list(zip(xs, xs_repair, costs))
        solver.tell(repaired_solutions)
        if solver.should_stop():
            break
    u = solver._mean.clip(-bound, bound).reshape(horizon, -1)
    # u = jnp.hstack((u, jnp.zeros_like(u))) # append zero rotations for translation-only control
    return u, cmaes_step

In [None]:
random_key = jax.random.PRNGKey(int(time.time()))

robot_step = 0.005
state = default_qp

# Set the target and the obstacle positions
target_position = [0.3, 0.3, 0.2]
obstacle_positions = [[0.06, 0.15, 0.55]]
state = KukaShelf.put_obstacle_and_target(state, target_position, obstacle_positions)

states = [state]
logs = []

horizon = 3
weight_contact = 5 # weight for the contact force at each step
weight_dist_to_goal = 0 # weight for distance to goal at each step
weight_dist_so_far = 0 # weight for distance so far at each step, negative

goal_pos = state.pos[KukaShelf.target_idx]

# Controls are x_curr - x_target
mean = jnp.zeros(13 * horizon)
sigma = 0.2 # robot_step * 10
bound = jnp.array([robot_step * 1.5] * 6 + [jnp.pi / 2] * 7) # max distance the robot can move
bound = jnp.hstack([bound] * horizon)
max_step = 1000

KukaShelf.Kqp = jnp.diag(jnp.array([1, 1, 1, 1, 1, 1, 1])) * 10
KukaShelf.Kqd = jnp.diag(jnp.array([1, 1, 1, 1, 1, 1, 1]))

for step in range(200):

    ############################ Find the best u ############################
    # Prepare the arguments to avoid repeat computation
    args = get_args_constants(state) # args = [dx_err, dq_err, J, M_task, N, g_plus_c]
    args += [(weight_contact, weight_dist_to_goal, weight_dist_so_far), goal_pos]
    
    # CMAES to find the best action
    random_key, key = jax.random.split(random_key)
    u, cmaes_step = cmaes_mpc(key, mean, sigma, bound, max_step, horizon, state, args)

    ############################ Step the system ############################
    state, info = dynamics(state, u[0], args)

    ############################ Log ############################
    states.append(state)

    contact_force = info.contact.vel[1: KukaShelf.first_node_index]
    contact_force_magnitude = jnp.sum(jnp.square(contact_force))
    logs.append([contact_force_magnitude, cmaes_step, state, u, args])

In [None]:
HTML(html.render(KukaShelf.brax_sys, states, height=480))