In [1]:
import numpy as np
from scipy import optimize, stats, linalg
import seaborn as sns
import pymc as pm
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
X = np.random.randn(10, 10)
x = np.random.randn(10)

In [2]:
def cov_to_corr(cov):
    corr = cov.copy()
    stds = np.sqrt(np.diag(corr))
    corr /= stds[None, :]
    corr /= stds[:, None]
    return corr

def generate_covariance(N, n_vals_non_id, vals_non_id_log_dist, diag_log_dist, lamda_id, eigvec_dist):
    x = eigvec_dist.rvs((N, n_vals_non_id))
    #x = stats.dirichlet(0.001 * np.ones(N)).rvs(n_vals_non_id).T
    x = x * np.sign(np.random.randn(N, n_vals_non_id))
    Q, R = linalg.qr(x, mode='economic')
    vals = np.exp(vals_non_id_log_dist.rvs(n_vals_non_id))
    vals.sort()
    #vals = vals[::-1]

    cov = Q @ np.diag(vals) @ Q.T + lamda_id * (np.eye(N) - Q @ Q.T)

    extra_std = np.sqrt(np.exp(diag_log_dist.rvs(N)))
    cov *= extra_std[:, None]
    cov *= extra_std[None, :]
    
    return cov, vals

def draw_samples(cov, n_draws):
    mean = np.random.randn(len(cov))
    #mean = np.zeros(len(cov))
    dist = stats.multivariate_normal(mean=mean, cov=cov)
    draws = dist.rvs(n_draws)

    #def logp(val, mean, cov):
    #    return jax.scipy.stats.multivariate_normal.logpdf(val, mean=mean, cov=cov)
    def logp(val):
        return dist.logpdf(val)

    #def dlogp(val, mean, cov):
    #    return jax.grad(logp)(val, mean, cov)
    def dlogp(values, mean, cov):
        return -linalg.solve(cov, (values - mean).T, sym_pos=True).T

    #logp_vec = jax.vmap(logp, in_axes=[0, None, None])
    #dlogp_vec = jax.vmap(dlogp, in_axes=[0, None, None])

    grads = dlogp(draws, mean, cov)
    return draws, grads, mean

In [3]:
from scipy.sparse import linalg as slinalg

cov, _ = generate_covariance(1000, 20, stats.norm(0, 2), stats.norm(0, 2), 1, stats.norm(0, 1))

draws, grads, mean = draw_samples(cov, 80)

draws = draws - draws.mean(axis=0, keepdims=True)
grads = grads - grads.mean(axis=0, keepdims=True)

n_draws, N = draws.shape

stds = np.sqrt(draws.std(0) / grads.std(0))

draws_ = draws / stds
grads_ = grads * stds

  return -linalg.solve(cov, (values - mean).T, sym_pos=True).T


In [4]:
((np.log(linalg.eigvalsh(np.diag(np.std(draws, 0) / np.std(grads, 0)), cov))) ** 2).sum()

59.95140402785537

In [5]:
alpha = 1000

def logdet(X):
    eigs = linalg.eigvalsh(X)
    return np.log(eigs).sum()

def cost(X):
    stds = np.sqrt(np.diag(X))
    corr = X / stds[None, :] / stds[:, None]
    #return np.trace(grads @ X @ grads.T) + np.trace(draws @ linalg.inv(X) @ draws.T) - alpha * logdet(corr)
    return np.trace(grads_ @ X @ grads_.T) + np.trace(draws_ @ linalg.inv(X) @ draws_.T) - alpha * logdet(corr)

    #return np.trace(grads_ @ X @ grads_.T) + np.trace(draws_ @ linalg.inv(X) @ draws_.T) - alpha * (np.trace(X) + np.trace(linalg.inv(X)))

def r_grad(X):
    diag = np.diag(X)
    #return X @ grads.T @ grads @ X - draws.T @ draws + alpha * (X @ np.diag(1/diag) @ X - X)
    return X @ grads_.T @ grads_ @ X - draws_.T @ draws_ + alpha * (X @ np.diag(1/diag) @ X - X)
    #return X @ grads_.T @ grads_ @ X - draws_.T @ draws_ - alpha * (np.eye(N) - X @ X)

In [71]:
grads_.shape

(80, 1000)

In [88]:
%timeit grads_ @ U

27 µs ± 539 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [80]:
U_tr = np.asfortranarray(U.T)
grads_tr = np.asfortranarray(grads_.T)

In [81]:
%timeit U_tr @ grads_tr

36.8 µs ± 1.37 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [91]:
grads__ = np.ascontiguousarray(grads_)
U__ = np.asfortranarray(U)

In [92]:
%timeit grads__ @ U__

26.8 µs ± 559 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [95]:
v = np.ones(N)

In [97]:
%timeit draws_ @ v

10.4 µs ± 904 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [99]:
draws__ = np.asfortranarray(draws_)

In [100]:
%timeit draws__ @ v

13.2 µs ± 183 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [102]:
%timeit np.array(1)

286 ns ± 4.33 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [103]:
import numba

In [107]:
@numba.njit
def alloc_stuff():
    for _ in range(1_000_000):
        x = np.array(1)

alloc_stuff()

%timeit alloc_stuff()

19.6 ms ± 526 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [108]:
def cost_fact(U, s):
    l_m1 = np.expm1(s)
    l_inv_m1 = np.expm1(-s)
    GU = grads_ @ U
    CU = draws_ @ U

    # To match the naive impl. Not necessary though, because extra terms are constant
    # out = (GU * GU * l_m1).sum() + (grads_ ** 2).sum() + (CU * CU * l_inv_m1).sum() + (draws_ ** 2).sum()
    out = (GU * GU * l_m1).sum() + (CU * CU * l_inv_m1).sum()

    X_diag = (U * U * l_m1).sum(1) + 1
    reg = -alpha * (s.sum() - np.log(X_diag).sum())
    return out + reg


def r_grad_action(U, s, v):
    l_m1 = np.expm1(s)
    X_diag = (U * U * l_m1).sum(1) + 1

    def X_mult(u):
        return u + U @ (l_m1 * (U.T @ u))

    Xv = X_mult(v)

    return (
        X_mult(grads_.T @ (grads_ @ Xv))
        - draws_.T @ (draws_ @ v)
        - alpha * Xv
        + alpha * X_mult(Xv / X_diag)
    )

In [109]:
U = U[:, :4]

In [110]:
l_m1 = np.expm1(s)[:4]

def X_mult(u):
    return u + U @ (l_m1 * (U.T @ u))

In [111]:
u = np.zeros(N)

In [112]:
%timeit X_mult(u)

5.79 µs ± 61.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [120]:
call_count = 0

def retraction_action(U, s, v, step_size):
    global call_count
    call_count += 1
    l_m1 = np.expm1(s)

    def X_mult(u):
        return u + U @ (l_m1 * (U.T @ u))

    return X_mult(v) - step_size * r_grad_action(U, s, v)

In [8]:
import pymanopt
import jax.numpy as jnp
import jax.scipy.linalg as jlinalg

In [9]:
man = pymanopt.manifolds.SymmetricPositiveDefinite(N)
@pymanopt.function.jax(man)
def cost_manopt(X):
    stds = jnp.sqrt(jnp.diag(X))
    corr = X / stds[None, :] / stds[:, None]
    chol_corr = jlinalg.cholesky(corr)
    return jnp.trace(grads_ @ X @ grads_.T) + jnp.trace(draws_ @ jlinalg.solve(X, draws_.T, assume_a="pos")) - 2 * alpha * jnp.log(jnp.diag(chol_corr)).sum()

In [None]:
problem = pymanopt.Problem(man, cost_manopt)

optimizer = pymanopt.optimizers.TrustRegions()
result = optimizer.run(problem, initial_point=np.eye(N))

In [114]:
X_trust = result.point# * stds[None, :] * stds[:, None]
cov_trust = X_trust * stds[None, :] * stds[:, None]

((np.log(linalg.eigvalsh(cov_trust, cov))) ** 2).sum()

NameError: name 'result' is not defined

In [11]:
stepsize = 1e-4

X_trace = []
grad_trace = []
norm_trace = []

#X = np.diag(np.std(draws, 0) / np.std(grads, 0))
X = np.eye(N)

for iter in range(10):
    X_trace.append(X)
    grad = r_grad(X)
    grad_trace.append(grad)

    #vals, vecs = slinalg.eigsh(slinalg.aslinearoperator(grad), k=20)
    #grad = vecs @ np.diag(vals) @ vecs.T
    norm = np.trace(linalg.inv(X) @ grad @ linalg.inv(X) @ grad)

    #if iter > 200:
    #    stepsize = 1e-3

    #if norm > 2:
    #    grad *= 2 / norm

    g = -stepsize * grad
    X = X + g# + g @ linalg.inv(X) @ g / 2

    #stds_ = np.sqrt(np.diag(X))
    #vals, vecs = slinalg.eigsh(slinalg.aslinearoperator(X / stds_[None, :] / stds_[:, None]), k=20, which="BE")
    vals, vecs = slinalg.eigsh(slinalg.aslinearoperator(X), k=4, which="BE")
    X = (vecs @ np.diag(vals - 1) @ vecs.T + np.eye(N))# * stds_[None, :] * stds_[:, None]
    #norm = np.trace(linalg.inv(X) @ grad @ linalg.inv(X) @ grad)
    
    #retract_operator = slinalg.LinearOperator((N, N), matvec=lambda v: retraction_action(vecs, np.log(vals), v, 1e-4))
    #slinalg.eigsh(retract_operator, k=10, which="BE")
    
    norm_trace.append(norm)
    if norm < 1e-3:
        break

In [121]:
%%time

stepsize = 1e-4
k = 15

cost_trace = []
U_trace = []
s_trace = []

U = np.eye(N)[:, :k]
s = np.zeros(k)

for iter in range(1000):
    U_trace.append(U)
    s_trace.append(s)
    cost_trace.append(cost_fact(U, s))

    retract_operator = slinalg.LinearOperator((N, N), matvec=lambda v: retraction_action(U, s, v, stepsize))
    s, U = slinalg.eigsh(retract_operator, k=k, which="BE", v0=U.mean(1))
    #s, U = slinalg.eigsh(retract_operator, k=k, which="BE")
    s = np.log(s)

X = (U @ np.diag(np.expm1(s)) @ U.T + np.eye(N))

CPU times: user 7.74 s, sys: 134 ms, total: 7.87 s
Wall time: 7.89 s


In [125]:
call_count

32076

In [126]:
l_m1 = np.expm1(s)
l_inv_m1 = np.expm1(-s)

def X_mult(u):
    return u + U @ (l_m1 * (U.T @ u))

def X_inv_mult(u):
    return u + U @ (l_inv_m1 * (U.T @ u))

def OPinv(x):
    return slinalg.cg(retract_operator, x, M=slinalg.LinearOperator((N, N), matvec=lambda x: X_inv_mult(x)))[0]

In [133]:
#%%timeit
_ = slinalg.eigsh(retract_operator, k=k, which="LM", sigma=0, OPinv=slinalg.LinearOperator((N, N), matvec=OPinv))

In [156]:
call_count = 0

In [157]:
_ = slinalg.eigsh(retract_operator, k=k, which="BE", v0=U.mean(1))

In [158]:
call_count

31

In [18]:
%%timeit
_ = slinalg.eigsh(retract_operator, k=k, which="LM", sigma=0, OPinv=slinalg.LinearOperator((N, N), matvec=OPinv), v0=U.mean(1))

211 ms ± 5.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [20]:
%timeit slinalg.eigsh(retract_operator, k=k, which="BE", v0=U.mean(1))
%timeit slinalg.eigsh(retract_operator, k=k, which="BE")

6.07 ms ± 108 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
7.62 ms ± 144 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
from petsc4py import PETSc
from slepc4py import SLEPc

In [12]:
n = 30

In [25]:
class MatrixFreeB(object):

    def __init__(self, matvec):
        self.matvec = matvec

    def mult(self, mat, x, y):
        self.matvec
        y.array[:] = self.matvec(x.array)
        # y <- A x
        #self.A.mult(x, y)

        # alpha <- v^T x
        #alpha = self.v.dot(x)

        # y <- y + alpha*u
        #y.axpy(alpha, self.u)


B = PETSc.Mat().create()


# Build the matrix "context"
Bctx = MatrixFreeB(matvec=lambda v: retraction_action(U, s, v, stepsize))

# Set up B
# B is the same size as A
B.setSizes([N, N])

B.setType(B.Type.PYTHON)
B.setPythonContext(Bctx)
B.setUp()

<petsc4py.PETSc.Mat at 0x7faee12093a0>

In [56]:
SLEPc.MFN??

[0;31mInit signature:[0m [0mSLEPc[0m[0;34m.[0m[0mMFN[0m[0;34m([0m[0mself[0m[0;34m,[0m [0;34m/[0m[0;34m,[0m [0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m      MFN
[0;31mFile:[0m           ~/mambaforge/envs/pymc-dev/lib/python3.10/site-packages/slepc4py/lib/SLEPc.cpython-310-x86_64-linux-gnu.so
[0;31mType:[0m           type
[0;31mSubclasses:[0m     


In [51]:
#%%timeit
E = SLEPc.EPS(); E.create()

E.setOperators(B)
E.setProblemType(SLEPc.EPS.ProblemType.HEP)
#E.setFromOptions()
E.setDimensions(15 // 2)
E.setWhichEigenpairs(SLEPc.EPS.Which.LARGEST_MAGNITUDE)
#E.setType(SLEPc.EPS.Type.LANCZOS)

%time E.solve()

CPU times: user 11 ms, sys: 0 ns, total: 11 ms
Wall time: 10.3 ms


In [52]:
its = E.getIterationNumber()
#Print("Number of iterations of the method: %d" % its)

eps_type = E.getType()
#Print("Solution method: %s" % eps_type)

nev, ncv, mpd = E.getDimensions()
#Print("Number of requested eigenvalues: %d" % nev)

tol, maxit = E.getTolerances()
#Print("Stopping condition: tol=%.4g, maxit=%d" % (tol, maxit))


nconv = E.getConverged()
#Print("Number of converged eigenpairs %d" % nconv)


if nconv > 0:
    # Create the results vectors
    vr, wr = B.getVecs()
    vi, wi = B.getVecs()
    #
    for i in range(nconv):
        k = E.getEigenpair(i, vr, vi)
        error = E.computeError(i)
        if k.imag != 0.0:
            print(" %9f%+9f j %12g" % (k.real, k.imag, error))
        else:
            print(" %12f      %12g" % (k.real, error))

     2.635141       3.52461e-16
     2.172873       6.93207e-16
     1.536547       2.92576e-13
     1.315439       1.17337e-09
     1.234448       1.24338e-15
     1.195038       4.28531e-15
     1.157507        1.6329e-13
     1.132428       2.66745e-13


1.1324283825817105

In [47]:
np.exp(s)

array([0.51267073, 0.58954816, 0.65282969, 0.71208998, 0.73835182,
       0.78755364, 0.79940022, 1.13242838, 1.1575072 , 1.19503755,
       1.23444835, 1.31543928, 1.53654668, 2.17287252, 2.6351411 ])