In [43]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Ascento-style LQR Gain Scheduling by stance height h

Compute K(h) on a grid, then use linear interpolation at runtime:
  u = -K(h_hat) x
  x = [theta, theta_dot, v_err, yaw_err]^T
  u = [tau_fwd, tau_yaw]^T

Notes:
- This is a *control-oriented* linear model (simple but practical).
- You should refine M_eff, Iyy(h), Iz, damping to match your robot.

Refs:
- Ascento paper (stabilizing torque controller & scheduling idea)  (see arXiv:2005.11435)
- Gain scheduling for Ascento in balance-control literature
- Bryson's rule for initial Q,R scaling
"""

from dataclasses import dataclass
import numpy as np
from scipy.linalg import solve_continuous_are


# -----------------------------
# Robot/model parameters
# -----------------------------
@dataclass
class ModelParams:
    g: float = 9.81

    # Effective lumped parameters (tune / identify)
    M_eff: float = 18.0         # effective mass for forward dynamics [kg]
    Iyy: float = 0.65           # pitch inertia about wheel axle [kg*m^2]
    Iz: float  = 0.70           # yaw inertia [kg*m^2]

    # Geometry
    r_wheel: float = 0.075      # wheel radius [m]
    wheel_sep: float = 0.27     # track width [m]

    # Height scheduling
    h_min: float = 0.18
    h_max: float = 0.34

    # Damping terms (small, to match reality & prevent aggressive oscillations)
    pitch_damping: float = 0.35   # N*m*s/rad (effective)
    v_damping: float = 0.20       # 1/s
    yaw_damping: float = 0.40     # N*m*s/rad (effective)


# -----------------------------
# Linear model A(h), B(h)
# x = [theta, theta_dot, v_err, yaw_err]
# u = [tau_fwd, tau_yaw]
# -----------------------------
def build_linear_model(p: ModelParams, h: float):
    """
    Control-oriented linearization around upright theta≈0.

    theta_ddot ≈ (g/h)*theta  - (pitch_damping/Iyy)*theta_dot  + (-1/(Iyy*r))*tau_fwd
    v_dot      ≈ -(v_damping)*v_err + (1/(M_eff*r))*tau_fwd
    yaw_ddot   ≈ -(yaw_damping/Iz)*yaw_err + (1/Iz)*tau_yaw

    This captures the key couplings needed for balancing+velocity tracking.
    """
    h = float(np.clip(h, p.h_min, p.h_max))

    A = np.zeros((4, 4), dtype=float)
    B = np.zeros((4, 2), dtype=float)

    # theta_dot
    A[0, 1] = 1.0

    # theta_ddot
    A[1, 0] = p.g / h
    A[1, 1] = -(p.pitch_damping / p.Iyy)
    B[1, 0] = -(1.0 / (p.Iyy * p.r_wheel))   # tau_fwd -> pitch accel

    # v_dot
    A[2, 2] = -p.v_damping
    B[2, 0] = (1.0 / (p.M_eff * p.r_wheel))  # tau_fwd -> forward accel

    # yaw_rate_dot (we put yaw_err as yaw_rate_err)
    A[3, 3] = -(p.yaw_damping / p.Iz)
    B[3, 1] = (1.0 / p.Iz)                   # tau_yaw -> yaw accel

    return A, B


# -----------------------------
# Bryson's rule helper
# -----------------------------
def bryson_QR(x_max, u_max, q_scale=1.0, r_scale=1.0):
    """
    Bryson's rule (diagonal):
      Q_ii = 1 / (x_i,max)^2
      R_jj = 1 / (u_j,max)^2
    """
    x_max = np.asarray(x_max, dtype=float)
    u_max = np.asarray(u_max, dtype=float)
    Q = np.diag(1.0 / (x_max**2)) * float(q_scale)
    R = np.diag(1.0 / (u_max**2)) * float(r_scale)
    return Q, R


# -----------------------------
# Solve continuous-time LQR
# -----------------------------
def lqr(A, B, Q, R):
    """
    Solve continuous-time LQR:
      minimize ∫ (x^T Q x + u^T R u) dt
      u = -K x
    """
    P = solve_continuous_are(A, B, Q, R)
    K = np.linalg.solve(R, B.T @ P)
    return K, P


# -----------------------------
# Build K(h) table
# -----------------------------
def build_gain_table(
    p: ModelParams,
    h_grid: np.ndarray,
    x_max=(0.35, 3.0, 1.5, 2.0),     # [rad, rad/s, m/s, rad/s] acceptable maxima
    u_max=(10.0, 6.0),               # [Nm, Nm] acceptable maxima for [tau_fwd, tau_yaw]
    q_scale=1.0,
    r_scale=1.0,
    extra_Q=None
):
    Q, R = bryson_QR(x_max, u_max, q_scale=q_scale, r_scale=r_scale)
    if extra_Q is not None:
        Q = Q + np.asarray(extra_Q, dtype=float)

    Ks = []
    for h in h_grid:
        A, B = build_linear_model(p, float(h))
        K, _ = lqr(A, B, Q, R)
        K[0, 1] /= 2
        Ks.append(K)

    Ks = np.stack(Ks, axis=0)  # shape: (N, 2, 4)
    return Q, R, Ks


def interp_K(h_hat: float, h_grid: np.ndarray, Ks: np.ndarray):
    """
    Linear interpolation of K over height grid.
    Ks: (N, 2, 4)
    """
    h_hat = float(h_hat)
    h_grid = np.asarray(h_grid, dtype=float)

    if h_hat <= h_grid[0]:
        return Ks[0]
    if h_hat >= h_grid[-1]:
        return Ks[-1]

    i = np.searchsorted(h_grid, h_hat) - 1
    h0, h1 = h_grid[i], h_grid[i + 1]
    a = (h_hat - h0) / (h1 - h0)

    return (1.0 - a) * Ks[i] + a * Ks[i + 1]


# -----------------------------
# Export to C header (Teensy-friendly)
# -----------------------------
def export_c_header(path: str, h_grid: np.ndarray, Ks: np.ndarray, name_prefix="LQR"):
    h_grid = np.asarray(h_grid, dtype=float)
    Ks = np.asarray(Ks, dtype=float)

    with open(path, "w", encoding="utf-8") as f:
        f.write("// Auto-generated LQR gain scheduling table\n")
        f.write("#pragma once\n\n")
        f.write(f"static const int {name_prefix}_N = {len(h_grid)};\n")
        f.write(f"static const float {name_prefix}_H[{len(h_grid)}] = {{")
        f.write(", ".join([f"{h:.6f}f" for h in h_grid]))
        f.write("};\n\n")

        # Ks[N][2][4]
        f.write(f"static const float {name_prefix}_K[{len(h_grid)}][2][4] = {{\n")
        for i in range(len(h_grid)):
            f.write("  {\n")
            for r in range(2):
                row = ", ".join([f"{Ks[i, r, c]:.6f}f" for c in range(4)])
                f.write(f"    {{{row}}},\n")
            f.write("  },\n")
        f.write("};\n")


# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    p = ModelParams(
        M_eff=18.0,
        Iyy=0.65,
        Iz=0.70,
        r_wheel=0.07,
        wheel_sep=0.27,
        h_min=0.16,
        h_max=0.36,
        pitch_damping=0.35,
        v_damping=0.20,
        yaw_damping=0.40,
    )

    # Height grid like "multiple linearizations then interpolate"
    h_grid = np.linspace(p.h_min, p.h_max, 9)

    # (Optional) add more penalty on pitch angle to reduce drift/lean
    extra_Q = np.diag([3.0, 0.0, 0.0, 0.0])

    Q, R, Ks = build_gain_table(
        p,
        h_grid,
        x_max=(0.70, 4.0, 1.0, 2.5),
        u_max=(1.25, 0.5),
        q_scale=1.0,
        r_scale=1.0,
        extra_Q=extra_Q
    )

    print("Q=\n", Q)
    print("R=\n", R)
    # print("K=\n", Ks)
    print("K(h_min)=\n", Ks[0])
    print("K(h_max)=\n", Ks[-1])

    # Runtime example
    h_hat = 0.16
    K_hat = interp_K(h_hat, h_grid, Ks)
    print(f"K(h_hat={h_hat:.3f})=\n", K_hat)

    # Export table
    # export_c_header("lqr_gain_table.h", h_grid, Ks, name_prefix="AMBULA")
    # print("Exported: lqr_gain_table.h")


Q=
 [[5.04081633 0.         0.         0.        ]
 [0.         0.0625     0.         0.        ]
 [0.         0.         1.         0.        ]
 [0.         0.         0.         0.16      ]]
R=
 [[0.64 0.  ]
 [0.   4.  ]]
K(h_min)=
 [[-7.20716540e+00 -4.38514387e-01 -9.01069566e-01  5.05413402e-17]
 [-3.00091194e-18 -3.36234355e-19  5.24480078e-18  4.72135955e-02]]
K(h_max)=
 [[-4.50028241e+00 -3.57729496e-01 -7.34247123e-01 -2.74647862e-16]
 [ 1.94910151e-17  8.03196629e-18  1.43325098e-16  4.72135955e-02]]
K(h_hat=0.160)=
 [[-7.20716540e+00 -4.38514387e-01 -9.01069566e-01  5.05413402e-17]
 [-3.00091194e-18 -3.36234355e-19  5.24480078e-18  4.72135955e-02]]
