# Hybrid Trajectory Optimization: Single Leg Hopper

In [1]:
import matplotlib.pyplot as plt
import mujoco
import mujoco.viewer
import numpy as np
import os
from pathlib import Path
import time
import control
import cyipopt
import jax
import jax.numpy as jnp


### Parameters

In [None]:
# -----------------------------------------------------------------------------
# Global Parameters (consistent with the Sympy derivation)
# -----------------------------------------------------------------------------
Nx = 8     # number of states
Nu = 2     # number of controls
Tfinal = 4.4
h = 0.1    # integration step
Nm = 5     # number of steps in each mode
Nt = int(np.ceil(Tfinal/h) + 1)
Nmodes = int(np.ceil(Nt / Nm))
thist = np.arange(0, h * Nt, h)
n_nlp = Nx * Nt + Nu * (Nt - 1)

# Number of constraints in each block
n_init = Nx
n_term = Nx
n_dyn = Nx*(Nt-1)
n_stance = int(np.ceil(Nmodes/2)*Nm)
n_length = Nt

# Starting offsets
off_init   = 0
off_term   = off_init + n_init
off_dyn    = off_term + n_term
off_stance = off_dyn + n_dyn
off_length = off_stance + n_stance

# Build index arrays (all contiguous, row-major)
c_init_inds   = jnp.arange(off_init,   off_init + n_init)
c_term_inds   = jnp.arange(off_term,   off_term + n_term)
c_dyn_inds    = jnp.arange(off_dyn,    off_dyn  + n_dyn)
c_stance_inds = jnp.arange(off_stance, off_stance + n_stance)
c_length_inds = jnp.arange(off_length, off_length + n_length)
m_nlp = off_length + n_length

### Dynamics (Using Jax)

In [None]:
# Dynamics parameters
g = 9.81
m1 = 5.0
m2 = 1.0


def flight_dynamics_jax(x, u):
    r1, r2 = x[:2], x[2:4]
    v = x[4:8]
    diff = r1 - r2
    norm = jnp.linalg.norm(diff)
    l1 = diff[0]/norm
    l2 = diff[1]/norm
    B = jnp.array([[ l1,  l2],
                   [ l2, -l1],
                   [-l1, -l2],
                   [-l2,  l1]])
    M_inv = jnp.diag(jnp.array([1/m1, 1/m1, 1/m2, 1/m2]))
    v_dot = jnp.array([0., -g, 0., -g]) + M_inv @ (B @ u)
    return jnp.concatenate([v, v_dot])

def stance_dynamics_jax(x, u):
    r1, r2 = x[:2], x[2:4]
    v = x[4:8]
    diff = r1 - r2
    norm = jnp.linalg.norm(diff)
    l1 = diff[0]/norm
    l2 = diff[1]/norm
    B = jnp.array([[l1,  l2],
                   [l2, -l1],
                   [0.,   0.],
                   [0.,   0.]])
    M_inv = jnp.diag(jnp.array([1/m1, 1/m1, 1/m2, 1/m2]))
    v_dot = jnp.array([0., -g, 0., 0.]) + M_inv @ (B @ u)
    return jnp.concatenate([v, v_dot])

def jump_map_jax(x):
    return jnp.concatenate([x[:6], jnp.zeros(2)])

# --- RK4 integrators ---
def flight_rk4(x, u):
    k1 = flight_dynamics_jax(x, u)
    k2 = flight_dynamics_jax(x + 0.5*h*k1, u)
    k3 = flight_dynamics_jax(x + 0.5*h*k2, u)
    k4 = flight_dynamics_jax(x +   h*k3, u)
    return x + (h/6)*(k1 + 2*k2 + 2*k3 + k4)

def stance_rk4(x, u):
    k1 = stance_dynamics_jax(x, u)
    k2 = stance_dynamics_jax(x + 0.5*h*k1, u)
    k3 = stance_dynamics_jax(x + 0.5*h*k2, u)
    k4 = stance_dynamics_jax(x +   h*k3, u)
    return x + (h/6)*(k1 + 2*k2 + 2*k3 + k4)

# --- JAX Jacobians ---
flight_jac_dx = jax.jacfwd(flight_rk4, argnums=0)
flight_jac_du = jax.jacfwd(flight_rk4, argnums=1)
stance_jac_dx = jax.jacfwd(stance_rk4, argnums=0)
stance_jac_du = jax.jacfwd(stance_rk4, argnums=1)
jump_jacobian = jax.jacfwd(jump_map_jax)  # Now jump_jacobian takes one argument.

# --- JIT & Vmap versions ---
flight_jac = jax.jit(jax.jacfwd(flight_rk4, argnums=(0,1)))
stance_jac = jax.jit(jax.jacfwd(stance_rk4, argnums=(0,1)))
jump_jac = jax.jit(jump_jacobian)

def jump_jac_wrapper(x, u):
    A_f = flight_jac(x, u)
    J_jump = jump_jac(flight_rk4(x, u))
    # Compose: both A and B are multiplied.
    return (J_jump @ A_f[0], J_jump @ A_f[1])


### Reference Trajectory

In [None]:
# -----------------------------------------------------------------------------
# Reference Trajectory
# -----------------------------------------------------------------------------

# Need to convert these functions to jax:
# Cost weights
Q = jnp.diag(jnp.ones(Nx))
R = 0.001
Qn = Q.copy()

# Reference Trajectory (example)
uref = jnp.array([m1 * g, 0.0])
# Using .at[].set for immutable updates:
xref = jnp.zeros((Nx, Nt))
xref = xref.at[0, :].set(jnp.linspace(-1.0, 1.0, Nt))
xref = xref.at[1, :].set(1.0 + 0.5 * jnp.sin(2 * jnp.pi / 10.0 * jnp.arange(Nt)))
xref = xref.at[2, :].set(jnp.linspace(-1.0, 1.0, Nt))
xref = xref.at[4, 1:-1].set((2.0 / Tfinal) * jnp.ones(Nt - 2))
xref = xref.at[6, 1:-1].set((2.0 / Tfinal) * jnp.ones(Nt - 2))

### Cost Function and indexing utilities

In [None]:
def stage_cost(x, u, k):
    dx = x - xref[:, k]
    du = u - uref
    return 0.5 * dx @ Q @ dx + 0.5 * du @ (R * du)

def terminal_cost(x):
    dx = x - xref[:, -1]
    return 0.5 * dx @ Qn @ dx

def cost(ztraj):
    z_mat, x_term = unpack(ztraj)
    xtraj = jnp.vstack([z_mat[:, :Nx], x_term[jnp.newaxis, :]])
    utraj = z_mat[:, Nx:]
    # Sum the stage costs over time steps
    J = jnp.sum(jax.vmap(lambda k: stage_cost(xtraj[k], utraj[k], k))(jnp.arange(Nt - 1)))
    return J + terminal_cost(xtraj[-1])

def unpack(ztraj):
    """
    Return:
      z_mat   -- shape (Nt-1, Nx+Nu), each row = [x_k, u_k]
      x_term  -- shape (Nx,), terminal state
    """
    flat_len = (Nt - 1) * (Nx + Nu)
    z_mat = jnp.reshape(ztraj[:flat_len], (Nt - 1, Nx + Nu))
    x_term = ztraj[flat_len:]
    return z_mat, x_term

def build_traj(ztraj):
    z_mat, x_term = unpack(ztraj)
    x_traj = jnp.vstack([z_mat[:, :Nx], x_term[jnp.newaxis, :]])  # shape (Nt, Nx)
    u_traj = z_mat[:, Nx:]  # shape (Nt-1, Nu)
    return x_traj, u_traj


### Dynamics Constraints

In [None]:
## Two methods of calculating constraints with Jax
# Method 1: Vectorization
def dynamics_constraint_vectorized(ztraj):
    # Build the trajectory into states and controls.
    xtraj, utraj = build_traj(ztraj)
    # xtraj: shape (Nt, Nx)
    # utraj: shape (Nt-1, Nu)
    
    # Separate out the states at each time step.
    x_prev = xtraj[:-1]  # shape (Nt-1, Nx)
    x_next = xtraj[1:]   # shape (Nt-1, Nx)
    
    # Precompute indices for each step.
    indices = jnp.arange(Nt - 1)
    # Compute mode: 0 means stance, 1 means flight.
    modes = (indices // Nm) % 2
    # Identify when a jump should occur (when s % Nm == Nm - 1).
    jump_cond = (indices % Nm) == (Nm - 1)
    
    # Vectorize the dynamics integrators over the time dimension.
    # These functions now operate on batches of states and controls.
    v_flight_rk4 = jax.vmap(flight_rk4, in_axes=(0, 0))
    v_stance_rk4 = jax.vmap(stance_rk4, in_axes=(0, 0))
    
    # Compute the result of the flight and stance integrators.
    f_flight = v_flight_rk4(x_prev, utraj)
    f_stance = v_stance_rk4(x_prev, utraj)
    
    # For flight mode, if a jump is indicated, apply the jump map.
    v_jump_map = jax.vmap(jump_map_jax, in_axes=0)
    f_jump = v_jump_map(f_flight)
    
    # Combine the results:
    # - When mode == 0, use the stance dynamics.
    # - Otherwise (mode == 1), if the jump condition holds then use the jump result,
    #   else use the flight result.
    f_combined = jnp.where((modes == 0)[:, None],
                        f_stance,
                        jnp.where(jump_cond[:, None], f_jump, f_flight))
    
    # The defect is the difference between the integrated state and the next state.
    d = f_combined - x_next  # shape: (Nt-1, Nx)
    
    # Place the defects into the full constraint vector.
    c = jnp.zeros(m_nlp)
    c = c.at[c_dyn_inds].set(jnp.reshape(d, (-1,)))
    return c

# Method 2: Loop
def dynamics_constraint(ztraj):
    xtraj, utraj = build_traj(ztraj)
    
    # Initialize array for dynamics defects
    d_init = jnp.zeros((Nt - 1, Nx))
    
    def body_fun(s, d):
        mode = (s // Nm) % 2
        x_s = jax.lax.dynamic_index_in_dim(xtraj, s, axis=0, keepdims=False)
        u_s = jax.lax.dynamic_index_in_dim(utraj, s, axis=0, keepdims=False)
        # Use lax.cond for branching
        f = jax.lax.cond(
            mode == 0,
            lambda _: stance_rk4(x_s, u_s),
            lambda _: jax.lax.cond(
                (s % Nm) == (Nm - 1),
                lambda _: jump_map_jax(flight_rk4(x_s, u_s)),
                lambda _: flight_rk4(x_s, u_s),
                operand=None
            ),
            operand=None
        )
        return d.at[s].set(f - xtraj[s + 1])
    
    d = jax.lax.fori_loop(0, Nt - 1, body_fun, d_init)
    # Place defects into the constraint vector at the appropriate indices.
    c = jnp.zeros(m_nlp)
    c = c.at[c_dyn_inds].set(jnp.reshape(d, (-1,)))
    return c

def stance_constraint(ztraj):
    xtraj, _ = build_traj(ztraj)
    c = jnp.zeros(m_nlp)
    
    # Loop over modes; for each even-indexed mode, assign the corresponding stance constraint
    def mode_body(mode, val):
        c_local, idx = val
        def inner_body(j, inner_val):
            c_inner, idx_inner = inner_val
            s = mode * Nm + j
            # Only update if s is in range
            c_inner = jax.lax.cond(
                s < Nt,
                lambda _: c_inner.at[c_stance_inds[idx_inner]].set(xtraj[s, 3]),
                lambda _: c_inner,
                operand=None
            )
            idx_inner = jax.lax.cond(
                s < Nt,
                lambda _: idx_inner + 1,
                lambda _: idx_inner,
                operand=None
            )
            return (c_inner, idx_inner)
        return jax.lax.fori_loop(0, Nm, inner_body, (c_local, idx))
    
    c, _ = jax.lax.fori_loop(0, Nmodes, mode_body, (c, 0))
    return c

# Precompute stance indices globally (they remain constant) - needed for vectorized stance constraint
s_np = np.arange(Nt)
stance_mask_np = ((s_np // Nm) % 2) == 0
stance_indices_np = np.where(stance_mask_np)[0]
stance_indices = jnp.array(stance_indices_np)  # This is now a constant JAX array

def stance_constraint_vectorized(ztraj):
    xtraj, _ = build_traj(ztraj)
    # Use the precomputed constant indices to extract stance values from xtraj (state index 3)
    stance_values = xtraj[stance_indices, 3]
    
    # Insert these values into the full constraint vector.
    c = jnp.zeros(m_nlp)
    c = c.at[c_stance_inds[:stance_values.shape[0]]].set(stance_values)
    return c

def length_constraint(ztraj):
    xtraj, _ = build_traj(ztraj)
    lengths = jnp.linalg.norm(xtraj[:, :2] - xtraj[:, 2:4], axis=1)
    c = jnp.zeros(m_nlp)
    c = c.at[c_length_inds].set(lengths)
    return c


## Assemble constraint
def con(ztraj):
    c = jnp.zeros(m_nlp)
    xtraj, _ = build_traj(ztraj)
    # Initial constraint
    c = c.at[c_init_inds].set(ztraj[:Nx] - xref[:, 0])
    # Terminal constraint
    c = c.at[c_term_inds].set(ztraj[-Nx:] - xref[:, -1])
    c = c + dynamics_constraint(ztraj)
    c = c + stance_constraint(ztraj)
    c = c + length_constraint(ztraj)
    return c


## Compile and get gradients and jacobians
con_jit = jax.jit(con)
cost_jit = jax.jit(cost)
cost_gradient_jit = jax.jit(jax.jacfwd(cost))
con_jacobian_jit = jax.jit(jax.jacfwd(con))
hessian_cost = jax.jit(jax.hessian(cost))




### Problem interface and solve

In [5]:


# Total decision variable dimension:
n_nlp = Nx * Nt + Nu * (Nt - 1)

# -----------------------------------------------------------------------------
# Define the cyipopt Problem Class
# -----------------------------------------------------------------------------
class TrajectoryOptimizationProblem(object):
    def __init__(self, n, m):
        self.n = n  # decision variable dimension
        self.m = m  # constraint dimension

    def objective(self, x):
        return cost_jit(x)

    def gradient(self, x):
        return cost_gradient_jit(x).flatten()

    def constraints(self, x):
        return con_jit(x).flatten()

    def jacobian(self, x):
        J = con_jacobian_jit(x)
        return J.flatten()

    def jacobianstructure(self):
        rows, cols = np.nonzero(np.ones((self.m, self.n)))
        return (rows, cols)

    def hessianstructure(self):
        rows, cols = np.nonzero(np.ones((self.n, self.n)))
        return (rows, cols)

    def hessian(self, x, lagrange, obj_factor):
        nnz = len(self.hessianstructure()[0])
        return np.zeros(nnz)
    # def hessian(self, x, lagrange, obj_factor):
    #     H = hessian_cost(x)
    #     return H.flatten()

# -----------------------------------------------------------------------------
# Set Bounds and Initial Guess
# -----------------------------------------------------------------------------
lb = -np.inf * np.ones(n_nlp)
ub = np.inf * np.ones(n_nlp)
cl = np.zeros(m_nlp)
cu = np.zeros(m_nlp)
# For length constraints (indices in c_length_inds), set lower bound = ell_min, upper bound = ell_max.
ell_min = 0.5
ell_max = 1.5
cl[c_length_inds] = ell_min
cu[c_length_inds] = ell_max
# xguess: state guess with added Gaussian noise
xguess = xref + 0.1 * np.random.randn(Nx, Nt)
# uguess: control guess with noise; repeat uref (as columns) for each of the Nt-1 stages
uguess = np.tile(uref.reshape(Nu, 1), (1, Nt-1)) + 0.1 * np.random.randn(Nu, Nt-1)
# Assemble the decision variable: first the flattened state-control sequence for stages 0..Nt-2, then the terminal state.
# xguess: (Nx, Nt)   uguess: (Nu, Nt-1)
# We want each row = [x_k, u_k] for k=0..Nt-2

z_mat = np.hstack(( xguess[:, :-1].T, uguess.T ))   # shape (Nt-1, Nx+Nu)
x_term = xguess[:, -1]                              # shape (Nx,)

# Flatten row‑major and append terminal state
z0 = np.concatenate(( z_mat.reshape(-1), x_term ))



# -----------------------------------------------------------------------------
# Instantiate and Solve the Problem with cyipopt
# -----------------------------------------------------------------------------
problem = TrajectoryOptimizationProblem(n=n_nlp, m=m_nlp)
nlp = cyipopt.Problem(
    n=n_nlp,
    m=m_nlp,
    problem_obj=problem,
    lb=lb,
    ub=ub,
    cl=cl,
    cu=cu
)

nlp.add_option('tol', 1e-6)
nlp.add_option('print_level', 5)
nlp.add_option('hessian_approximation', 'limited-memory')
nlp.add_option('check_derivatives_for_naninf', 'yes')


# print(dynamics_constraint(z0))
# print(con(z0))
# # print(length_constraint(x0))
x_sol, info = nlp.solve(z0)

# print("Solution:")
# print(x_sol)
# print("\nSolver Information:")
# print(info)



******************************************************************************
This program contains Ipopt, a library for large-scale nonlinear optimization.
 Ipopt is released as open source code under the Eclipse Public License (EPL).
         For more information visit https://github.com/coin-or/Ipopt
******************************************************************************

This is Ipopt version 3.14.17, running with linear solver MUMPS 5.7.3.

Number of nonzeros in equality constraint Jacobian...:   176064
Number of nonzeros in inequality constraint Jacobian.:    20160
Number of nonzeros in Lagrangian Hessian.............:        0

Total number of variables............................:      448
                     variables with only lower bounds:        0
                variables with lower and upper bounds:        0
                     variables with only upper bounds:        0
Total number of equality constraints.................:      393
Total number of inequality c

### Animate

In [None]:
z_sol, x_term = unpack(x_sol)
model_name = f"hopper"
model_path = Path("mujoco_models") / (str(model_name) + str(".xml"))
    # Load the model and data
model = mujoco.MjModel.from_xml_path(str(model_path.absolute()))


data = mujoco.MjData(model)
print(data.qpos)
data.qpos[0] = z_sol[0,0]
data.qpos[2] = z_sol[0,1]
data.qpos[7] = z_sol[0,2]
data.qpos[9] = z_sol[0,3]

# Get the body and geom IDs
mass1_id = data.body("mass1").id
mass2_id = data.body("mass2").id
# rod_id   = model.geom("rod").id
pos_mass1 = data.xpos[mass1_id].copy()
pos_mass2 = data.xpos[mass2_id].copy()


new_fromto = np.concatenate([pos_mass1, pos_mass2])
rgba_blue = np.array([0.7, 0.7, 0.7, 1.0])
# print(dir(viewer.user_scn.geoms[-1]))

# this function is from the tutorial notebook
def add_visual_capsule(scene, point1, point2, radius, rgba):
    """Adds one capsule to an mjvScene."""
    if scene.ngeom >= scene.maxgeom:
        return
    scene.ngeom += 1  # increment ngeom
    # initialise a new capsule, add it to the scene using mjv_makeConnector
    mujoco.mjv_initGeom(
        scene.geoms[scene.ngeom - 1],
        mujoco.mjtGeom.mjGEOM_CAPSULE,
        np.zeros(3),
        np.zeros(3),
        np.zeros(9),
        rgba.astype(np.float32),
    )
    mujoco.mjv_connector(
        scene.geoms[scene.ngeom - 1],
        mujoco.mjtGeom.mjGEOM_CAPSULE,
        radius,
        point1,
        point2,
    )

# mujoco.set_mjcb_control(controller)
start_time = time.time()
with mujoco.viewer.launch_passive(model, data, show_left_ui=False, show_right_ui=False) as viewer:
    viewer.cam.azimuth = 140
    viewer.cam.elevation = -20
    viewer.cam.distance =  1
    viewer.cam.lookat = np.array([1.0 , -0.8 , 1.1])
    first_time = time.time()
    i = 0
    add_visual_capsule(viewer.user_scn, pos_mass1, pos_mass2, 0.01, rgba_blue)
    while viewer.is_running():
        if i < Nt-2:
            i+=1
        else:
            i = Nt-2

        # mj_step can be replaced with code that also evaluates
        # a policy and applies a control signal before stepping the physics.
        pos_mass1 = [z_sol[i,0], 0, z_sol[i,1]]
        pos_mass2 = [z_sol[i,2], 0, z_sol[i,3]]
        new_fromto = np.concatenate([pos_mass1, pos_mass2])
        rgba_blue = np.array([0.012, 0.263, 0.8745, 1.0])
        # print(dir(viewer.user_scn.geoms[-1]))
        # add_visual_capsule(viewer.user_scn, pos_mass1, pos_mass2, 0.01, rgba_blue)
        mujoco.mjv_connector(
            viewer.user_scn.geoms[viewer.user_scn.ngeom - 1],
            mujoco.mjtGeom.mjGEOM_CAPSULE,
            0.01,
            pos_mass1,
            pos_mass2,
        )
        data.qpos[0] = z_sol[i,0]
        data.qpos[2] = z_sol[i,1]
        data.qpos[7] = z_sol[i,2]
        data.qpos[9] = z_sol[i,3]

            # Retrieve the current world positions of the two masses
        


        
        mujoco.mj_step(model, data)
        time.sleep(0.1)
        # input("wait")
        # Pick up changes to the physics state, apply perturbations, update options from GUI.
        viewer.sync()

[0.  0.  1.  1.  0.  0.  0.  0.  0.  0.5 1.  0.  0.  0. ]


2025-03-20 11:17:10.567 mjpython[14278:20185932] +[IMKClient subclass]: chose IMKClient_Modern
2025-03-20 11:17:10.567 mjpython[14278:20185932] +[IMKInputSession subclass]: chose IMKInputSession_Modern
