In [1]:
import jax.numpy as jnp
import jax

In [2]:
class Centroidal_Model_JAX:
    def __init__(self) -> None: 
        """
        This method initializes the foothold generator Centroidal_Model, which creates
        the prediction model of the NMPC.
        """
        self.state_dim = 24
        self.dt = 0.02

        self.device = jax.devices('gpu')[0]

        # Mass and Inertia robot dependant
        self.mass = 24.64
        self.inertia = jnp.array([[ 0.2310941359705289,   -0.0014987128245817424, -0.021400468992761768 ], # shape (3,3)
                                  [-0.0014987128245817424, 1.4485084687476608,     0.0004641447134275615],
                                  [-0.021400468992761768,  0.0004641447134275615,  1.503217877350808    ]])
        
        # We precompute the inverse of the inertia
        self.inertia_inv = self.calculate_inverse(self.inertia)

        # Initialize the cost function matrices
        self.Q = jnp.identity(self.state_dim)*0
        self.Q = self.Q.at[0,0].set(0.0)
        self.Q = self.Q.at[1,1].set(0.0)
        self.Q = self.Q.at[2,2].set(111500) #com_z
        self.Q = self.Q.at[3,3].set(5000) #com_vel_x
        self.Q = self.Q.at[4,4].set(5000) #com_vel_y
        self.Q = self.Q.at[5,5].set(200) #com_vel_z
        self.Q = self.Q.at[6,6].set(11200) #base_angle_roll
        self.Q = self.Q.at[7,7].set(11200) #base_angle_pitch
        self.Q = self.Q.at[8,8].set(0.0) #base_angle_yaw
        self.Q = self.Q.at[9,9].set(20) #base_angle_rates_x
        self.Q = self.Q.at[10,10].set(20) #base_angle_rates_y
        self.Q = self.Q.at[11,11].set(600) #base_angle_rates_z
    

    def calculate_inverse(self, A):
        a11 = A[0, 0]
        a12 = A[0, 1]
        a13 = A[0, 2]
        a21 = A[1, 0]
        a22 = A[1, 1]
        a23 = A[1, 2]
        a31 = A[2, 0]
        a32 = A[2, 1]
        a33 = A[2, 2]

        # Calculate the determinant DET of A
        DET = a11 * (a33 * a22 - a32 * a23) - a21 * (a33 * a12 - a32 * a13) + a31 * (a23 * a12 - a22 * a13)

        # Calculate the inverse of A
        return jnp.array([
            [(a33 * a22 - a32 * a23), -(a33 * a12 - a32 * a13), (a23 * a12 - a22 * a13)],
            [-(a33 * a21 - a31 * a23), (a33 * a11 - a31 * a13), -(a23 * a11 - a21 * a13)],
            [(a32 * a21 - a31 * a22), -(a32 * a11 - a31 * a12), (a22 * a11 - a21 * a12)]
        ]) / DET

    
    def fd(self, states: jnp.ndarray, inputs: jnp.ndarray, contact_status: jnp.ndarray):
        """
        This method computes the state derivative of the system.
        """
    
        def skew(v):
            return jnp.array([[0, -v[2], v[1]], 
                            [v[2], 0, -v[0]], 
                            [-v[1], v[0], 0]])
        
        # Extracting variables for clarity
        foot_position_fl, foot_position_fr, foot_position_rl, foot_position_rr = jnp.split(states[12:], 4)
        foot_force_fl, foot_force_fr, foot_force_rl, foot_force_rr = jnp.split(inputs[12:], 4)
        com_position = states[:3]
        stanceFL, stanceFR, stanceRL, stanceRR = contact_status[:4]

        # Compute linear_com_vel
        linear_com_vel = states[3:6]
        

        # Compute linear_com_acc
        temp = jnp.dot(foot_force_fl, stanceFL) + jnp.dot(foot_force_fr, stanceFR) + jnp.dot(foot_force_rl, stanceRL) + jnp.dot(foot_force_rr, stanceRR)
        gravity = jnp.array([jnp.float32(0), jnp.float32(0), jnp.float32(-9.81)])
        linear_com_acc = jnp.dot(jnp.float32(1) / self.mass, temp) + gravity
        

        # Compute euler_rates_base and angular_acc_base
        w = states[9:12]
        roll, pitch, yaw = states[6:9]

    
        conj_euler_rates = jnp.array([
            [jnp.float32(1), jnp.float32(0), -jnp.sin(pitch)],
            [jnp.float32(0), jnp.cos(roll), jnp.cos(pitch) * jnp.sin(roll)],
            [jnp.float32(0), -jnp.sin(roll), jnp.cos(pitch) * jnp.cos(roll)]
        ])

        
        
        temp2 = jnp.dot(skew(foot_position_fl - com_position), foot_force_fl) * stanceFL
        temp2 += jnp.dot(skew(foot_position_fr - com_position), foot_force_fr) * stanceFR
        temp2 += jnp.dot(skew(foot_position_rl - com_position), foot_force_rl) * stanceRL
        temp2 += jnp.dot(skew(foot_position_rr - com_position), foot_force_rr) * stanceRR

        
        euler_rates_base = jnp.dot(self.calculate_inverse(conj_euler_rates), w)


        # FINAL angular_acc_base STATE (4)
        #Z Y X rotations!
        b_R_w = jnp.array([
                          [jnp.cos(pitch)*jnp.cos(yaw), jnp.cos(pitch)*jnp.sin(yaw), -jnp.sin(pitch)],
                          [jnp.sin(roll)*jnp.sin(pitch)*jnp.cos(yaw)-jnp.cos(roll)*jnp.sin(yaw), jnp.sin(roll)*jnp.sin(pitch)*jnp.sin(yaw)+jnp.cos(roll)*jnp.cos(yaw), jnp.sin(roll)*jnp.cos(pitch)],
                          [jnp.cos(roll)*jnp.sin(pitch)*jnp.cos(yaw)+jnp.sin(roll)*jnp.sin(yaw), jnp.cos(roll)*jnp.sin(pitch)*jnp.sin(yaw)-jnp.sin(roll)*jnp.cos(yaw), jnp.cos(roll)*jnp.cos(pitch)]
                          ])


        
        
        angular_acc_base = -jnp.dot(self.inertia_inv, jnp.dot(skew(w), jnp.dot(self.inertia,  w))) + jnp.dot(self.inertia_inv, jnp.dot(b_R_w, temp2))
        
        
        # Returning the results
        return jnp.concatenate([linear_com_vel, linear_com_acc, euler_rates_base, angular_acc_base])

       


    def integrate_jax(self, state, inputs, contact_status):
        """
        This method computes the forward evolution of the system.
        """
        fd = self.fd(state, inputs, contact_status)


        # Simple euler!
        new_state = state[0:12] + fd*self.dt

        return jnp.concatenate([new_state, state[12:]])

In [8]:
model = Centroidal_Model_JAX()
num_samples = 1
num_legs=4
horizon=10
device = jax.devices('gpu')[0]


pos_com_lw      = jnp.array((0.0, 0.0, 0.35))
lin_com_vel_lw  = jnp.array((0.0, 0.0, 0.0))
euler_xyz_angle = jnp.array((0.0, 0.0, 0.0))
ang_vel_com_b   = jnp.array((0.0, 0.0, 0.30))
p_lw            = jnp.array(([0.2, 0.1, 0.0, 0.2, -0.1, 0.0, -0.2, 0.1, 0.0, -0.2, -0.1, 0.0]))

pos_com_lw_ref      = jnp.array((0.0, 0.0, 0.38))
lin_com_vel_lw_ref  = jnp.array((0.0, 0.0, 0.0))
euler_xyz_angle_ref = jnp.array((0.0, 0.0, 0.0))
ang_vel_com_b_ref   = jnp.array((0.0, 0.0, 0.00))
p_lw_ref            = jnp.array(([1000.2, 1000.1, 110.0, 10.2, -110.1, 110.0, -110.2, 110.1, 110.0, -10.2, -110.1, 110.0]))

state = jnp.concatenate((pos_com_lw,
                 lin_com_vel_lw, 
                 euler_xyz_angle, 
                 ang_vel_com_b, 
                 p_lw))

reference = jnp.concatenate((pos_com_lw_ref,
                 lin_com_vel_lw_ref, 
                 euler_xyz_angle_ref, 
                 ang_vel_com_b_ref, 
                 p_lw_ref))



input_p_lw = jnp.array(([0.2, 0.1, 0.0, 0.2, -0.1, 0.0, -0.2, 0.1, 0.0, -0.2, -0.1, 0.0]))
input_F_lw = jnp.array(([0.0, 0.0, 50.0, 0.0, 0.0, 50.0, 0.0, 0.0, 50.0, 0.0, 0.0, 50.0]))

input = jnp.concatenate((input_p_lw,
                         input_F_lw))

contact =  jnp.array((1.0, 1.0, 1.0, 1.0))

for i in range(horizon):
    new_state = model.integrate_jax(state, input, contact)

                # Calculate cost regulation state
    state_error = new_state - reference[0:model.state_dim]
    error_cost = state_error.T@model.Q@state_error

    print(new_state[0:3])
    print(new_state[3:6])
    print(new_state[6:9])
    print(new_state[9:12])
    print(new_state[12:])
    print(state_error)
    print(error_cost)
    print()

    state = new_state


[0.   0.   0.35]
[ 0.          0.         -0.03386234]
[0.    0.    0.006]
[3.7919692e-06 2.6597363e-05 3.0000007e-01]
[ 0.2  0.1  0.   0.2 -0.1  0.  -0.2  0.1  0.  -0.2 -0.1  0. ]
[ 0.0000000e+00  0.0000000e+00 -3.0000001e-02  0.0000000e+00
  0.0000000e+00 -3.3862341e-02  0.0000000e+00  0.0000000e+00
  6.0000001e-03  3.7919692e-06  2.6597363e-05  3.0000007e-01
 -1.0000000e+03 -1.0000000e+03 -1.1000000e+02 -1.0000000e+01
  1.1000000e+02 -1.1000000e+02  1.1000000e+02 -1.1000000e+02
 -1.1000000e+02  1.0000000e+01  1.1000000e+02 -1.1000000e+02]
154.57938

[0.         0.         0.34932274]
[ 0.          0.         -0.06772468]
[7.5839381e-08 5.3194725e-07 1.2000002e-02]
[7.5458811e-06 5.3214848e-05 3.0000010e-01]
[ 0.2  0.1  0.   0.2 -0.1  0.  -0.2  0.1  0.  -0.2 -0.1  0. ]
[ 0.0000000e+00  0.0000000e+00 -3.0677259e-02  0.0000000e+00
  0.0000000e+00 -6.7724682e-02  7.5839381e-08  5.3194725e-07
  1.2000002e-02  7.5458811e-06  5.3214848e-05  3.0000010e-01
 -1.0000000e+03 -1.0000000e+03 -1.1

In [4]:
model = Centroidal_Model_JAX()
num_samples = 1
horizon = 20
num_legs=4
device = jax.devices('gpu')[0]


pos_com_lw      = jnp.array((0.1, -0.23, 0.35))
lin_com_vel_lw  = jnp.array((0.04, -0.02, 0.09))
euler_xyz_angle = jnp.array((0.2, 0.1, -0.01))
ang_vel_com_b   = jnp.array((-0.1, 0.1, 0.30))
p_lw            = jnp.array(([0.2, 0.1, 0.0, 0.2, -0.1, 0.0, -0.2, 0.1, 0.0, -0.2, -0.1, 0.0]))

pos_com_lw_ref      = jnp.array((0.0, 0.0, 0.38))
lin_com_vel_lw_ref  = jnp.array((0.0, 0.0, 0.0))
euler_xyz_angle_ref = jnp.array((0.0, 0.0, 0.0))
ang_vel_com_b_ref   = jnp.array((0.0, 0.0, 0.00))
p_lw_ref            = jnp.array(([1000.2, 1000.1, 110.0, 10.2, -110.1, 110.0, -110.2, 110.1, 110.0, -10.2, -110.1, 110.0]))

state = jnp.concatenate((pos_com_lw,
                 lin_com_vel_lw, 
                 euler_xyz_angle, 
                 ang_vel_com_b, 
                 p_lw))

reference = jnp.concatenate((pos_com_lw_ref,
                 lin_com_vel_lw_ref, 
                 euler_xyz_angle_ref, 
                 ang_vel_com_b_ref, 
                 p_lw_ref))



input_p_lw = jnp.array(([0.2, 0.1, 0.0, 0.2, -0.1, 0.0, -0.2, 0.1, 0.0, -0.2, -0.1, 0.0]))
input_F_lw = jnp.array(([100.0, 5.0, 40.0, -5.0, -5.0, 50.0, 10.0, -10.0, 60.0, 0.5, 10.0, -10.0]))

input = jnp.concatenate((input_p_lw,
                         input_F_lw))

contact =  jnp.array((1.0, 1.0, 0.0, 1.0))
cost = 0

for i in range(horizon):
    new_state = model.integrate_jax(state, input, contact)

                # Calculate cost regulation state
    state_error = new_state - reference[0:model.state_dim]
    error_cost = state_error.T@model.Q@state_error

    cost += error_cost
    print(error_cost)

    state = new_state

print(new_state[0:3])
print(new_state[3:6])
print(new_state[6:9])
print(new_state[9:12])
print(new_state[12:])
print(state_error)
print(cost)
print()



810.0715
1461.3267
2749.7078
4843.52
8126.4937
13401.146
22155.031
36753.47
60249.29
95377.83
142510.75
197290.9
250453.1
293324.88
328614.1
376739.12
465682.2
608797.94
788790.25
950610.0
[ 0.41056162 -0.20715585 -0.11280669]
[ 1.5903242   0.14233768 -2.5352986 ]
[ 7.0926466  1.1169364 -2.8043485]
[ 25.409197    1.7612581 -23.062546 ]
[ 0.2  0.1  0.   0.2 -0.1  0.  -0.2  0.1  0.  -0.2 -0.1  0. ]
[ 4.1056162e-01 -2.0715585e-01 -4.9280667e-01  1.5903242e+00
  1.4233768e-01 -2.5352986e+00  7.0926466e+00  1.1169364e+00
 -2.8043485e+00  2.5409197e+01  1.7612581e+00 -2.3062546e+01
 -1.0000000e+03 -1.0000000e+03 -1.1000000e+02 -1.0000000e+01
  1.1000000e+02 -1.1000000e+02  1.1000000e+02 -1.1000000e+02
 -1.1000000e+02  1.0000000e+01  1.1000000e+02 -1.1000000e+02]
4648741.0



### Test 3

In [3]:
model = Centroidal_Model_JAX()
num_samples = 1
horizon = 4
num_legs=4
device = jax.devices('gpu')[0]


pos_com_lw      = jnp.array((1.1, -0.23, 0.38))
lin_com_vel_lw  = jnp.array((0.24, -1.02, 0.19))
euler_xyz_angle = jnp.array((3.2, 0.1, -6.0))
ang_vel_com_b   = jnp.array((-1.1, 0.5, 0.45))
p_lw            = jnp.array(([0.2, 0.2, 0.5, 0.25, -0.1, -0.3, -0.2, 0.1, 1.0, -2.2, -0.1, 0.0]))

pos_com_lw_ref      = jnp.array((10.0, 0.0, 0.38))
lin_com_vel_lw_ref  = jnp.array((0.1, -0.3, 0.1))
euler_xyz_angle_ref = jnp.array((0.2, -0.2, 0.05))
ang_vel_com_b_ref   = jnp.array((0.3, -0.1, 0.10))
p_lw_ref            = jnp.array(([1000.2, 1000.1, 110.0, 10.2, -110.1, 110.0, -110.2, 110.1, 110.0, -10.2, -110.1, 110.0]))

state = jnp.concatenate((pos_com_lw,
                 lin_com_vel_lw, 
                 euler_xyz_angle, 
                 ang_vel_com_b, 
                 p_lw))

reference = jnp.concatenate((pos_com_lw_ref,
                 lin_com_vel_lw_ref, 
                 euler_xyz_angle_ref, 
                 ang_vel_com_b_ref, 
                 p_lw_ref))



input_p_lw = jnp.array(([101.0, 0.0, 0.0, 0.0, 101.0, 0.0, 0.0, 0.0, 90.0, 10.0, 20.0, 30.0]))
input_F_lw = jnp.array(([100.0, 5.0, 40.0, -5.0, -5.0, 50.0, 10.0, -10.0, 60.0, 0.5, 10.0, -20.0]))

input = jnp.concatenate((input_p_lw,
                         input_F_lw))

contact =  jnp.array((1.0, 1.0, 0.0, 1.0))
cost = 0

for i in range(horizon):
    new_state = model.integrate_jax(state, input, contact)

                # Calculate cost regulation state
    state_error = new_state - reference[0:model.state_dim]
    error_cost = state_error.T@model.Q@state_error

    cost += error_cost
    print(error_cost)

    state = new_state

    print('pos ',new_state[0:3])
    print('vel ',new_state[3:6])
    print('ang ',new_state[6:9])
    print('ome ',new_state[9:12])
    print(new_state[12:])
print(state_error)
print(cost)
print()



104191.67
pos  [ 1.1048 -0.2504  0.3838]
vel  [ 0.3175162  -1.0118831   0.05061817]
ang  [ 3.17704     0.09054242 -6.0096164 ]
ome  [2.0406623  0.26612943 1.4828725 ]
[ 0.2   0.2   0.5   0.25 -0.1  -0.3  -0.2   0.1   1.   -2.2  -0.1   0.  ]
109671.1
pos  [ 1.1111503  -0.27063766  0.38481236]
vel  [ 0.39503244 -1.0037663  -0.08876365]
ang  [ 3.2151453   0.08627424 -6.0395665 ]
ome  [5.242951  0.0738503 2.5228996]
[ 0.2   0.2   0.5   0.25 -0.1  -0.3  -0.2   0.1   1.   -2.2  -0.1   0.  ]
121502.36
pos  [ 1.1190509  -0.29071298  0.3830371 ]
vel  [ 0.47254866 -0.9956494  -0.22814548]
ang  [ 3.3156428  0.0885092 -6.0901847]
ome  [8.509395   0.08533783 3.603424  ]
[ 0.2   0.2   0.5   0.25 -0.1  -0.3  -0.2   0.1   1.   -2.2  -0.1   0.  ]
140386.19
pos  [ 1.1285019  -0.31062597  0.37847418]
vel  [ 0.5500649 -0.9875325 -0.3675273]
ang  [ 3.4795058   0.09930853 -6.16174   ]
ome  [11.842703   0.4794737  4.724822 ]
[ 0.2   0.2   0.5   0.25 -0.1  -0.3  -0.2   0.1   1.   -2.2  -0.1   0.  ]
[-8.871498