In [None]:
# === SQP with Riccati inner loop (Lecture 5 Cases I–IV translated to code) ===
def cost_terms(z, N):
    x_blocks = z[: (N + 1) * DIM_X].reshape(N + 1, DIM_X)
    u_blocks = z[(N + 1) * DIM_X :].reshape(N, DIM_U)

    ref = reference_trajectory(N)
    cost = 0.0
    grad = np.zeros_like(z)

    g_x_list = []
    g_u_list = []

    for k in range(N):
        x_err = x_blocks[k] - ref[k]
        u_err = u_blocks[k] - u_hover

        cost += x_err @ Q @ x_err + u_err @ R @ u_err

        g_x = 2.0 * Q @ x_err
        g_u = 2.0 * R @ u_err
        g_x_list.append(g_x)
        g_u_list.append(g_u)

        grad[k * DIM_X : (k + 1) * DIM_X] += g_x
        grad[(N + 1) * DIM_X + k * DIM_U : (N + 1) * DIM_X + (k + 1) * DIM_U] += g_u

    xN_err = x_blocks[-1] - ref[-1]
    xN_err[4] = wrap_angle(xN_err[4] - 2.0 * np.pi)
    cost += xN_err @ Q_f @ xN_err
    g_xN = 2.0 * Q_f @ xN_err
    grad[-DIM_X:] += g_xN

    return cost, grad, g_x_list, g_u_list, g_xN, x_blocks, u_blocks


def linearize_dynamics(x_blocks, u_blocks):
    A_list = []
    B_list = []
    d_list = []

    for k in range(len(u_blocks)):
        xk = x_blocks[k]
        uk = u_blocks[k]

        θ = xk[4]
        u_sum = uk.sum()
        sinθ = np.sin(θ)
        cosθ = np.cos(θ)

        A = np.eye(DIM_X)
        A[0, 1] = Δt
        A[2, 3] = Δt
        A[4, 5] = Δt
        A[1, 4] = -Δt * (u_sum / m) * cosθ
        A[3, 4] = -Δt * (u_sum / m) * sinθ

        B = np.zeros((DIM_X, DIM_U))
        B[1, :] = -Δt * sinθ / m
        B[3, :] = Δt * cosθ / m
        B[5, 0] = Δt * r / I
        B[5, 1] = -Δt * r / I

        d = step_dynamics(xk, uk) - x_blocks[k + 1]

        A_list.append(A)
        B_list.append(B)
        d_list.append(d)

    return A_list, B_list, d_list


def solve_stage_qp(z, g_x_list, g_u_list, g_xN, x_blocks, u_blocks):
    N = len(u_blocks)
    A_list, B_list, d_list = linearize_dynamics(x_blocks, u_blocks)

    P_next = 2.0 * Q_f
    p_next = g_xN.copy()

    K_list = [None] * N
    k_list = [None] * N

    for k in reversed(range(N)):
        A = A_list[k]
        B = B_list[k]
        d = d_list[k]
        g_x = g_x_list[k]
        g_u = g_u_list[k]

        v = p_next + P_next @ d
        Q_x = g_x + A.T @ v
        Q_u = g_u + B.T @ v
        Q_xx = 2.0 * Q + A.T @ P_next @ A
        Q_ux = B.T @ P_next @ A
        Q_uu = 2.0 * R + B.T @ P_next @ B

        Q_uu_chol = np.linalg.cholesky(Q_uu)
        K = -np.linalg.solve(Q_uu_chol.T, np.linalg.solve(Q_uu_chol, Q_ux))
        k_ff = -np.linalg.solve(Q_uu_chol.T, np.linalg.solve(Q_uu_chol, Q_u))

        K_list[k] = K
        k_list[k] = k_ff

        P = Q_xx + K.T @ Q_uu @ K + K.T @ Q_ux + Q_ux.T @ K
        p = Q_x + K.T @ Q_uu @ k_ff + K.T @ Q_u + Q_ux.T @ k_ff

        P_next = 0.5 * (P + P.T)
        p_next = p

    Δx_blocks = np.zeros_like(x_blocks)
    Δu_blocks = np.zeros_like(u_blocks)

    Δx = np.zeros(DIM_X)
    Δx_blocks[0] = Δx

    for k in range(N):
        A = A_list[k]
        B = B_list[k]
        d = d_list[k]
        K = K_list[k]
        k_ff = k_list[k]

        Δu = K @ Δx + k_ff
        Δx_next = A @ Δx + B @ Δu + d

        Δu_blocks[k] = Δu
        Δx_blocks[k + 1] = Δx_next
        Δx = Δx_next

    step = np.concatenate([Δx_blocks.reshape(-1), Δu_blocks.reshape(-1)])
    step[:DIM_X] = 0.0  # keep initial state fixed via projection
    return step


def constraint_violation(z, N):
    x_blocks = z[: (N + 1) * DIM_X].reshape(N + 1, DIM_X)
    u_blocks = z[(N + 1) * DIM_X :].reshape(N, DIM_U)

    violations = [
        np.abs(z[:DIM_X] - x_init),
        np.abs(dynamics_residual(z, N)),
        np.maximum(0.0, -u_blocks).ravel(),
        np.maximum(0.0, u_blocks - 10.0).ravel(),
    ]
    return np.sum(np.concatenate(violations))


def filter_line_search(z, step, cost_curr, violation_curr, c_best, f_best,
                       metrics_fn, project_fn, ρ=0.5, α0=1.0):
    α = α0
    while α > 1e-6:
        z_trial = project_fn(z + α * step)
        cost_trial, violation_trial = metrics_fn(z_trial)
        if cost_trial < f_best or violation_trial < c_best:
            return (z_trial, α,
                    min(f_best, cost_trial),
                    min(c_best, violation_trial),
                    cost_trial, violation_trial)
        α *= ρ
    return z, 0.0, f_best, c_best, cost_curr, violation_curr


def solve_sqp(x_init, N, max_iters=80, tol_grad=1e-4, tol_constr=1e-4):
    z = np.zeros((N + 1) * DIM_X + N * DIM_U)
    z[:DIM_X] = x_init
    ref = reference_trajectory(N)
    for k in range(1, N + 1):
        z[k * DIM_X : (k + 1) * DIM_X] = ref[k]
    for k in range(N):
        z[(N + 1) * DIM_X + k * DIM_U : (N + 1) * DIM_X + (k + 1) * DIM_U] = u_hover

    project = lambda vec: np.concatenate([
        x_init,
        vec[DIM_X : (N + 1) * DIM_X],
        np.clip(vec[(N + 1) * DIM_X :], 0.0, 10.0),
    ])

    f_best = np.inf
    c_best = np.inf
    history = []

    for it in range(max_iters):
        cost, grad, g_x_list, g_u_list, g_xN, x_blocks, u_blocks = cost_terms(z, N)
        violation = constraint_violation(z, N)
        grad_inf = np.linalg.norm(grad, np.inf)
        history.append({
            "iter": it,
            "cost": cost,
            "violation": violation,
            "grad_inf": grad_inf,
        })

        if violation < tol_constr and grad_inf < tol_grad:
            history[-1]["alpha"] = 0.0
            break

        step = solve_stage_qp(z, g_x_list, g_u_list, g_xN, x_blocks, u_blocks)

        metrics_fn = lambda cand: (cost_terms(cand, N)[0], constraint_violation(cand, N))
        z_new, α, f_best, c_best, cost_new, violation_new = filter_line_search(
            z, step, cost, violation, c_best, f_best, metrics_fn, project
        )
        history[-1]["alpha"] = α
        history[-1]["accepted_cost"] = cost_new
        history[-1]["accepted_violation"] = violation_new

        if α == 0.0:
            print("Filter line search failed; terminating SQP early.")
            break

        z = z_new

    return z, history


# === Run SQP, verify looping conditions, and plot diagnostics ===
N = 100
z_opt, log = solve_sqp(x_init, N)

states = z_opt[: (N + 1) * DIM_X].reshape(N + 1, DIM_X)
controls = z_opt[(N + 1) * DIM_X :].reshape(N, DIM_U)
time_grid = np.arange(N + 1) * Δt

theta_final = wrap_angle(states[-1, 4] - 2.0 * np.pi)
omega_final = states[-1, 5]
thrust_ok = np.all((controls >= -1e-6) & (controls <= 10.0 + 1e-6))

print(f"Terminal θ error (rad): {theta_final:.3e}")
print(f"Terminal ω (rad/s):     {omega_final:.3e}")
print(f"Thrust bounds satisfied: {thrust_ok}")
print(f"Final position (p_x, p_y): {states[-1,0]:.3f}, {states[-1,2]:.3f}")

fig, axes = plt.subplots(3, 2, figsize=(10, 8), sharex=True)
labels = ["p_x", "v_x", "p_y", "v_y", "θ", "ω"]
for idx, label in enumerate(labels):
    ax = axes[idx // 2, idx % 2]
    ax.plot(time_grid, states[:, idx])
    ax.set_ylabel(label)
axes[-1, 0].set_xlabel("time [s]")
axes[-1, 1].set_xlabel("time [s]")
axes[0, 0].set_title("State trajectories")
plt.tight_layout()

plt.figure(figsize=(6, 3))
plt.plot(time_grid[:-1], controls[:, 0], label="u₁")
plt.plot(time_grid[:-1], controls[:, 1], label="u₂")
plt.axhline(0.0, color="k", linestyle="--", linewidth=0.8)
plt.axhline(10.0, color="k", linestyle="--", linewidth=0.8)
plt.xlabel("time [s]")
plt.ylabel("thrust")
plt.legend()
plt.title("Control inputs")
plt.tight_layout()

iters = [entry["iter"] for entry in log]
costs = [entry["cost"] for entry in log]
violations = [entry["violation"] for entry in log]
alphas = [entry.get("alpha", 0.0) for entry in log]

plt.figure(figsize=(10, 3))
plt.subplot(1, 3, 1)
plt.plot(iters, costs, marker="o")
plt.title("Cost per SQP iter")
plt.xlabel("iteration")
plt.subplot(1, 3, 2)
plt.plot(iters, violations, marker="o")
plt.yscale("log")
plt.title("Constraint violation")
plt.xlabel("iteration")
plt.subplot(1, 3, 3)
plt.plot(iters, alphas, marker="o")
plt.title("Filter α")
plt.xlabel("iteration")
plt.tight_layout()

quadrotor.animate_robot(states.T, controls.T)