In [None]:
import os
import time
from copy import deepcopy
from tqdm import tqdm
from typing import Sequence, Any, cast, Dict, List, Optional, Tuple
from functools import partial
import json

import jax
from jax.lib import xla_bridge
import jax.numpy as jnp

import brax
from brax.io import file, html, image
from brax.physics import config_pb2

from google.protobuf import text_format

import gc
import numpy as np
from IPython.display import HTML, Image
from scipy.spatial.transform import Rotation as R
import matplotlib.pyplot as plt
import pybullet as p
from mpl_toolkits.mplot3d import Axes3D

print('Jax now using: {}'.format(xla_bridge.get_backend().platform))
devices = jax.devices('gpu')
print('Found {} GPU devices: {}'.format(len(devices), devices))

In [None]:
class BraxKukaObstacle():
    def __init__(self, args):
        self.args = args
        self.brax_config = None
        self.brax_sys = None
        self.fruit_link_names = []

        # Setup PyBullet for controller compute
        self._physics_client_id = p.connect(p.DIRECT)
        p.setRealTimeSimulation(False)
        robot_urdf = os.path.join(self.args['root'], "gyms/robot/kuka_with_gripper_fix_brax.urdf")
        flags = p.URDF_ENABLE_CACHED_GRAPHICS_SHAPES | p.URDF_USE_INERTIA_FROM_FILE | p.URDF_USE_SELF_COLLISION
        self.kukaUid = p.loadURDF(robot_urdf, basePosition=[0, 0, 0], baseOrientation=[0, 0, 0, 1], 
                                  flags=flags, physicsClientId=self._physics_client_id)
        self.numJoints = p.getNumJoints(self.kukaUid, physicsClientId=self._physics_client_id)
        self.armJointsInfo = {}
        for i in range(self.numJoints):
            info = p.getJointInfo(self.kukaUid, i, physicsClientId=self._physics_client_id)
            if info[2] == p.JOINT_REVOLUTE:
                self.armJointsInfo[i] = info[1]
            elif info[1].decode('utf-8') == 'grasp_target':
                self.graspTargetLink = i # Index of the grasp target link
        self.armJoints = list(self.armJointsInfo.keys())

        self.Kp = jnp.diag(jnp.array([2450, 3993, 2937, 5778, 5663, 4219]))
        self.Kd = jnp.diag(jnp.array([38, 75, 45, 43, 47, 64]))
        self.Kqp = jnp.diag(jnp.array([50, 50, 50, 50, 50, 50, 50]))
        self.Kqd = jnp.diag(jnp.array([8, 8, 8, 8, 8, 8, 8]))
        # self.restPose = jnp.array([0.006418, 0.413184, -0.011401, -1.589317, 0.005379, 1.137684, -0.006539])
        self.restPose = jnp.array([0.0, 0.0, 0.0, -1.5708, 0.0, 1.5708, 0.0])

    ###################################################################################
    ###################################### Robot ######################################
    ###################################################################################

    def load_robot(self, complexity='capsule'):
        assert complexity in ['exact', 'simple', 'capsule'], f'Robot complexity {complexity} not support'

        # Read the config file for the robot
        config_file = os.path.join(self.args['root'], f'brax_config/kuka_{complexity}_config.txt')
        with open(config_file, 'r') as f:
            data = f.readlines()
        _KUKA_CONFIG = ''.join(data)
        self.brax_config = text_format.Parse(_KUKA_CONFIG, brax.Config())

        # Change the parameters for the robot
        for joint in self.brax_config.joints:
            joint.stiffness = self.args['robot_joint_stiffness']
            joint.angular_damping = self.args['robot_joint_angular_damping']
            joint.spring_damping = self.args['robot_joint_spring_damping']
            joint.limit_strength = self.args['robot_joint_limit_strength']

        for actuator in self.brax_config.actuators:
            actuator.strength = self.args['robot_actuator_strength']
        
        # Get some useful parameters
        self.kuka_bodies = [body.name for body in self.brax_config.bodies][1:]
        self.first_node_index = 9

    ###################################################################################
    ###################################### Ball #######################################
    ###################################################################################

    def get_ball_obstacle(self, rad=0.15, n_fix=10):
        assert rad in [0.1, 0.15, 0.2, 0.3], f'Ball dimension {rad} not support, need to generate with GMsh'

        # Read the .mesh file, compute ball vertices and unique edges
        mesh_file = os.path.join(self.args['root'], f'ball/ball_{int(rad*100)}cm.mesh') 
        with open(mesh_file, "r") as f:
            mesh_lines = f.readlines()

        mesh_lines = [line.strip('\n') for line in mesh_lines]

        vertices_start = mesh_lines.index(' Vertices')
        num_vertices = mesh_lines[vertices_start + 1]
        vertices = mesh_lines[vertices_start + 2: vertices_start + 2 + int(num_vertices)]
        vertices = np.array([[float(i) for i in vertice.split()[:-1]] for vertice in vertices])

        # Get all edges, triangles, and tetrahedras
        try:
            edge_start = mesh_lines.index(' Edges')
            num_edge = mesh_lines[edge_start + 1]
            edges = mesh_lines[edge_start + 2 : edge_start + 2 + int(num_edge)]
            edges = np.array([[int(i) for i in edge.split()[:-1]] for edge in edges])
        except:
            edges = np.empty((0, 2)).astype(int)
            print("No edge found")

        try:
            triangle_start = mesh_lines.index(' Triangles')
            num_triangle = mesh_lines[triangle_start + 1]
            triangles = mesh_lines[triangle_start + 2 : triangle_start + 2 + int(num_triangle)]
            triangles = np.array([[int(i) for i in triangle.split()[:-1]] for triangle in triangles])
        except:
            triangles = np.empty((0, 3)).astype(int)
            print("No triangle found")

        tetrahedra_start = mesh_lines.index(' Tetrahedra')
        num_tetrahedra = mesh_lines[tetrahedra_start + 1]
        tetrahedras = mesh_lines[tetrahedra_start + 2 : tetrahedra_start + 2 + int(num_tetrahedra)]
        tetrahedras = np.array([[int(i) for i in tetrahedra.split()[:-1]] for tetrahedra in tetrahedras])

        # Gather all edges
        connections = np.vstack((edges,
                                 np.vstack((triangles[:, 0], triangles[:, 1])).T,
                                 np.vstack((triangles[:, 0], triangles[:, 2])).T,
                                 np.vstack((triangles[:, 1], triangles[:, 2])).T,
                                 np.vstack((tetrahedras[:, 0], tetrahedras[:, 1])).T,
                                 np.vstack((tetrahedras[:, 0], tetrahedras[:, 2])).T,
                                 np.vstack((tetrahedras[:, 0], tetrahedras[:, 3])).T,
                                 np.vstack((tetrahedras[:, 1], tetrahedras[:, 2])).T,
                                 np.vstack((tetrahedras[:, 1], tetrahedras[:, 3])).T,
                                 np.vstack((tetrahedras[:, 2], tetrahedras[:, 3])).T,
                                 ))
        connections.sort(axis=1)
        tmp = [tuple(connection) for connection in connections]
        connections = np.unique(tmp, axis=0)

        dist = np.linalg.norm(vertices, axis=1)
        center_node_idxs = np.argpartition(dist, n_fix)[:n_fix].tolist()
        surface_node_idxs = np.where(np.isclose(dist, rad, atol=1e-3))[0]

        return vertices, connections, center_node_idxs, surface_node_idxs

    def load_ball_obstacle(self, vertices, connections, center_node_idxs=[], surface_node_idxs=[]):
        if self.brax_config is None:
            print('No robot loaded, only load the obstacle')
            self.brax_config = brax.Config(dt=0.01, substeps=10, dynamics_mode='legacy_spring')
            ground = self.brax_config.bodies.add(name='floor')
            ground.frozen.all = True
            plane = ground.colliders.add().plane
            plane.SetInParent()
            self.first_node_index = 1
            surface_node_idxs = [] # No need to enable collisions

        # Add all vertices
        for i in range(vertices.shape[0]):
            node = self.brax_config.bodies.add(name=f'node_{i+1}', mass=self.args['obstacle_node_mass'])
            cap = node.colliders.add().capsule
            cap.radius = self.args['obstacle_node_rad']
            cap.length = 2 * self.args['obstacle_node_rad']

            # Fix center nodes 
            if i in center_node_idxs: 
                node.frozen.all = True
            
            # Allow collisions between surface nodes and the robot
            if i in surface_node_idxs:
                for kuka_body in self.kuka_bodies[2:]:
                    self.brax_config.collide_include.add(first=kuka_body, second=f'node_{i+1}')

        # Add all edges
        for i, connection in enumerate(connections):
            parent, child = connection
            joint = self.brax_config.joints.add(name=f'edge_{i+1}',
                                                parent=f'node_{parent}',
                                                child=f'node_{child}', 
                                                angular_damping=self.args['obstacle_joint_angular_damping'],
                                                stiffness=self.args['obstacle_joint_stiffness'], 
                                                spring_damping=self.args['obstacle_joint_spring_damping'])
            joint.angle_limit.add(min=-180, max=180)
            joint.angle_limit.add(min=-180, max=180)
            joint.angle_limit.add(min=-180, max=180)
            offset = vertices[parent-1] - vertices[child-1] # connections are index from 1
            joint.child_offset.x = offset[0]
            joint.child_offset.y = offset[1]
            joint.child_offset.z = offset[2]

    ###################################################################################
    #################################### Strawberry ###################################
    ###################################################################################

    def get_strawberry(self, radius=0.005, link_height=0.05, mass=0.5, root_pos=[0, 0, 0]):
        link_infos = {}
        
        # Sample layers and directions for each branch
        n_layer = np.random.randint(2, 5)
        n_stem = np.random.randint(3, 5)
        n_branch = 4 * n_layer + n_stem

        init_rz = np.random.uniform(-np.pi / 4, np.pi / 4)
        rxs = np.deg2rad(np.linspace(15, 65, n_branch) + np.random.uniform(-5, 5, n_branch))

        # Select non-bottom branches to add fruits
        idx_branch_for_fruits = np.random.choice(n_branch - 4, (n_stem,), replace=False)

        # Sample properties for each branch, from the most vertical one
        for idx_branch in range(n_branch):
            layer = int(idx_branch / 4)
            n_links = np.random.randint(4 - 0.5 * layer, 7 - 0.5 * layer)
            length = link_height * n_links

            # Tilt down angle from vertical
            rx = rxs[idx_branch] 
            # Rotation among world's Z axis
            rz = init_rz + idx_branch * np.pi / 2 + np.deg2rad(np.random.uniform(-10, 10))
            rot = [rx, 0.0, rz]
            
            # Add fruits to the stems
            add_fruit = True if idx_branch in idx_branch_for_fruits else False

            link_infos[idx_branch] = [mass, radius, link_height, n_links, root_pos, rot, add_fruit]

        return link_infos

    def create_strawberry_branch(self, config, idx_branch, info, joint_stiffness, joint_spring_damping, 
                                 joint_angular_damping, joint_limit_strength, fruit_rad=None, fruit_mass=None,
                                 leaf_angle=60, leaf_rad=0.01, leaf_height=0.07, leaf_mass=None):
        mass, radius, link_height, n_links, root_pos, rot, add_fruit = info
        fruit_rad = 3 * radius if fruit_rad is None else fruit_rad
        fruit_mass = 4 * mass if fruit_mass is None else fruit_mass
        leaf_mass = mass / 6 if leaf_mass is None else leaf_mass

        # Compute the angle of the branch
        parent = 'crown'
        rotation = R.from_euler('xyz', rot)
        link_rotation = np.rad2deg(rotation.as_euler('zyx'))
        joint_offset = link_height * np.array([0, 0, 1])

        # Add all links
        n_links += 1 # Add addition link for fruit or leaf

        for idx_link in range(n_links):
            child = f'branch_{idx_branch}_link_{idx_link}'

            # Create a link
            link = config.bodies.add(name=child, mass=mass)
            link.inertia.x = 0.1
            link.inertia.y = 0.1
            link.inertia.z = 0.1
            collider = link.colliders.add()
            collider.capsule.radius = radius
            collider.capsule.length = link_height + 2 * radius

            # Create joint to connect parent and child
            joint = config.joints.add(name=f'branch_{idx_branch}_joint_{idx_link}', parent=parent,
                                      child=child, angular_damping=joint_angular_damping,
                                      stiffness=joint_stiffness, spring_damping=joint_spring_damping,
                                      limit_strength=joint_limit_strength)
            joint.angle_limit.add(min=0, max=0)
            joint.angle_limit.add(min=0, max=0)
            joint.angle_limit.add(min=0, max=0)

            if parent == 'crown':
                link.frozen.all = True
                joint.child_offset.x = 0
                joint.child_offset.y = 0
                joint.child_offset.z = - joint_offset[2] * 0.5
                joint.reference_rotation.x = link_rotation[2]
                joint.reference_rotation.y = link_rotation[1]
                joint.reference_rotation.z = link_rotation[0]
            else:
                joint.child_offset.x = - joint_offset[0] * 0.5
                joint.child_offset.y = - joint_offset[1] * 0.5
                joint.child_offset.z = - joint_offset[2] * 0.5
                joint.parent_offset.x = joint_offset[0] * 0.5
                joint.parent_offset.y = joint_offset[1] * 0.5
                joint.parent_offset.z = joint_offset[2] * 0.5
            
            # Add leaf or fruit at the end of the branch
            if idx_link == n_links - 1:
                if add_fruit: # The fruit branch
                    link.mass = fruit_mass
                    collider.capsule.radius = fruit_rad
                    collider.capsule.length = 2 * fruit_rad
                    collider.color = '#e71c25'
                    joint.child_offset.x = 0
                    joint.child_offset.y = 0
                    joint.child_offset.z = 0
                    self.fruit_link_names.append(child)
                else: # The leaf branch
                    link.mass = fruit_mass if n_links <= 5 else fruit_mass / 2
                    collider.capsule.radius = leaf_rad
                    collider.capsule.length = leaf_height * 2
                    collider.color = '#466336'

                    collider = link.colliders.add()
                    collider.capsule.radius = leaf_rad
                    collider.capsule.length = leaf_height
                    collider.rotation.y = leaf_angle
                    collider.position.x = leaf_height * np.cos(np.rad2deg(leaf_angle)) / 2
                    collider.position.z = - leaf_height * np.sin(np.rad2deg(leaf_angle)) / 2
                    collider.color = '#466336'

                    collider = link.colliders.add()
                    collider.capsule.radius = leaf_rad
                    collider.capsule.length = leaf_height
                    collider.rotation.y = - leaf_angle
                    collider.position.x = - leaf_height * np.cos(np.rad2deg(leaf_angle)) / 2
                    collider.position.z = - leaf_height * np.sin(np.rad2deg(leaf_angle)) / 2
                    collider.color = '#466336'

            parent = child

        return config

    def load_strawberry(self, link_infos=None, joint_stiffness=30000, joint_angular_damping=1, 
                        joint_spring_damping=20.0, joint_limit_strength=20.0,
                        first_link_in_col=3):
        if link_infos is None:
            link_infos = self.get_strawberry(radius=0.005, link_height=0.05, mass=0.5, root_pos=[0, 0, 0])
        
        # Add the root crown
        crown = self.brax_config.bodies.add(name='crown', mass=1)
        crown_cap = crown.colliders.add().capsule
        crown_cap.radius = 0.02
        crown_cap.length = 2 * crown_cap.radius
        crown.frozen.all = True

        # Add all branches
        for idx_branch, info in link_infos.items():
            self.brax_config = self.create_strawberry_branch(self.brax_config, idx_branch, info, joint_stiffness, 
                                                             joint_spring_damping, joint_angular_damping, 
                                                             joint_limit_strength)

        # Enable robot-strawberry collisions
        for body in self.brax_config.bodies:
            if 'branch' in body.name:
                idx_link = int(body.name.split('_')[-1])
                if idx_link > 0:
                    for robot_link in self.kuka_bodies[first_link_in_col:]:
                        c = self.brax_config.collide_include.add()
                        c.first = robot_link
                        c.second = body.name

        # Log the fruit link position
        self.fruit_info = {}
        for i, b in enumerate(self.brax_config.bodies):
            if b.name in self.fruit_link_names:
                self.fruit_info[i] = b.name

    ###################################################################################
    ######################### Dynamics & Controller & Helpers #########################
    ###################################################################################

    def create_sys(self):
        assert self.brax_config, 'No config exist yet'
        self.brax_sys = brax.System(self.brax_config)
        
        # Jit some useful functions
        self.step = jax.jit(self.brax_sys.step)
        if self.first_node_index == 9:
            self.get_kuka_joints = jax.jit(self._get_kuka_joints) # in rad
            self.get_kuka_target_pose = jax.jit(self._get_kuka_target_pose)
        if len(self.fruit_link_names) > 0:
            self.get_fruit_pose = jax.jit(self._get_fruit_pose)

        qp = self.brax_sys.default_qp()
        return qp
    
    def _get_kuka_joints(self, qp):
        joint_angle, joint_vel = self.brax_sys.joints[0].angle_vel(qp)
        return joint_angle, joint_vel
    
    def _get_kuka_target_pose(self, qp):
        # GT pose of the grasp target in PyBullet at [0] * 7
        # [7.89874410e-13 2.97529613e-07 1.55850000e+00]
        # [1.11022302e-16 2.80692743e-16 0.00000000e+00 1.00000000e+00]
        link_7_pos = qp.pos[self.first_node_index - 1]
        link_7_quat = qp.rot[self.first_node_index - 1] # in (w, x, y, z)
        link_7_rot_z = _quat2matrix(link_7_quat, z_only=True)
        grasp_target_pos = link_7_pos + 0.2975 * link_7_rot_z
        grasp_target_quat = jnp.roll(link_7_quat, -1) # in (x, y, z, w)

        # use link 7 velocity to represent grasp target velocity, angular velocity is wrong
        link_7_vel = qp.vel[self.first_node_index - 1]
        link_7_ang = qp.ang[self.first_node_index - 1]

        return grasp_target_pos, grasp_target_quat, link_7_vel, link_7_ang

    def _get_fruit_pose(self, qp):
        fruits_pos = qp.pos[list(self.fruit_info.keys()), :]
        fruits_quat = qp.rot[list(self.fruit_info.keys()), :] # in (w, x, y, z)
        fruits_quat = jnp.roll(fruits_quat, -1, axis=1) # in (x, y, z, w)
        return fruits_pos, fruits_quat
        
    def compute_robot_dynamics(self, q, dq):
        q = q.tolist() if not isinstance(q, list) else q
        dq = dq.tolist() if not isinstance(dq, list) else dq
        zero = [0] * len(q)

        # Compute inertial
        M = jnp.array(p.calculateMassMatrix(bodyUniqueId=self.kukaUid, objPositions=q, 
                                            physicsClientId=self._physics_client_id))

        # Compute jacobian
        J = p.calculateJacobian(bodyUniqueId=self.kukaUid, linkIndex=self.graspTargetLink, 
                                localPosition=[0, 0, 0], objPositions=q, objVelocities=zero, 
                                objAccelerations=zero, physicsClientId=self._physics_client_id)
        J = jnp.vstack(J)
        
        # Compute drift
        g_plus_c = jnp.asarray(p.calculateInverseDynamics(bodyUniqueId=self.kukaUid, objPositions=q, 
                                                          objVelocities=dq, objAccelerations=zero, 
                                                          physicsClientId=self._physics_client_id))
        
        # Compute task space inertial
        M_inv = jnp.linalg.inv(M)
        M_task = np.linalg.pinv(J @ M_inv @ J.T)

        # Compute jacobian inverse
        J_inv = M_inv @ J.T @ M_task

        # Compute the full nullspace
        N = jnp.eye(len(self.armJoints)) - J_inv @ J
        return M, M_task, J, J_inv, g_plus_c, N

    def compute_torque_from_eef_offset(self, x_err, constants):
        dx_err, q_err, dq_err, J, M_task, N, g_plus_c = constants
        tau_task = J.T @ M_task @ (- self.Kp @ x_err.reshape(-1, 1) - self.Kd @ dx_err.reshape(-1, 1))
        tau_joint = N.T @ (- self.Kqp @ q_err.reshape(-1, 1) - self.Kqd @ dq_err.reshape(-1, 1))
        tau = tau_task + tau_joint + g_plus_c.reshape(-1, 1)
        # turns out scale the torque here is very important
        return tau.reshape(-1,) / self.args['robot_actuator_strength']

def _quat2matrix(quat, z_only=False, wxyz=True):
    if wxyz:
        # quat in w, x, y, z
        qr, qi, qj, qk = quat
    else:
        qi, qj, qk, qr = quat
    s = 1 / (jnp.linalg.norm(quat) ** 2)
    
    r02 = 2 * s * (qi * qk + qj * qr)
    r12 = 2 * s * (qj * qk - qi * qr)
    r22 = 1 - 2 * s * (qi * qi + qj * qj)
    
    if z_only:
        return jnp.array([r02, r12, r22])
    
    r00 = 1 - 2 * s * (qj * qj + qk * qk)
    r10 = 2 * s * (qi * qj + qk * qr)
    r20 = 2 * s * (qi * qk - qj * qr)
    
    r01 = 2 * s * (qi * qj - qk * qr)
    r11 = 1 - 2 * s * (qi * qi + qk * qk)
    r21 = 2 * s * (qj * qk + qi * qr)
    
    return jnp.array([[r00, r01, r02],
                      [r10, r11, r12],
                      [r20, r21, r22]])
    
quat2matrix = jax.jit(_quat2matrix)
batch_quat2matrix = jax.jit(jax.vmap(quat2matrix, in_axes=[0])) # use wxyz

@jax.jit
def euler2matrix(angles):
    x, y = angles
    Rx = jnp.array([[1, 0, 0],
                    [0, jnp.cos(x), -jnp.sin(x)],
                    [0, jnp.sin(x), jnp.cos(x)]])
    Ry = jnp.array([[jnp.cos(y), 0, jnp.sin(y)],
                    [0, 1, 0],
                    [-jnp.sin(y), 0, jnp.cos(y)]])
    mat = Ry @ Rx
    mat = jnp.hstack((mat, jnp.array([0, 0, 0]).reshape(3, 1)))
    mat = jnp.vstack((mat, jnp.array([0, 0, 0, 1]).reshape(4,)))
    return mat

batch_euler2matrix = jax.jit(jax.vmap(euler2matrix, in_axes=[0]))

# Robot + Strawberry

## Init the robot and the env

In [None]:
link_infos = {}
mass = 0.5
radius = 0.005
link_height = 0.05
root_pos = [0, 0, 0]

n_links = [3, 2, 3, 2, 4, 5, 3, 4, 3, 5, 4, 5, 4, 4, 4]
# add_fruits = [True, False, False, False, False, True, False, False, False, True, False, True]
add_fruits = [False, False, False, False, False, False, False, False, False, False, False, False, True, True, True]
rots = [[60, 0, 0], [60, 0, 85], [70, 0, 195], [70, 0, 270],
        [40, 0, 0], [37, 0, 75], [57, 0, 170], [55, 0, 265],
        [25, 0, 0], [15, 0, 90], [40, 0, 190], [30, 0, 260],
        [45, 0, -70], [70, 0, 150], [45, 0, 225]]
rots = np.deg2rad(rots).tolist()

for i, (n_link, rot, add_fruit) in enumerate(zip(n_links, rots, add_fruits)):
    rot[2] += np.pi / 4
    link_infos[i] = [mass, radius, link_height, n_link, root_pos, rot, add_fruit]

In [None]:
sys_args = {'root': os.path.abspath(os.path.join(os.getcwd(), os.pardir)),
        
            'robot_joint_stiffness': 80000.0, # Make the robot very stiff by design
            'robot_joint_spring_damping': 800.0, # Velocity between parent/child bodies
            'robot_joint_limit_strength': 300.0,
            'robot_joint_angular_damping': 50.0, # Control the angular velocity between parent/child bodies
            'robot_actuator_strength': 100.0, # Used to scale the input torque

            'link_infos': link_infos, # link_infos
            'plant_root': jnp.array([0.4, 0, 0.0]),
}

random_key = jax.random.PRNGKey(0)
np.random.seed(428)

KukaStrawberry = BraxKukaObstacle(sys_args)
KukaStrawberry.load_robot()
KukaStrawberry.load_strawberry(link_infos=sys_args['link_infos'])

qp = KukaStrawberry.create_sys()
qp.pos[KukaStrawberry.first_node_index: ] += sys_args['plant_root']

# Control the robot to the rest pose
n_waypoints = 100
q, dq = KukaStrawberry.get_kuka_joints(qp)
waypoints = np.linspace(start=q, stop=KukaStrawberry.restPose, num=n_waypoints)
Kqp = jnp.diag(jnp.array([1, 1, 1, 1, 1, 1, 1])) * 2.0
Kqd = jnp.diag(jnp.array([1, 1, 1, 1, 1, 1, 1])) * 0.5
for step in range(2000):
    next_goal = waypoints[min(step, n_waypoints - 1)]
    q, dq = KukaStrawberry.get_kuka_joints(qp)
    err = jnp.rad2deg(next_goal - q).reshape(-1, 1)
    tau = Kqp @ err + Kqd @ dq.reshape(-1, 1)
    qp, info = KukaStrawberry.step(qp, tau.reshape(-1, ))

default_qp = qp

In [None]:
HTML(html.render(KukaStrawberry.brax_sys, [default_qp], height=480))

## Some useful functions

In [None]:
def get_args_constants(state):
    q, dq = KukaStrawberry.get_kuka_joints(state)
    M, M_task, J, J_inv, g_plus_c, N = KukaStrawberry.compute_robot_dynamics(q, dq)
    eef_pos, eef_quat, eef_vel, eef_ang = KukaStrawberry.get_kuka_target_pose(state)

    dx_err = jnp.hstack((eef_vel, eef_ang))
    q_err = q - KukaStrawberry.restPose
    dq_err = dq

    args = [dx_err, q_err, dq_err, J, M_task, N, g_plus_c]
    return args

@jax.jit
def dynamics(state, u, args):
    tau = KukaStrawberry.compute_torque_from_eef_offset(u, args[0:7])
    # for _ in range(3): # Step the same action for 3 substeps, controller gains are tunned based on these values
    state, info = KukaStrawberry.step(state, tau)
    return state, info

batch_dynamics = jax.jit(jax.vmap(dynamics, in_axes=[None, 0, None]))

@jax.jit
def cost(state, us, args):
    if us.shape[1] == 3:
        us = jnp.hstack((us, jnp.zeros_like(us)))

    # step the system
    eef_pos_init, eef_quat_init, eef_vel_init, eef_ang_init = KukaStrawberry.get_kuka_target_pose(state)
    logs = []
    for u in us:
        state, info = dynamics(state, u, args)
        contact_force = info.contact.vel[1: KukaStrawberry.first_node_index] # can clip with jnp.clip(f, -1, 1)
        contact_torque = info.contact.ang[1: KukaStrawberry.first_node_index] 
        eef_pos, eef_quat, eef_vel, eef_ang = KukaStrawberry.get_kuka_target_pose(state)
        logs.append([contact_force, contact_torque, eef_pos, eef_quat])
    eef_pos_next, eef_quat_next, eef_vel_next, eef_ang_next = KukaStrawberry.get_kuka_target_pose(state)

    # compute cost
    dist_to_goal_end = jnp.linalg.norm((eef_pos_next - args[-1]) * 100)

    dist_to_goal_step = jnp.array([jnp.linalg.norm((eef_pos - args[-1]) * 100) for _, _, eef_pos, _ in logs]).mean()
    contact_force_step = jnp.array([jnp.sum(jnp.square(contact_force)) for contact_force, _, _, _ in logs]).mean()
    dist_so_far_step = jnp.array([jnp.linalg.norm((eef_pos - eef_pos_init) * 100) for _, _, eef_pos, _ in logs]).mean()

    weights = args[-2]
    return weights[0] * contact_force_step + weights[1] * dist_to_goal_step + weights[2] * dist_so_far_step + dist_to_goal_end

batch_cost = jax.jit(jax.vmap(cost, in_axes=[None, 0, None]))
d_cost = jax.jit(jax.grad(cost, argnums=1))

def sample_translate_in_ball(random_key, rad=0.02, d=3, n_sample=32, horizon=3):
    samples = jax.random.ball(random_key, d=d, shape=(n_sample, horizon,)) * rad
    return samples # of shape (n_sample, horizon, 3)

def random_shooting(cost, state, u, args, n_best=10):
    costs = cost(state, u, args)
    best_idx = jnp.argsort(costs)[:n_best]
    u = jnp.mean(u[best_idx, :], axis=0)
    best_cost = jnp.mean(costs[best_idx])
    return u, best_cost

def gradient_descent(d_cost, batch_cost, state, u, args, alphas, n_best_grad=1):
    grad = d_cost(state, u, args)
    shift_u = jnp.repeat(u[np.newaxis, :, :], len(alphas), axis=0) - alphas[:, None, None] * jnp.repeat(grad[np.newaxis, :, :], len(alphas), axis=0)
    costs = batch_cost(state, shift_u, args)
    best_idx = jnp.argsort(costs)[:n_best_grad]
    u = jnp.mean(shift_u[best_idx, :], axis=0)
    best_cost = costs[best_idx].mean()
    return u, best_cost

## Motion Planning by Random Shooting

In [None]:
random_key = jax.random.PRNGKey(int(time.time()))

robot_step = 0.005
state = default_qp
states = [state]
logs = []

horizon = 3
n_translate = 10 # number of bins for translation of each step
n_best = 10
n_best_grad = 1
weight_contact = 19 # weight for the contact force
smoothing_coef = 0.97 # smooth coefficient for the control samples
goal_pos = KukaStrawberry.get_fruit_pose(state)[0][2]

# Controls are x_curr - x_target
controls = sample_translate_in_ball(random_key, rad=robot_step, d=3, horizon=horizon,
                                    n_sample=min(n_translate**horizon, 10000))
controls = jnp.dstack((controls, jnp.zeros(controls.shape)))
def body_fun(t, noises):
    return noises.at[:, t].set(smoothing_coef * noises[:, t - 1] + np.sqrt(1 - smoothing_coef**2) * noises[:, t])
controls = jax.lax.fori_loop(1, horizon, body_fun, controls)
controls = jnp.vstack((controls, jnp.zeros((1, horizon, 6))))

alphas = jax.numpy.linspace(0, 0.01, num=1000)

KukaStrawberry.Kp = jnp.diag(jnp.array([2450, 3993, 2937, 5778, 5663, 4219]))
KukaStrawberry.Kd = jnp.diag(jnp.array([38, 75, 45, 43, 47, 64]))
KukaStrawberry.Kqp = jnp.diag(jnp.array([50, 50, 50, 50, 50, 50, 50]))
KukaStrawberry.Kqd = jnp.diag(jnp.array([8, 8, 8, 8, 8, 8, 8]))

for step in range(300):

    ############################ Find the best u ############################
    # Prepare the arguments to avoid repeat computation
    args = get_args_constants(state)
    args += [weight_contact, goal_pos]
    
    # Random shooting to find the best initial action
    u, best_cost = random_shooting(batch_cost, state, controls, args=args, n_best=n_best)

    # # Apply gradient descent to improve the action locally
    # u, best_cost_after_grad = gradient_descent(d_cost, batch_cost, state, u, args, alphas, n_best_grad=n_best_grad)

    ############################ Step the system ############################
    state, info = dynamics(state, u[0], args)

    # if jnp.linalg.norm(KukaStrawberry.get_kuka_target_pose(state)[0] - goal_pos) < robot_step:
    #     break
    ############################ Log ############################    
    states.append(state)

    contact_force = info.contact.vel[1: KukaStrawberry.first_node_index]
    contact_force_magnitude = jnp.sum(jnp.square(contact_force))
    logs.append([contact_force_magnitude, best_cost])

## Motion Planning by CMAES

In [None]:
_EPS = 1e-8
_MEAN_MAX = 1e32
_SIGMA_MAX = 1e32

@jax.jit
def _eigen_decomposition(C):
    C = (C + C.T) / 2
    D2, B = jnp.linalg.eigh(C)
    D = jnp.sqrt(jnp.where(D2 < 0, _EPS, D2))
    C = jnp.dot(jnp.dot(B, jnp.diag(D ** 2)), B.T)
    return B, D

@jax.jit
def sample_white_gaussian(mean, sigma, B, D, z, n_dim):
    """z ~ N(0, I)"""
    y = cast(jnp.ndarray, B.dot(jnp.diag(D))).dot(z)  # ~ N(0, C)
    x = mean + sigma * y  # ~ N(m, σ^2 C)
    return x

batch_sample_white_gaussian = jax.jit(jax.vmap(sample_white_gaussian, in_axes=[None, None, None, None, 0, None]))

class CMAES():
    def __init__(self, mean, sigma, key, bounds, n_max_resampling=100):
        n_dim = len(mean)
        population_size = int(4 + np.floor(3 * np.log(n_dim)))  # (eq. 48): lambda
        mu = population_size // 2 # number of top candidates to update

        # (eq. 49)
        tmp = jnp.log((population_size + 1) / 2)
        weights_prime = jnp.array([tmp - np.log(i + 1) for i in range(population_size)])
        mu_eff = (jnp.sum(weights_prime[:mu]) ** 2) / jnp.sum(weights_prime[:mu] ** 2)
        mu_eff_minus = (jnp.sum(weights_prime[mu:]) ** 2) / jnp.sum(weights_prime[mu:] ** 2)

        alpha_cov = 2
        c1 = alpha_cov / ((n_dim + 1.3) ** 2 + mu_eff)
        # learning rate for the rank-μ update
        cmu = min(1 - c1 - 1e-8,  # 1e-8 is for large popsize
                  alpha_cov * (mu_eff - 2 + 1 / mu_eff) / ((n_dim + 2) ** 2 + alpha_cov * mu_eff / 2))
        
        min_alpha = min(1 + c1 / cmu,  # eq.50
                        1 + (2 * mu_eff_minus) / (mu_eff + 2),  # eq.51
                        (1 - c1 - cmu) / (n_dim * cmu))  # eq.52

        # (eq.53)
        positive_sum = jnp.sum(weights_prime[weights_prime > 0])
        negative_sum = jnp.sum(jnp.abs(weights_prime[weights_prime < 0]))
        weights = jnp.where(weights_prime >= 0,
                            1 / positive_sum * weights_prime,
                            min_alpha / negative_sum * weights_prime)
        cm = 1  # (eq. 54)

        # (eq.55)
        c_sigma = (mu_eff + 2) / (n_dim + mu_eff + 5)
        d_sigma = 1 + 2 * max(0, jnp.sqrt((mu_eff - 1) / (n_dim + 1)) - 1) + c_sigma
        
        # (eq.56)
        cc = (4 + mu_eff / n_dim) / (n_dim + 4 + 2 * mu_eff / n_dim)
        
        self._n_dim = n_dim
        self._popsize = population_size
        self._mu = mu
        self._mu_eff = mu_eff

        self._cc = cc
        self._c1 = c1
        self._cmu = cmu
        self._c_sigma = c_sigma
        self._d_sigma = d_sigma
        self._cm = cm

        # E||N(0, I)|| (p.28)
        self._chi_n = jnp.sqrt(self._n_dim) * (1.0 - (1.0 / (4.0 * self._n_dim)) + 1.0 / (21.0 * (self._n_dim ** 2)))

        self._weights = weights

        # evolution path
        self._p_sigma = jnp.zeros(self._n_dim)
        self._pc = jnp.zeros(self._n_dim)

        self._mean = mean

        self._C = jnp.eye(self._n_dim)

        self._sigma = sigma

        self._bounds = bounds
        self._n_max_resampling = n_max_resampling

        self._g = 0 # generation
        self._rng = key

        # Termination criteria
        self._tolx = 1e-5 * sigma
        self._tolxup = 1e5
        self._tolfun = 1e-5
        self._tolconditioncov = 1e14

        self._funhist_term = int(10 + jnp.ceil(30 * n_dim / population_size))
        self._funhist_values = jnp.empty(self._funhist_term * 2)

    def ask(self):
        B, D = _eigen_decomposition(self._C)
        self._rng, split = jax.random.split(self._rng)
        zs = jax.random.normal(split, shape=(self._popsize, self._n_dim)) # ~ N(0, I)
        xs = batch_sample_white_gaussian(self._mean, self._sigma, B, D, zs, self._n_dim)

        # repair the samples to the feasible set
        xs_repair = xs.clip(-self._bounds, self._bounds)
        return xs, xs_repair

    def tell(self, repaired_solutions, alpha=1):
        """repaired_solutions: [( x, x_repair, cost(x_repair) )]
           solutions: [( x, cost_fitness(x) )]"""
        self._g += 1

        # (eq.62) repair the cost function, choose alpha such that two terms are in the similar magnitude
        solutions = []
        for x, x_repair, c in repaired_solutions:
            dist = jnp.linalg.norm(x - x_repair)
            alpha = 10 ** (int(jnp.log10(c)) - int(jnp.log10(dist))) if dist > 0 else 0
            solutions.append((x, c + alpha * dist))

        # solutions = [(x, c + alpha * jnp.linalg.norm(x - x_repair)) for x, x_repair, c in repaired_solutions]
        solutions.sort(key=lambda s: s[1])

        # Stores 'best' and 'worst' values of the last 'self._funhist_term' generations.
        funhist_idx = 2 * (self._g % self._funhist_term)
        self._funhist_values.at[funhist_idx].set(solutions[0][1])
        self._funhist_values.at[funhist_idx + 1].set(solutions[-1][1])

        # Sample new population of search_points, for k=1, ..., popsize
        B, D = _eigen_decomposition(self._C)

        x_k = jnp.array([s[0] for s in solutions])  # ~ N(m, σ^2 C)
        y_k = (x_k - self._mean) / self._sigma  # ~ N(0, C)

        # Selection and recombination
        y_w = jnp.sum(y_k[: self._mu].T * self._weights[: self._mu], axis=1)  # eq.41
        self._mean += self._cm * self._sigma * y_w

        # Step-size control
        C_2 = cast(jnp.ndarray, cast(jnp.ndarray, B.dot(jnp.diag(1 / D))).dot(B.T))  # C^(-1/2) = B D^(-1) B^T
        self._p_sigma = (1 - self._c_sigma) * self._p_sigma + jnp.sqrt(
            self._c_sigma * (2 - self._c_sigma) * self._mu_eff
        ) * C_2.dot(y_w)

        norm_p_sigma = jnp.linalg.norm(self._p_sigma)
        self._sigma *= jnp.exp(
            (self._c_sigma / self._d_sigma) * (norm_p_sigma / self._chi_n - 1)
        )
        self._sigma = min(self._sigma, _SIGMA_MAX)

        # Covariance matrix adaption
        h_sigma_cond_left = norm_p_sigma / jnp.sqrt(
            1 - (1 - self._c_sigma) ** (2 * (self._g + 1))
        )
        h_sigma_cond_right = (1.4 + 2 / (self._n_dim + 1)) * self._chi_n
        h_sigma = 1.0 if h_sigma_cond_left < h_sigma_cond_right else 0.0  # (p.28)

        # (eq.45)
        self._pc = (1 - self._cc) * self._pc + h_sigma * jnp.sqrt(
            self._cc * (2 - self._cc) * self._mu_eff
        ) * y_w

        # (eq.46)
        w_io = self._weights * jnp.where(
            self._weights >= 0,
            1,
            self._n_dim / (jnp.linalg.norm(C_2.dot(y_k.T), axis=0) ** 2 + _EPS),
        )

        delta_h_sigma = (1 - h_sigma) * self._cc * (2 - self._cc)  # (p.28)

        # (eq.47)
        rank_one = jnp.outer(self._pc, self._pc)
        rank_mu = jnp.sum(
            jnp.array([w * jnp.outer(y, y) for w, y in zip(w_io, y_k)]), axis=0
        )
        self._C = ((1 + self._c1 * delta_h_sigma - self._c1 - self._cmu * jnp.sum(self._weights)) * self._C
                   + self._c1 * rank_one + self._cmu * rank_mu)

    def should_stop(self):
        B, D = _eigen_decomposition(self._C)
        dC = jnp.diag(self._C)

        # Stop if the range of function values of the recent generation is below tolfun.
        if (
            self._g > self._funhist_term
            and jnp.max(self._funhist_values) - jnp.min(self._funhist_values)
            < self._tolfun
        ):
            return True

        # Stop if the std of the normal distribution is smaller than tolx
        # in all coordinates and pc is smaller than tolx in all components.
        if jnp.all(self._sigma * dC < self._tolx) and jnp.all(
            self._sigma * self._pc < self._tolx
        ):
            return True

        # Stop if detecting divergent behavior.
        if self._sigma * jnp.max(D) > self._tolxup:
            return True

        # No effect coordinates: stop if adding 0.2-standard deviations
        # in any single coordinate does not change m.
        if jnp.any(self._mean == self._mean + (0.2 * self._sigma * jnp.sqrt(dC))):
            return True

        # No effect axis: stop if adding 0.1-standard deviation vector in
        # any principal axis direction of C does not change m. "pycma" check
        # axis one by one at each generation.
        i = self._g % self._n_dim
        if jnp.all(self._mean == self._mean + (0.1 * self._sigma * D[i] * B[:, i])):
            return True

        # Stop if the condition number of the covariance matrix exceeds 1e14.
        condition_cov = jnp.max(D) / jnp.min(D)
        if condition_cov > self._tolconditioncov:
            return True

        return False

def cmaes_mpc(key, mean, sigma, bound, max_step, horizon):
    solver = CMAES(mean, sigma, key, bound)
    for cmaes_step in range(max_step):
        xs, xs_repair = solver.ask()
        controls = xs_repair.reshape(solver._popsize, horizon, -1)
        costs = batch_cost(state, controls, args)
        repaired_solutions = list(zip(xs, xs_repair, costs))
        solver.tell(repaired_solutions)
        if solver.should_stop():
            break
    u = solver._mean.reshape(horizon, -1).clip(-bound, bound)
    u = jnp.hstack((u, jnp.zeros_like(u)))
    return u, cmaes_step

In [None]:
random_key = jax.random.PRNGKey(int(time.time()))

robot_step = 0.005
state = default_qp
states = [state]
logs = []

horizon = 3
weight_contact = 20 # weight for the contact force
weight_dist_to_goal = 1 # weight for distance to goal
weight_dist_so_far = 1 # weight for distance so far

goal_pos = KukaStrawberry.get_fruit_pose(state)[0][1]

# Controls are x_curr - x_target
mean = jnp.zeros(3 * horizon)
sigma = robot_step
bound = robot_step * 1.5 # max distance the robot can move
max_step = 1000

KukaStrawberry.Kp = jnp.diag(jnp.array([2450, 3993, 2937, 5778, 5663, 4219]))
KukaStrawberry.Kd = jnp.diag(jnp.array([38, 75, 45, 43, 47, 64]))
KukaStrawberry.Kqp = jnp.diag(jnp.array([50, 50, 50, 50, 50, 50, 50]))
KukaStrawberry.Kqd = jnp.diag(jnp.array([8, 8, 8, 8, 8, 8, 8]))

for step in range(150):

    ############################ Find the best u ############################
    # Prepare the arguments to avoid repeat computation
    args = get_args_constants(state)
    args += [(weight_contact, weight_dist_to_goal, weight_dist_so_far), goal_pos]
    
    # CMAES to find the best action
    random_key, key = jax.random.split(random_key)
    u, cmaes_step = cmaes_mpc(key, mean, sigma, bound, max_step, horizon)

    ############################ Step the system ############################
    state, info = dynamics(state, u[0], args)
    best_cost = batch_cost(state, u.reshape(1, horizon, -1), args)

    ############################ Log ############################    
    states.append(state)

    contact_force = info.contact.vel[1: KukaStrawberry.first_node_index]
    contact_force_magnitude = jnp.sum(jnp.square(contact_force))
    logs.append([contact_force_magnitude, best_cost, u, cmaes_step])

In [None]:
HTML(html.render(KukaStrawberry.brax_sys, states, height=480))