In [1]:
import numpy as np
import rqcopt as oc
import sys
sys.path.append('..')
%load_ext autoreload
%autoreload 2
import jax
from jax import config
config.update("jax_enable_x64", True)

In [2]:
# get the setup for the problem
from opentn.transformations import create_kitaev_liouvillians, exp_operator_dt, factorize_psd, super2choi
d, N, gamma = 2, 4, 1
tau = 4
dim = d**N
Lvec, Lvec_odd, Lvec_even, Lnn = create_kitaev_liouvillians(N=N, d=d, gamma=gamma)
superops_exp = []
for i, op in enumerate([Lvec, Lvec_odd, Lvec_even]):
    if i == 1:
        superops_exp.append(exp_operator_dt(op, tau/2, 'jax'))
    else:
        superops_exp.append(exp_operator_dt(op, tau, 'jax'))
exp_Lvec, exp_Lvec_odd, exp_Lvec_even = superops_exp

tol = 1e-12

X1 = factorize_psd(psd=super2choi(exp_Lvec_odd), tol=tol)
X2 = factorize_psd(psd=super2choi(exp_Lvec_even), tol=tol)

In [3]:
# NOTE: Everything is real so the norm of the imaginary part is completely zero for the gradient and the Xi themselves
for op in [exp_Lvec, exp_Lvec_odd, exp_Lvec_even]:
    print(np.linalg.norm(op.imag))

0.0
0.0
0.0


In [4]:
# we have to start working with X that are (out, k, in) instead of (out, in, k)
# first step: get the k for each layer.
k1 = np.linalg.matrix_rank(X1)
k2 = np.linalg.matrix_rank(X2)

def split_matrix_svd(A, chi_max=2, eps=1e-9):
    """
    Split a matrix by singular value decomposition,
    and truncate small singular values based on tolerance.
    """
    assert A.ndim == 2
    u, s, v = np.linalg.svd(A, full_matrices=False)
    # truncate small singular values
    chi_keep = min(chi_max, np.sum(s > eps))
    assert chi_keep >=1

    idx_keep = np.argsort(s)[::-1][:chi_keep]  # keep the largest `chivC` singular values

    u = u[:, idx_keep]
    v = v[idx_keep, :]
    s = s[idx_keep]
    return u, s, v


def factorize_psd_truncated(psd, chi_max=2, eps=1e-9):
    "factorize psd matrix truncating the singular values based on parameters"
    x, s, xdg = split_matrix_svd(psd, chi_max, eps)
    return x@np.diag(np.sqrt(s))


xs_truncated = []
for X, op in zip([X1,X2], [exp_Lvec_odd, exp_Lvec_even]):
    k = np.linalg.matrix_rank(X)
    x_truncated = factorize_psd_truncated(psd=super2choi(op), chi_max=k)
    C_trnc = x_truncated@x_truncated.conj().T
    print(np.allclose(C_trnc, X@X.conj().T))
    xs_truncated.append(x_truncated)
x1_truncated, x2_truncated = xs_truncated

True
True


In [5]:
# now we need to transform the matrices X we constructed to match the orthogonality condition 
def choi2ortho(x:np.ndarray):
    "transform the x matrices that factorize a choi matrix into its orthogonal form"
    # reshape the x matrix from (out, in, k) to (out, k, in)
    dim = int(np.sqrt(x.shape[0]))
    k = x.shape[1]
    x = np.reshape(x, [dim, dim, k])
    x = x.swapaxes(1,2).reshape([dim*k, dim])
    return x

for x_truncated in [x1_truncated, x2_truncated]:
    print(x_truncated.shape)
    x_ortho = choi2ortho(x_truncated)
    print(x_ortho.shape)
    print(np.allclose(x_ortho.conj().T@x_ortho, np.eye(dim)))

(256, 4)
(64, 16)
True
(256, 2)
(32, 16)
True


In [73]:
# let's create the gradient and see if we can make it into a matrix

# step 1: create the cost function.
# option 1: as we have done it so far, with superoperators and composing them with @. model ys
# option 2: with choi matrices and the choi composition function. model cs
xi_init = [x1_truncated, x2_truncated, x1_truncated]
from opentn.optimization import frobenius_norm, model_Ys, compute_loss
print(compute_loss(xi=xi_init, loss_fn=frobenius_norm, model=model_Ys, exact=exp_Lvec))
# but f should only accept the list of xi as input, meaning that the exact should be given already.
f = lambda xi: frobenius_norm(model_Ys(xi), exp_Lvec)
print(f(xi_init))
# we leave the retraction for later. For now let's focus on the gradient and hessian.
# get the euclidean gradient using jax

from opentn.transformations import vectorize

# the riemannian gradient is obtained projecting to tangent space. 
def project(X, Z):
    "project Z vector onto tangent space at X on manifold"
    return Z - 0.5 * X @ (X.conj().T @ Z + Z.conj().T @ X)

def rgrad_f(xi):
    "compute riemannian gradient for all xi, returning a list"
    Zi = jax.grad(f)(xi)
    return [project(X, Z)
    for X,Z in zip(xi, Zi)]

def rgrad_f_vec(xi):
    "compute the vectorized gradient for all xi"
    return np.vstack([
            vectorize(grad) 
    for grad in rgrad_f(xi)]).reshape(-1)

print(rgrad_f_vec(xi_init).shape)

# now it is the hessian which remains a question. Let us just create the directed matrix as I think it should be, and then we see if it works
def metric(delta1, delta2, X):
    """
    riemannian metric between delta1 and delta2 in tangent space at X in manifold. 
    From https://arxiv.org/abs/2112.05176 eq. 24
    """
    dim = X.shape[0]
    gamma = np.eye(dim) - 0.5 * (X@X.conj().T)
    return np.trace(delta1.conj().T@gamma@delta2).real

def hvp(xi, v):
    "from https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#hessian-vector-products-with-grad-of-grad"
    return project(
        jax.grad(lambda xi: metric(rgrad_f(xi), v, x))(xi[0]),
        x[0])



# v = np.zeros_like(xi_init[0].reshape(-1))
# v[0] = 1
# v = v.reshape(xi_init[0].shape)
# hvp(xi_init, v=v)

0.09591767023235143
0.09591767023235143
(2560,)


In [81]:
def hessian_riem(xi):
    primals, f_vjp = jax.jvp(fun=rgrad_f, primals=[xi], tangents=[xi])
    # for p in primals:
    #     print(p.shape)
    hess_xi = [project(X, Z) for X,Z in zip(xi, primals)]
    return hess_xi

def hessian_riem_vec(xi):
    return np.vstack([
            vectorize(hess) 
    for hess in hessian_riem(xi)]).reshape(-1)



(2560,)

(256, 4)
(256, 2)
(256, 4)


In [42]:
lambda metric(rgrad_f(xi_init)[0], v, xi_init[0])

-7.311165756268884

In [29]:
grads = jax.grad(f)(xi_init)
metric(grads[0], grads[2], xi_init[0])

1.7305042292865611

In [24]:
# lets concatenate vectors into one long vector
a = np.arange(5).reshape((5,1))
b =  np.arange(3).reshape((3,1))
c = np.arange(2).reshape((2,1))
np.vstack([a,b,c]).reshape(-1)

array([0, 1, 2, 3, 4, 0, 1, 2, 0, 1])

In [None]:
oc.riemannian_trust_region_optimize()

In [None]:
oc.retract_unitary_list()