In [1]:
import sys
import os

sys.path.insert(0, os.path.abspath("."))
sys.path.append(os.path.abspath("../../"))

import desc
from desc.objectives import *
from desc.basis import ZernikePolynomial, FourierZernikeBasis, DoubleFourierSeries
from desc.transform import Transform
from desc.equilibrium import EquilibriaFamily, Equilibrium
from desc.plotting import (
    plot_1d,
    plot_2d,
    plot_3d,
    plot_comparison, 
    plot_section, 
    plot_fsa, 
    plot_surfaces,
)
from desc.optimize import Optimizer
from desc.perturbations import *
import numpy as np
np.set_printoptions(linewidth=np.inf)

from desc.geometry import (
    FourierRZToroidalSurface, 
    ZernikeRZToroidalSection, 
    SplineXYZCurve, 
    FourierXYZCurve,
)
%matplotlib inline
from desc.utils import copy_coeffs
import matplotlib.pyplot as plt
from desc.examples import get
from desc.objectives.getters import (
    get_fixed_boundary_constraints, 
    maybe_add_self_consistency,
)
from desc.grid import LinearGrid, QuadratureGrid, Grid
import plotly.graph_objects as go
from desc.profiles import PowerSeriesProfile
from desc.__main__ import main
from desc.vmec_utils import vmec_boundary_subspace
from desc.compute import data_index
from desc.coils import SplineXYZCoil
from desc.equilibrium import Equilibrium
from desc.geometry import FourierRZToroidalSurface
from desc.objectives import (
    ObjectiveFunction,
    ForceBalance,
    get_fixed_boundary_constraints,
)
from desc.optimize import Optimizer
from desc.plotting import plot_1d, plot_section, plot_surfaces
from desc.profiles import PowerSeriesProfile
from desc.examples import get
from desc.grid import LinearGrid
import plotly.graph_objects as go

from scipy.linalg import qr_insert, qr

import jax.numpy as jnp
from jax.lax import fori_loop
from jax.lax import rsqrt
import jax

import functools


DESC version 0.12.1+5.g28e391e1e.dirty,using JAX backend, jax version=0.4.28, jaxlib version=0.4.28, dtype=float64
Using device: CPU, with 9.95 GB available memory


# The one was on the PR but didn't merged in

In [None]:
# TODO: add references to the docstrings
def _givens_jax(a, b):
    """Compute Givens rotation matrix.

    Compute the Givens rotation matrix G2 that zeros out the second element
    of a 2-vector.
        G2*[a; b] = [r; 0]
        where r = sqrt(a^2 + b^2)
        G2 = [[c, -s], [s, c]]
    """
    # Taken from jax._src.scipy.sparse.linalg._givens_rotation
    b_zero = abs(b) == 0
    a_lt_b = abs(a) < abs(b)
    t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
    r = rsqrt(1 + abs(t) ** 2).astype(t.dtype)
    cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
    sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
    G2 = jnp.array([[cs, -sn], [sn, cs]])
    return G2.astype(float)


@jit
def update_qr_jax(A, w, q, r):
    """Update QR factorization with a diagonal matrix w at the bottom."""
    m, n = A.shape
    Q = jnp.eye(m + n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])

    def body_inner(i, jQR):
        j, Q, R = jQR
        i = m + j - i
        a, b = R[i - 1, j], R[i, j]
        G2 = _givens_jax(a, b)
        R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
        Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)
        return j, Q, R

    def body(j, QR):
        Q, R = QR
        j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
        return Q, R

    Q, R = fori_loop(0, n, body, (Q, R))
    R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

    return Q, R


@jit
def update_qr_jax_eco(A, w, q, r):
    """Update QR factorization with a diagonal matrix w at the bottom."""
    m, n = A.shape
    Q = jnp.eye(m + n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])

    def body_inner(i, jQR):
        j, Q, R = jQR
        i = m + j - i
        a, b = R[i - 1, j], R[i, j]
        G2 = _givens_jax(a, b)
        R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
        Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)
        return j, Q, R

    def body(j, QR):
        Q, R = QR
        j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
        return Q, R

    Q, R = fori_loop(0, n, body, (Q, R))
    R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

    Re = R.at[: R.shape[1], : R.shape[1]].get()
    Qe = Q.at[:, : R.shape[1]].get()

    return Qe, Re

In [3]:

# def _givens_jax(a, b):
#     """Compute Givens rotation matrix.

#     Compute the Givens rotation matrix G2 that zeros out the second element
#     of a 2-vector.
#         G2*[a; b] = [r; 0]
#         where r = sqrt(a^2 + b^2)
#         G2 = [[c, -s], [s, c]]
#     """
#     r = jnp.sqrt(a**2 + b**2)
#     c = a / r
#     s = -b / r

#     G2 = jnp.array([[c, -s], [s, c]])
#     return G2.astype(float)

def _givens_jax(a, b):
    b_zero = abs(b) == 0
    a_lt_b = abs(a) < abs(b)
    t = -jnp.where(a_lt_b, a, b) / jnp.where(a_lt_b, b, a)
    r = rsqrt(1 + abs(t) ** 2).astype(t.dtype)
    cs = jnp.where(b_zero, 1, jnp.where(a_lt_b, r * t, r))
    sn = jnp.where(b_zero, 0, jnp.where(a_lt_b, r, r * t))
    G2 = jnp.array([[cs, -sn], [sn, cs]])
    return G2.astype(float)

@jax.jit
def update_qr_jax(A, w, q, r):
    """Update QR factorization with a diagonal matrix w at the bottom."""
    m, n = A.shape
    Q = jnp.eye(m + n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])

    def body_inner(i, jQR):
        j, Q, R = jQR
        i = m + j - i
        a, b = R[i - 1, j], R[i, j]
        G2 = _givens_jax(a, b)
        R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
        Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)
        return j, Q, R

    def body(j, QR):
        Q, R = QR
        j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
        return Q, R

    Q, R = fori_loop(0, n, body, (Q, R))
    R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

    return Q, R


@jax.jit
def update_qr_jax_eco(A, w, q, r):
    """Update QR factorization with a diagonal matrix w at the bottom."""
    m, n = A.shape
    Q = jnp.eye(m + n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])

    def body_inner(i, jQR):
        j, Q, R = jQR
        i = m + j - i
        a, b = R[i - 1, j], R[i, j]
        G2 = _givens_jax(a, b)
        R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
        Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)
        return j, Q, R

    def body(j, QR):
        Q, R = QR
        j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
        return Q, R

    Q, R = fori_loop(0, n, body, (Q, R))
    R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

    Re = R.at[: R.shape[1], : R.shape[1]].get()
    Qe = Q.at[:, : R.shape[1]].get()

    return Qe, Re


@jax.jit
def update_qr_jax_eco_for(A, w, q, r):
    """Update QR factorization with a diagonal matrix w at the bottom."""
    m, n = A.shape
    Q = jnp.eye(m + n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])


    for j in range(n):
        for i in range(m):
            i = m + j - i
            a, b = R[i - 1, j], R[i, j]
            G2 = _givens_jax(a, b)
            R = R.at[jnp.array([i - 1, i])].set(G2 @ R[jnp.array([i - 1, i])])
            Q = Q.at[:, jnp.array([i - 1, i])].set(Q[:, jnp.array([i - 1, i])] @ G2.T)

    R = jnp.where(jnp.abs(R) < 1e-10, 0, R)

    Re = R.at[: R.shape[1], : R.shape[1]].get()
    Qe = Q.at[:, : R.shape[1]].get()

    return Qe, Re


import functools

@jax.jit
def qr_house_jax(A):
    """Compute QR factorization of A using Householder reflections."""
    m, n = A.shape
    Q = np.eye(m)
    R = A.copy()

    @functools.partial(jax.jit, static_argnums=(0,))
    def body(j, QR):
        q, r = QR
        v = jnp.zeros(m)
        v = v.at[j:].set(r[:-m+j, j])
        v = v.at[j].add(jnp.sign(v[j]) * jnp.linalg.norm(v))
        v = v / jnp.linalg.norm(v)
        r = r.at[j:, j:].sub(2 * jnp.outer(v[j:], v[j:] @ r[j:, j:]))

        q = q.at[j:, :].sub(2 * jnp.outer(v[j:], v[j:] @ q[j:, :]))

    Q, R = fori_loop(0, n, body, (Q, R))
    return Q, R

def qr_householder(A):
    """Compute QR factorization of A using Householder reflections."""
    m, n = A.shape
    Q = np.eye(m)
    R = A.copy()

    for j in range(n):
        v = np.zeros(m)
        v[j:] = R[j:, j]
        v[j] += np.sign(v[j]) * np.linalg.norm(v)
        v = v / np.linalg.norm(v)
        R[j:, j:] -= 2 * np.outer(v[j:], v[j:] @ R[j:, j:])
        Q[j:, :] -= 2 * np.outer(v[j:], v[j:] @ Q[j:, :])
        R[abs(R)<1e-10] = 0
        print(R)

    return Q, R

def qr_householder_desc(A):
    """Compute QR factorization of A using Householder reflections."""
    m, n = A.shape
    Q = np.eye(m)
    R = A.copy()

    for j in range(n):
        v = np.zeros(m)
        v[n:] = R[n:, j]
        v[n] += np.sign(v[n]) * np.linalg.norm(v)
        v = v / np.linalg.norm(v)
        R[j:, j:] -= 2 * np.outer(v[j:], v[j:] @ R[j:, j:])
        Q[j:, :] -= 2 * np.outer(v[j:], v[j:] @ Q[j:, :])
        R[abs(R)<1e-10] = 0
        print(R)

    return Q, R

import numpy as np

def zero_last_n_rows_householder(At):
    """Use Householder reflections to zero out the last n rows of At."""
    m_plus_n, n = At.shape
    Q = np.eye(m_plus_n)
    R = At.copy()

    for j in range(n):
        # Create the Householder vector
        x = R[j:, j]
        e = np.zeros_like(x)
        e[0] = np.copysign(np.linalg.norm(x), -R[j, j])
        v = x - e
        v = v / np.linalg.norm(v)
        
        # Apply the Householder transformation to R
        R[j:, :] -= 2 * np.outer(v, v @ R[j:, :])
        
        # Apply the Householder transformation to Q
        Q[:, j:] -= 2 * np.outer(Q[:, j:] @ v, v)
        R[abs(R)<1e-10] = 0
        print(R)
    
    return Q, R

# # Example usage
# m, n = 5, 3
# A = 100 * np.random.rand(m, n)
# A = np.triu(A)
# w = 2 * np.eye(n)
# At = np.vstack([A, w])

# Q, R = zero_last_n_rows_householder(At)
# print("Q:", Q)
# print("R:", R)


import jax
import jax.numpy as jnp

# @jax.jit
# def zero_last_n_rows_householder(At):
#     """Use Householder reflections to zero out the last n rows of At."""
#     m_plus_n, n = At.shape
#     Q = jnp.eye(m_plus_n)
#     R = At.copy()

#     for j in range(n):
#         # Create the Householder vector
#         x = R[j:, j]
#         e = jnp.zeros_like(x)
#         e = e.at[0].set(jnp.copysign(jnp.linalg.norm(x), -R[j, j]))
#         v = x - e
#         v = v / jnp.linalg.norm(v)
        
#         # Apply the Householder transformation to R
#         R = R.at[j:, :].set(R[j:, :] - 2 * jnp.outer(v, v @ R[j:, :]))
        
#         # Apply the Householder transformation to Q
#         Q = Q.at[:, j:].set(Q[:, j:] - 2 * jnp.outer(Q[:, j:] @ v, v))
#         R = jnp.where(jnp.abs(R) < 1e-10, 0, R)
#         # jax.debug.print("{R}", R=R)
    
#     return Q, R


@jax.jit
def zero_last_n_rows_householder(At):
    """Use Householder reflections to zero out the last n rows of At."""
    m_plus_n, n = At.shape
    Q = jnp.eye(m_plus_n)
    R = At.copy()

    def body(j, QR):
        Q, R = QR
        # Create the Householder vector
        x = jax.lax.dynamic_slice(R, (j, j), (m_plus_n - j, 1)).reshape(-1)
        e = jnp.zeros_like(x)
        e = e.at[0].set(jnp.copysign(jnp.linalg.norm(x), -R[j, j]))
        v = x - e
        v = v / jnp.linalg.norm(v)

        # Apply the Householder transformation to R
        R_sub = jax.lax.dynamic_slice(R, (j, j), (m_plus_n - j, n - j))
        R_sub = R_sub - 2 * jnp.outer(v, v @ R_sub)
        R = jax.lax.dynamic_update_slice(R, R_sub, (j, j))

        # Apply the Householder transformation to Q
        Q_sub = jax.lax.dynamic_slice(Q, (0, j), (m_plus_n, m_plus_n - j))
        Q_sub = Q_sub - 2 * jnp.outer(Q_sub @ v, v)
        Q = jax.lax.dynamic_update_slice(Q, Q_sub, (0, j))

        R = jnp.where(jnp.abs(R) < 1e-10, 0, R)
        return Q, R
    
    Q, R = fori_loop(0, n, body, (Q, R))
    return Q, R

# # Example usage
# m, n = 5, 3
# A = 100 * np.random.rand(m, n)
# A = jnp.triu(A)
# w = 2 * jnp.eye(n)
# At = jnp.vstack([A, w])

# Q, R = zero_last_n_rows_householder(At)
# print("Q:", Q)
# print("R:", R)

In [5]:
m, n = 10, 10

A = 100*np.random.rand(m, n)
A = np.triu(A)
w = 2*np.eye(n)
At = np.vstack([A, w])
Q, R = qr(At)
Qh, Rh = qr_householder(At)

print((np.abs(Q)-np.abs(Qh)).sum())
print((np.abs(R)-np.abs(Rh)).sum())

# q, r = qr(At, mode="economic")
# qj, rj = update_qr_jax_eco(A, w, Q, R)

# %timeit _, _ = qr(At, mode="economic")
# %timeit _, _ = qr_householder(At)
# %timeit _, _ = qr_house_jax(A)
# %timeit qj, rj = update_qr_jax_eco(A, w, Q, R)
# %timeit qj, rj = update_qr_jax_eco_for(A, w, Q, R)


# print(np.allclose(q@r, qj@rj))
# print(f"{q.shape=}, {r.shape=}")
# print(f"{qj.shape=}, {rj.shape=}")
# print(rj)
# print(r)
# print(qj)
# print(q)
# print((np.abs(q)-np.abs(qj)).sum())
# print((np.abs(r)-np.abs(rj)).sum())

[[ -2.7203903  -19.85354657 -43.00839413 -66.02147253 -10.72869863  -5.62531529 -64.47729787 -47.01398558 -39.6249137  -55.12738198]
 [  0.          28.58766309  36.1209074   98.4054998   29.63933786  88.16519209  33.22541632  29.55056901  69.97451965  50.92707038]
 [  0.           0.          49.28284417  24.05125309  75.28931787  52.06345881  77.16379404  94.31678695  72.15533443  46.00944234]
 [  0.           0.           0.           9.59209258  10.50175975  60.11156289  16.88977386  33.7254631   93.40056901  26.78074486]
 [  0.           0.           0.           0.          43.45930937  15.2231329   79.25069933  69.02497985  79.49674123  50.58805458]
 [  0.           0.           0.           0.           0.          87.65163402  91.85787165  63.27447278  64.02809586  55.95042079]
 [  0.           0.           0.           0.           0.           0.          30.38670856  38.73413928  19.57005006  72.60234899]
 [  0.           0.           0.           0.           0.           

In [None]:
def givens(a, b):
    if np.abs(b) < 1e-6:
        s = 0
        c = np.sign(a)
    elif np.abs(a) < 1e-6:
        c = 0
        s = -np.sign(b)
    elif np.abs(a) > np.abs(b):
        t = b / a
        u = np.sign(a) * np.sqrt(1 + t * t)
        c = 1 / u
        s = -c * t
    else:
        t = a / b
        u = np.sign(b) * np.sqrt(1 + t * t)
        s = -1 / u
        c = t / u
        
    G2 = np.array([[c, -s], 
                  [s, c]])
    return G2.astype(float)

In [None]:
if 0:
    A = np.array([[6, 5, 0, 4], 
                [5, 1, 4, 3], 
                [0, 4, 3, 2]])
    A_org = np.array([[6, 5, 0, 4], 
                [5, 1, 4, 3], 
                [0, 4, 3, 2]])
    print(A)
elif 1:
    A = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3],
                [2, 4, 3]])
    A_org = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3],
                [2, 4, 3]])
    print(A)
else:
    A = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3]])
    A_org = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3]])
    print(A)

A = A.astype(float)
m, n = A.shape
Q = np.eye(m)
R = A.copy()
for j in range(n):
    for i in range(m-1, j, -1):
        if R[i, j] != 0:
            print(f"There is none-zero element at ({i}, {j}) with value {R[i, j]}")
            a, b = R[i-1, j], R[i, j]
            G2 = givens(a, b)
            R[[i-1, i], j:] = G2@R[[i-1, i], j:]
            Q[:, [i-1, i]] = Q[:, [i-1, i]] @ G2.T
            # print(R)
            # print(Q)
print(np.dot(Q, R))           
q, r = qr(A_org)
# print(q@r)
print(np.allclose(A, r))


In [None]:
if 0:
    A = np.array([[6, 5, 0, 4], 
                [5, 1, 4, 3], 
                [0, 4, 3, 2]])
    A_org = np.array([[6, 5, 0, 4], 
                [5, 1, 4, 3], 
                [0, 4, 3, 2]])
    print(A)
elif 1:
    A = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3],
                [2, 4, 3]])
    A_org = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3],
                [2, 4, 3]])
    print(A)
else:
    A = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3]])
    A_org = np.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3]])
    print(A)

A = A.astype(float)

w = np.eye(A.shape[1])
A_t = np.vstack([A, w])
q_org, r_org = qr(A_org)

H = np.vstack([r_org, w])

J = np.eye(A_t.shape[0])
J[:A.shape[0], :A.shape[0]] = q_org

for j in range(H.shape[1]):
    for i in range(H.shape[0]-1, j, -1):
        if H[i, j] != 0:
            print(f"There is none-zero element at ({i}, {j}) with value {H[i, j]}")
            a, b = H[i-1, j], H[i, j]
            G2 = givens(a, b)
            H[[i-1, i], j:] = G2@H[[i-1, i], j:]
            J[:, [i-1, i]] = J[:, [i-1, i]] @ G2.T
            # print(H)
R = H
q, r = qr(A_t)

A_t_comp = J@R
A_t_comp[A_t_comp < 1e-10] = 0
print(A_t_comp)

In [None]:
import jax.numpy as jnp
from jax.lax import fori_loop
from jax import jit
from jax.scipy.linalg import block_diag
from scipy.linalg import qr

def givens_jax(a, b):
    r = jnp.sqrt(a**2 + b**2)
    c = a / r
    s = -b / r
        
    G2 = jnp.array([[c, -s], 
                  [s, c]])
    return G2.astype(float)

@jit
def update_qr_jax(A, w, q, r):
    m, n = A.shape
    Q = jnp.eye(m+n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])

    def body_inner(i, jQR):
        j, Q, R = jQR
        i = m+j-i
        a, b = R[i-1, j], R[i, j]
        G2 = givens_jax(a, b)
        R = R.at[jnp.array([i-1, i])].set(G2@R[jnp.array([i-1, i])])
        Q = Q.at[:, jnp.array([i-1, i])].set(Q[:, jnp.array([i-1, i])] @ G2.T)
        return j, Q, R
    
    def body(j, QR):
        Q, R = QR
        j, Q, R = fori_loop(0, m, body_inner, (j, Q, R))
        return Q, R

    Q, R = fori_loop(0, n, body, (Q, R)) 

    return Q, R

A = jnp.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3]])
print(A)

A = A.astype(float)

w = jnp.eye(A.shape[1])
q, r = qr(A)
Q, R = update_qr_jax(A, w, q, r)
print(Q@R)
# print(q@r)

%timeit -n 1000 Q, R = update_qr_jax(A, w, q, r)
%timeit -n 1000 q, r = qr(jnp.vstack([A, w]))

[[6 5 0]
 [5 1 4]
 [0 4 3]]
[[ 6.00000000e+00  5.00000000e+00 -2.03266369e-18]
 [ 5.00000000e+00  1.00000000e+00  4.00000000e+00]
 [ 1.54074396e-33  4.00000000e+00  3.00000000e+00]
 [ 1.00000000e+00  1.73205888e-17  5.55111512e-17]
 [ 7.70371978e-34  1.00000000e+00 -3.46510183e-17]
 [-1.92592994e-34 -2.40741243e-35  1.00000000e+00]]
18.3 μs ± 1.3 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
143 μs ± 26.8 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


## Some old version with for loop

In [None]:
import jax.numpy as jnp
from jax.lax import fori_loop
from jax import jit
from jax.scipy.linalg import block_diag
from scipy.linalg import qr

def givens_jax(a, b):
    r = jnp.sqrt(a**2 + b**2)
    c = a / r
    s = -b / r
        
    G2 = jnp.array([[c, -s], 
                  [s, c]])
    return G2.astype(float)

@jit
def update_qr_jax(A, w, q, r):
    m, n = A.shape
    Q = jnp.eye(m+n)
    Q = Q.at[:m, :m].set(q)

    R = jnp.vstack([r, w])

    for j in range(n):
        for i in range(m+j, j, -1):
            a, b = R[i-1, j], R[i, j]
            G2 = givens_jax(a, b)
            R = R.at[[i-1, i], j:].set(G2@R[[i-1, i], j:])
            Q = Q.at[:, [i-1, i]].set(Q[:, [i-1, i]] @ G2.T)
            
    return Q, R

A = jnp.array([[6, 5, 0], 
                [5, 1, 4], 
                [0, 4, 3]])
print(A)

A = A.astype(float)

w = jnp.eye(A.shape[1])
q, r = qr(A)
Q, R = update_qr_jax(A, w, q, r)

%timeit -n 100 Q, R = update_qr_jax(A, w, q, r)
%timeit -n 100 q, r = qr(jnp.vstack([A, w]))

In [None]:
@jit
def trust_region_step_exact_qr(
    Q, R, p_newton, f, J, trust_radius, initial_alpha=None, rtol=0.01, max_iter=10
):
    """Solve a trust-region problem using a semi-exact method.

    Solves problems of the form
        min_p ||J*p + f||^2,  ||p|| < trust_radius

    Parameters
    ----------
    f : ndarray
        Vector of residuals.
    J : ndarray
        Jacobian matrix.
    trust_radius : float
        Radius of a trust region.
    initial_alpha : float, optional
        Initial guess for alpha, which might be available from a previous
        iteration. If None, determined automatically.
    rtol : float, optional
        Stopping tolerance for the root-finding procedure. Namely, the
        solution ``p`` will satisfy
        ``abs(norm(p) - trust_radius) < rtol * trust_radius``.
    max_iter : int, optional
        Maximum allowed number of iterations for the root-finding procedure.

    Returns
    -------
    p : ndarray, shape (n,)
        Found solution of a trust-region problem.
    hits_boundary : bool
        True if the proposed step is on the boundary of the trust region.
    alpha : float
        Positive value such that (J.T*J + alpha*I)*p = -J.T*f.
        Sometimes called Levenberg-Marquardt parameter.

    """

    # def truefun(*_):
    #     print("TrueFun")
    #     return p_newton, False, 0.0

    # def falsefun(*_):
    #     print("FalseFun")
    #     alpha_upper = jnp.linalg.norm(J.T @ f) / trust_radius
    #     alpha_lower = 0.0
    #     alpha = setdefault(
    #         initial_alpha,
    #         jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
    #     )
    #     k = 0
    #     # algorithm 4.3 from Nocedal & Wright
    #     fp = jnp.pad(f, (0, J.shape[1]))

    #     def loop_cond(state):
    #         alpha, alpha_lower, alpha_upper, phi, k = state
    #         return (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter)

    #     def loop_body(state):
    #         print("LoopBody")
    #         alpha, alpha_lower, alpha_upper, phi, k = state

    #         alpha = jnp.where(
    #             (alpha < alpha_lower) | (alpha > alpha_upper),
    #             jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
    #             alpha,
    #         )
    #         Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R)
    #         print(f"{Q2.shape=}, {R2.shape=} {J.shape=}")
    #         print("First triangular regularized")
    #         p = solve_triangular_regularized(R2, -Q2.T @ fp)
    #         p_norm = jnp.linalg.norm(p)
    #         phi = p_norm - trust_radius
    #         alpha_upper = jnp.where(phi < 0, alpha, alpha_upper)
    #         alpha_lower = jnp.where(phi > 0, alpha, alpha_lower)

    #         print("Second triangular regularized")

    #         q = solve_triangular_regularized(R2.T, p, lower=True)
    #         q_norm = jnp.linalg.norm(q)

    #         alpha += (p_norm / q_norm) ** 2 * phi / trust_radius
    #         alpha = jnp.where(
    #             (alpha < alpha_lower) | (alpha > alpha_upper),
    #             jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
    #             alpha,
    #         )
    #         k += 1
    #         return alpha, alpha_lower, alpha_upper, phi, k

    #     alpha, *_ = while_loop(
    #         loop_cond, loop_body, (alpha, alpha_lower, alpha_upper, jnp.inf, k)
    #     )
    #     print(f"Final QR update, {J.shape=}, {Q.shape=}, {R.shape=}")
    #     Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R)
    #     print(f"{Q2.shape=}, {R2.shape=} {J.shape=} {(-Q2.T @ fp).shape}")
    #     print("Final triangular solve")
    #     p = solve_triangular(R2, -Q2.T @ fp)

    #     # Make the norm of p equal to trust_radius; p is changed only slightly.
    #     # This is done to prevent p from lying outside the trust region
    #     # (which can cause problems later).
    #     p *= trust_radius / jnp.linalg.norm(p)

    #     return p, True, alpha

    # return cond(jnp.linalg.norm(p_newton) <= trust_radius, truefun, falsefun, None)

    if jnp.linalg.norm(p_newton) <= trust_radius:
        print("TrueFun")
        return p_newton, False, 0.0

    else:
        print("FalseFun")
        alpha_upper = jnp.linalg.norm(J.T @ f) / trust_radius
        alpha_lower = 0.0
        alpha = setdefault(
            initial_alpha,
            jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
        )
        k = 0
        # algorithm 4.3 from Nocedal & Wright
        fp = jnp.pad(f, (0, J.shape[1]))
        phi = jnp.inf
        while (jnp.abs(phi) > rtol * trust_radius) & (k < max_iter):
            print("LoopBody")
            alpha = jnp.where(
                (alpha < alpha_lower) | (alpha > alpha_upper),
                jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
                alpha,
            )
            print("Updating QR")
            # Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R)
            Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
            Q2, R2 = qr(Ji, mode="economic")
            print(f"{Q2.shape=}, {R2.shape=} {J.shape=}")
            print(R2)

            print("First triangular regularized")
            p = solve_triangular_regularized(R2, -Q2.T @ fp)
            p_norm = jnp.linalg.norm(p)
            phi = p_norm - trust_radius
            alpha_upper = jnp.where(phi < 0, alpha, alpha_upper)
            alpha_lower = jnp.where(phi > 0, alpha, alpha_lower)

            print("Second triangular regularized")

            q = solve_triangular_regularized(R2.T, p, lower=True)
            q_norm = jnp.linalg.norm(q)

            alpha += (p_norm / q_norm) ** 2 * phi / trust_radius
            alpha = jnp.where(
                (alpha < alpha_lower) | (alpha > alpha_upper),
                jnp.maximum(0.001 * alpha_upper, (alpha_lower * alpha_upper) ** 0.5),
                alpha,
            )
            k += 1

        print(f"Final QR update, {J.shape=}, {Q.shape=}, {R.shape=}")
        # Q2, R2 = update_qr_jax_eco(J, jnp.sqrt(alpha) * jnp.eye(J.shape[1]), Q, R)

        Ji = jnp.vstack([J, jnp.sqrt(alpha) * jnp.eye(J.shape[1])])
        Q2, R2 = qr(Ji, mode="economic")
        print(f"{Q2.shape=}, {R2.shape=} {J.shape=} {(-Q2.T @ fp).shape}")
        print("Final triangular solve")
        print(R2)
        p = solve_triangular(R2, -Q2.T @ fp)

        # Make the norm of p equal to trust_radius; p is changed only slightly.
        # This is done to prevent p from lying outside the trust region
        # (which can cause problems later).
        p *= trust_radius / jnp.linalg.norm(p)

        return p, True, alpha