In [64]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-

"""
Ascento-style LQR gain scheduling by height h:
  - Precompute K(h) on a grid of heights.
  - Runtime: interpolate K at estimated height h_hat.
  - Control: u = -K(h_hat) * x
    where x = [theta, theta_dot, v_err, yaw_err]^T
          v_err = v - v_ref
          yaw_err = yaw_rate - yaw_rate_ref
  - Output torque (Nm) for left/right wheels: tau_L, tau_R

Reference:
- Ascento paper: multiple linearizations & linear interpolation of gain vs height,
  using u = -K(h_hat) x_hat, and wheel torques as input. :contentReference[oaicite:3]{index=3}
- Ambula workflow / report confirm torque control (FOC) at wheel layer. :contentReference[oaicite:4]{index=4} :contentReference[oaicite:5]{index=5}
"""

from dataclasses import dataclass
import numpy as np
from scipy.linalg import solve_continuous_are


# ===================== Model =====================
@dataclass
class ModelParams:
    # --- physical ---
    g: float = 9.81
    M_eff: float = 18.0        # effective mass (kg) for longitudinal accel in simplified model
    Iyy: float = 0.665         # pitch inertia about wheel axle (kg*m^2)
    Iz: float = 0.7            # yaw inertia (kg*m^2)
    r_wheel: float = 0.075     # wheel radius (m)
    wheel_sep: float = 0.27    # track width (m)

    # --- height -> COM height mapping (simple) ---
    # l(h) = l0 + alpha_l*(h - h0)
    l0: float = 0.26           # COM height above axle at reference height h0 (m)
    h0: float = 0.26           # reference stance height (m)
    alpha_l: float = 1.0       # slope (start with 1.0 if COM tracks stance height)

    # --- damping (tune) ---
    pitch_damping: float = 0.4
    yaw_damping: float = 0.5


def l_of_h(p: ModelParams, h: float) -> float:
    return float(p.l0 + p.alpha_l * (h - p.h0))


def build_AB(p: ModelParams, h: float):
    """
    State x = [theta, theta_dot, v_err, yaw_err]^T
      - theta     : pitch (rad)
      - theta_dot : pitch rate (rad/s)
      - v_err     : (v - v_ref) (m/s)
      - yaw_err   : (yaw_rate - yaw_rate_ref) (rad/s)

    Input u = [tau_sum, tau_diff]^T (Nm)
      - tau_sum  = tau_L + tau_R
      - tau_diff = tau_R - tau_L

    This is a compact engineering linear model (starting point).
    """
    l = l_of_h(p, h)

    A = np.zeros((4, 4), dtype=float)
    A[0, 1] = 1.0

    # inverted-pendulum dominant term + pitch damping
    A[1, 0] = (p.M_eff * p.g * l) / p.Iyy
    A[1, 1] = -(p.pitch_damping) / p.Iyy

    # v_err dynamics (very simplified): v_dot driven by tau_sum
    A[2, 2] = 0.0

    # yaw_err dynamics: yaw_rate_dot damped
    A[3, 3] = -(p.yaw_damping) / p.Iz

    B = np.zeros((4, 2), dtype=float)
    # theta_ddot driven by tau_sum about axle
    B[1, 0] = 1.0 / p.Iyy

    # v_dot driven by longitudinal force from wheel torque: F = tau_sum / r, a = F / M_eff
    B[2, 0] = 1.0 / (p.M_eff * p.r_wheel)

    # yaw_rate_dot from differential wheel force moment
    B[3, 1] = (p.wheel_sep / 2.0) * (1.0 / p.r_wheel) * (1.0 / p.Iz)

    return A, B


# ===================== LQR =====================
def lqr_continuous(A, B, Q, R):
    """
    Continuous-time LQR:
      minimize ∫ (x^T Q x + u^T R u) dt
      u = -Kx
    """
    P = solve_continuous_are(A, B, Q, R)
    K = np.linalg.solve(R, B.T @ P)
    K = K * 0.05
    return K, P


# ===================== Scheduling =====================
def build_K_table(p: ModelParams, Q: np.ndarray, R: np.ndarray, h_grid: np.ndarray):
    """
    Precompute K(h) for each height in h_grid.
    Returns K_table shape (Nh, 2, 4)
    """
    K_table = np.zeros((len(h_grid), 2, 4), dtype=float)
    for i, h in enumerate(h_grid):
        A, B = build_AB(p, float(h))
        K, _ = lqr_continuous(A, B, Q, R)
        K_table[i] = K
    return K_table


def interp_K(h: float, h_grid: np.ndarray, K_table: np.ndarray) -> np.ndarray:
    """
    Linear interpolation of K across height.
    """
    if h <= h_grid[0]:
        return K_table[0]
    if h >= h_grid[-1]:
        return K_table[-1]
    i = int(np.searchsorted(h_grid, h) - 1)
    t = (h - h_grid[i]) / (h_grid[i + 1] - h_grid[i])
    return (1.0 - t) * K_table[i] + t * K_table[i + 1]


# ===================== Config =====================
@dataclass
class LqrWeights:
    Q_theta: float = 250.0
    Q_theta_dot: float = 25.0
    Q_v_err: float = 5.0
    Q_yaw_err: float = 2.0
    R_tau_sum: float = 0.08
    R_tau_diff: float = 0.12


@dataclass
class ControlLimits:
    tau_max_per_wheel: float = 10.0     # Nm (จริง ๆ ตั้งตาม motor+FOC)
    tau_sum_max: float = 30.0           # Nm
    tau_diff_max: float = 20.0          # Nm
    theta_safe_rad: float = 0.7         # ตัดถ้าเอียงเกิน ~40deg


@dataclass
class SignConfig:
    # state sign normalization
    theta_sign: float = 1.0
    theta_dot_sign: float = 1.0
    v_sign: float = 1.0
    yaw_rate_sign: float = 1.0
    # output sign for motors (flip if motor torque direction is reversed)
    tau_left_sign: float = 1.0
    tau_right_sign: float = 1.0


# ===================== Controller =====================
class AscentoStyleLqrController:
    def __init__(self,
                 p: ModelParams,
                 weights: LqrWeights,
                 limits: ControlLimits,
                 sign: SignConfig,
                 h_grid: np.ndarray):
        self.p = p
        self.weights = weights
        self.limits = limits
        self.sign = sign

        self.h_grid = np.array(h_grid, dtype=float)

        Q = np.diag([weights.Q_theta, weights.Q_theta_dot, weights.Q_v_err, weights.Q_yaw_err])
        R = np.diag([weights.R_tau_sum, weights.R_tau_diff])

        self.Q = Q
        self.R = R

        # Precompute K-table (Ascento-style)
        self.K_table = build_K_table(self.p, self.Q, self.R, self.h_grid)

    def compute_wheel_torque(self,
                             theta: float,
                             theta_dot: float,
                             v: float,
                             yaw_rate: float,
                             v_ref: float = 0.0,
                             yaw_rate_ref: float = 0.0,
                             h_hat: float = 0.26):
        """
        Return (tau_L, tau_R) in Nm.
        """
        # normalize signs (important!)
        th = self.sign.theta_sign * float(theta)
        thd = self.sign.theta_dot_sign * float(theta_dot)
        vv = self.sign.v_sign * float(v)
        yr = self.sign.yaw_rate_sign * float(yaw_rate)

        # safety
        if abs(th) > self.limits.theta_safe_rad:
            return 0.0, 0.0

        # Ascento-style: reference injected as errors in state
        v_err = vv - float(v_ref)
        yaw_err = yr - float(yaw_rate_ref)

        x = np.array([th, thd, v_err, yaw_err], dtype=float)

        # interpolate scheduled K(h_hat)
        K_use = interp_K(float(h_hat), self.h_grid, self.K_table)

        # u = [tau_sum, tau_diff]
        u = -K_use @ x
        print(u)
        tau_sum = float(np.clip(u[0], -self.limits.tau_sum_max, self.limits.tau_sum_max))
        tau_diff = float(np.clip(u[1], -self.limits.tau_diff_max, self.limits.tau_diff_max))

        # convert to wheel torques
        tau_L = 0.5 * (tau_sum - tau_diff)
        tau_R = 0.5 * (tau_sum + tau_diff)

        # clamp per wheel
        tau_L = float(np.clip(tau_L, -self.limits.tau_max_per_wheel, self.limits.tau_max_per_wheel))
        tau_R = float(np.clip(tau_R, -self.limits.tau_max_per_wheel, self.limits.tau_max_per_wheel))

        # output sign fix
        tau_L *= self.sign.tau_left_sign
        tau_R *= self.sign.tau_right_sign

        return tau_L, tau_R

    def debug_print_K(self):
        np.set_printoptions(precision=4, suppress=True)
        for i, h in enumerate(self.h_grid):
            print(f"\n--- K at h={h:.3f} ---")
            print(self.K_table[i])


# ===================== YOUR IO (replace these) =====================
def get_measurements():
    """
    Replace this with your estimator outputs.
    Returns:
      theta (rad), theta_dot (rad/s), v (m/s), yaw_rate (rad/s), h_hat (m)
    """
    # Example dummy values
    theta = 0.34906585
    theta_dot = 0.5
    v = 0.0
    yaw_rate = 0.0
    h_hat = 0.16
    return theta, theta_dot, v, yaw_rate, h_hat


def send_torque_left(tau_L: float):
    # Replace with: ODrive torque command / Teensy CAN message
    print(f"[CMD] tau_L = {tau_L:+.3f} Nm")


def send_torque_right(tau_R: float):
    # Replace with: ODrive torque command / Teensy CAN message
    print(f"[CMD] tau_R = {tau_R:+.3f} Nm")


# ===================== Main =====================
def main():
    # ---- params ----
    p = ModelParams(
        M_eff=18.0,
        Iyy=0.665,
        Iz=0.7,
        r_wheel=0.07,
        wheel_sep=0.27,
        l0=0.16,
        h0=0.16,
        alpha_l=1.0,
        pitch_damping=0.4,
        yaw_damping=0.5
    )

    weights = LqrWeights(
        Q_theta=100.0, Q_theta_dot=10.0, Q_v_err=5.0, Q_yaw_err=2.0,
        R_tau_sum=1.0, R_tau_diff=2.0
    )

    limits = ControlLimits(
        tau_max_per_wheel=10.0,
        tau_sum_max=30.0,
        tau_diff_max=20.0,
        theta_safe_rad=0.7
    )

    sign = SignConfig(
        theta_sign=1.0, theta_dot_sign=1.0, v_sign=1.0, yaw_rate_sign=1.0,
        tau_left_sign=1.0, tau_right_sign=1.0
    )

    # Ascento-style: use multiple heights; paper mentions using many height points and interpolate. :contentReference[oaicite:6]{index=6}
    # Choose h_grid spanning your real stance range.
    h_grid = np.array([0.16, 0.23, 0.26, 0.29, 0.32, 0.36], dtype=float)

    ctrl = AscentoStyleLqrController(p, weights, limits, sign, h_grid=h_grid)

    # optional debug
    print("Precomputed K-table (2x4) for each h in h_grid.")
    ctrl.debug_print_K()

    # ---- example command ----
    v_ref = 0.0       # m/s
    yaw_ref = 0.0     # rad/s

    # ---- one-step demo (replace with your control loop @ 100-400 Hz) ----
    theta, theta_dot, v, yaw_rate, h_hat = get_measurements()

    tau_L, tau_R = ctrl.compute_wheel_torque(
        theta=theta,
        theta_dot=theta_dot,
        v=v,
        yaw_rate=yaw_rate,
        v_ref=v_ref,
        yaw_rate_ref=yaw_ref,
        h_hat=h_hat
    )

    send_torque_left(tau_L)
    send_torque_right(tau_R)


if __name__ == "__main__":
    main()


Precomputed K-table (2x4) for each h in h_grid.

--- K at h=0.160 ---
[[ 3.7213  0.5624 -0.1118  0.    ]
 [ 0.      0.     -0.      0.0387]]

--- K at h=0.230 ---
[[ 5.092   0.6433 -0.1118 -0.    ]
 [-0.     -0.      0.      0.0387]]

--- K at h=0.260 ---
[[ 5.6746  0.6745 -0.1118  0.    ]
 [-0.      0.      0.      0.0387]]

--- K at h=0.290 ---
[[ 6.2549  0.7042 -0.1118 -0.    ]
 [-0.     -0.      0.      0.0387]]

--- K at h=0.320 ---
[[ 6.8331  0.7325 -0.1118  0.    ]
 [ 0.      0.     -0.      0.0387]]

--- K at h=0.360 ---
[[ 7.6012  0.7684 -0.1118 -0.    ]
 [-0.     -0.     -0.      0.0387]]
[-1.5802 -0.    ]
[CMD] tau_L = -0.790 Nm
[CMD] tau_R = -0.790 Nm
