In [1]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
import osqp
import scipy.sparse as sp
import mediapy as media
from utils import quat_to_axis_angle, axis_angle_to_quat, L_mult, state_error, apply_delta_x, state_from_mujoco
import mujoco

# Set up MuJoCo for simulation
model = mujoco.MjModel.from_xml_path("models/go2_scene.xml")
data = mujoco.MjData(model)
model.opt.timestep = 0.0005 # The controller runs at 100 Hz in sim, and we simulate the low-level PD controller at 2 kHz

In [2]:
### Load problem data
# The QP consists of the cost hessian (H), cost gradient (g), constraint jacobian (A), and constraint bounds (lb, ub)
# The constraint jacobian (A) is fixed based on the dynamics
# The cost gradient and constraint bounds (g, l, u) change in time with the reference (X_ref, U_ref) and initial condition
prob_data = h5py.File( "quad_walking_data.h5", "r")
H, A = prob_data["H"], prob_data["A"][:].T # Cost hessian and constraint jacobian
gs, ls, us = prob_data["gs"][:], prob_data["ls"][:], prob_data["us"][:] # gradient and bounds (vary in time)
X_ref, U_ref = prob_data['X_ref'][:], prob_data['U_ref'][:] # Reference
x_lin, u_lin = prob_data['x_lin'], prob_data['u_lin'] # Linearization point
x_idx, u_idx = prob_data['x_idx'], prob_data['u_idx'] # Indexing helpers (NOTE these are 1-indexed because of Julia)
N_ref = len(X_ref) - 2

In [4]:
# Setup the convex linear MPC QP in OSQP
prob = osqp.OSQP()
prob.setup(P = sp.csc_matrix(H), q=gs[0, :], A=sp.csc_matrix(A), l=np.zeros_like(ls[0, :]), u=np.zeros_like(us[0, :]), 
            check_termination=100,  
                 max_iter=20,
                 eps_abs=1e-05,
                 eps_rel=0,
                 eps_dual_inf=1e-9,
                 eps_prim_inf=1e-9,
                 verbose=False,
                 warm_start=True,
                 scaling=False,
                 adaptive_rho=False) # disables refactorization online

# PD controller terms
Kp = 30*np.eye(12)
Kd = 2*np.eye(12)
K_pd = K_pd = np.hstack([np.zeros((12, model.nq - 12)), Kp, np.zeros((12, model.nv - 12)), Kd])

# Reset mujoco
mujoco.mj_resetData(model, data)  # Reset state and time.

# Set initial condition
data.qpos = X_ref[0, :19]
data.qpos[0] = -0.2
data.qpos[3:7] = axis_angle_to_quat(np.pi/2*np.array([0, 0, 1.0]))
mujoco.mj_forward(model, data)

frames = []
with mujoco.Renderer(model) as renderer:
    for k in range(5*N_ref):
        # Set initial condition
        g, lb, ub = gs[k % N_ref, :].copy(), ls[k % N_ref, :].copy(), us[k % N_ref, :].copy()
        x = (np.concatenate([data.qpos, data.qvel]))
        delta_x = state_error(x, x_lin)

        # Modify state error to encourage good tracking without going too fast
        if np.linalg.norm(delta_x[5]) > 1e-3:
            delta_x[5] = delta_x[5] * min(np.linalg.norm(delta_x[5]), 0.2) / np.linalg.norm(delta_x[5])

        if np.linalg.norm(delta_x[5]) < 1e-1:
            if np.linalg.norm(delta_x[0:2]) > 1e-3:
                delta_x[0:2] = delta_x[0:2] * min(np.linalg.norm(delta_x[0:2]), 0.18) / np.linalg.norm(delta_x[0:2])
        else:
            if np.linalg.norm(delta_x[0:2]) > 1e-3:
                delta_x[0:2] = delta_x[0:2] * min(np.linalg.norm(delta_x[0:2]), 0.05) / np.linalg.norm(delta_x[0:2])

        lb[:36] = delta_x
        ub[:36] = delta_x

        # Update
        prob.update(q=g, l=lb, u=ub)

        # Solve
        result = prob.solve()

        # Get control
        delta_u = result.x[u_idx[0]-1][:12]

        # Remove PD terms that are in MPC
        x_d = X_ref[(k+1) % N_ref, :]
        u_ffwd = u_lin[:12] + delta_u - K_pd@x_d

        for _ in range(20): # Run PD controller at 2 kHz
            # Update PD term
            x = (np.concatenate([data.qpos, data.qvel]))
            delta_x = state_error(x, x_d)
            u = u_ffwd - Kp @ delta_x[6:18] - Kd @ delta_x[24:36]

            # Apply control
            data.qfrc_applied[6:18] = u

            mujoco.mj_step(model, data)

        if k%2 == 0: # Save frame for visualization
            renderer.update_scene(data)
            pixels = renderer.render()
            frames.append(pixels)

media.show_video(frames, fps=50)

0
This browser does not support the video tag.
