https://arxiv.org/pdf/2502.16020
<hr></hr>

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.scipy as jsp
from toqito.rand import random_psd_operator

In [None]:
def sqrtm(A):
    s = 10
    I = jnp.eye(A.shape[0])
    def db_iter_sqrt(X, *args):
        X1i = jnp.linalg.inv(X[1])
        return ((0.5 * X[0] @ (I + X1i), 0.5 * (I + 0.5 * (X[1] + X1i))), None)
    return jax.lax.scan(db_iter_sqrt, (A, A), length=s)[0][0]
sqrtm = jax.jit(sqrtm)

def logm(A):
    s = 30
    Ap = sqrtm(A)
    I = jnp.eye(A.shape[0])
    Z0 = Ap - I
    if s == 1:
        return Z0
    Ap = sqrtm(Ap)
    def logm_iter(X, _):
        Asqrt = sqrtm(X[0])
        return ((Asqrt, X[1] @ (I + Asqrt)), None)
    P = jax.lax.scan(logm_iter, (Ap, I + Ap), length=s - 2)[0][1]
    return 2 ** s * Z0 @ jsp.linalg.inv(P)
logm = jax.jit(logm)

In [None]:
def barrier(X):
    """Returns -log det X."""
    sign, value = jnp.linalg.slogdet(X)
    return -sign * value
def grad_barrier(X):
    """Returns -Xinv."""
    return -jnp.linalg.inv(X)
def hvp(X, V):
    return jax.jvp(grad_barrier, (X,), (V,))[1]
hvp_jit = jax.jit(hvp)

In [52]:
def ope_barrier(Z, X, Y):
    Xsqrt = sqrtm(X)
    Xsqrtinv = jsp.linalg.inv(Xsqrt)
    signlZ, lZ = jnp.linalg.slogdet(Z + Xsqrt @ logm(Xsqrtinv @ Y @ Xsqrtinv) @ Xsqrt)
    signlX, lX = jnp.linalg.slogdet(X)
    signlY, lY = jnp.linalg.slogdet(Y)
    return -(signlZ * lZ + signlX * lX + signlY * lY)
jac = jax.jacobian(ope_barrier, argnums=[0, 1, 2], holomorphic=True)
hess = jax.hessian(ope_barrier, argnums=[0, 1, 2], holomorphic=True)

ope_jit = jax.jit(ope_barrier)
jac_jit = jax.jit(jac)
hess_jit = jax.jit(hess)

In [None]:
L = 10
X = random_psd_operator(L, is_real=False)
Y = random_psd_operator(L, is_real=False)
Z = random_psd_operator(L, is_real=False)

In [70]:
%timeit -n10 -r7 ope_jit(Z, X, Y).block_until_ready()

934 µs ± 91.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [71]:
%timeit -n10 -r7 jac_jit(Z, X, Y)[0].block_until_ready()

2.37 ms ± 202 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [72]:
%timeit -n10 -r7 hess_jit(Z, X, Y)[0][0].block_until_ready()

323 ms ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
