This is the notebook in which we prototyped the algorithm.  It is left in our repository for completeness, but does not correspond to any of the experiments we use in the paper.

In [1]:
import numpy as np
from scipy import linalg

# Riemannian Building Blocks

In [2]:
# Formulas here came from:
# "Conic geometric optimisation on the manifold of positive definite matrices"
# By Sra & Hosseini

def pd_gradient(
    grad: np.ndarray,
    X: np.ndarray
) -> np.ndarray:
    """
    Transform Euclidian gradient `grad` at point `X`
    into the Riemannian gradient in the relevant geometry.

    This assumes the gradient is symmetric, which it is in our case.

    Otherwise, need to change `grad` to `(grad + grad.T) / 2`.
    """

    return X @ grad @ X

def pd_retraction(
    X: np.ndarray,
    grad: np.ndarray,
) -> np.ndarray:
    """
    Perform a retraction step on the manifold.
    """

    return X @ linalg.expm(linalg.inv(X) @ grad)

def pd_rectraction_definition(
    X: np.ndarray,
    grad: np.ndarray,
) -> np.ndarray:
    """
    Perform a retraction step on the manifold, using the definition.
    """
    X_sqrt = linalg.sqrtm(X)
    X_sqrt_inv = linalg.inv(X_sqrt)
    return X_sqrt @ linalg.expm(X_sqrt_inv @ grad @ X_sqrt_inv) @ X_sqrt

def pd_grad_retraction(
    X: np.ndarray,
    grad: np.ndarray,
) -> np.ndarray:
    """
    Transform gradient and then retract in one go

    Assumes gradient is symmetric, otherwise must symmetrize
    """

    return X @ linalg.expm(grad @ X)

Below, we check if everything works as intended.

In [3]:
# Generate random positive definite matrix
n = 5
X = np.random.rand(n, n)
S = X @ X.T

# Generate random symmetric matrix
grad = np.random.rand(n, n)
grad = (grad + grad.T) / 2

# Transform gradient
grad_pd = pd_gradient(grad, S)

# Perform retraction
S_new = pd_retraction(S, grad_pd)
S_new_def = pd_rectraction_definition(S, grad_pd)
S_new_grad = pd_grad_retraction(S, grad)

assert np.allclose(S_new, S_new_def), \
    "Uh-oh, retraction methods don't match!"

assert np.allclose(S_new, S_new_grad), \
    "Uh-oh, retraction methods don't match!"

# Simultaneous Diagonalization

In [4]:
def sim_diag(
    X: np.ndarray,
    Y: np.ndarray
) -> tuple[
    np.ndarray,
    np.ndarray
]:
    """
    Simultaneously diagonalize two positive definite matrices by congruence
    (Technically only X must be positive definite, the other must be symmetric)

    Returns P, D such that X = PP^T and Y = PDP^T
    """

    X_sqrt = linalg.sqrtm(X)
    X_sqrt_inv = linalg.inv(X_sqrt)
    S = X_sqrt_inv @ Y @ X_sqrt_inv
    D, V = linalg.eigh(S)
    P = X_sqrt @ V

    return P, D

Test if this works as intended.

In [5]:
# Generate two random positive definite matrices
n = 5
X = np.random.rand(n, 3 * n)
Y = np.random.rand(n, 3 * n)
X = X @ X.T
Y = Y @ Y.T

# Simultaneously diagonalize
P, D = sim_diag(X, Y)

# Check that the matrices are indeed diagonalized
assert np.allclose(P @ P.T, X), \
    "X not diagonalized"
assert np.allclose(P @ np.diag(D) @ P.T, Y), \
    "Y not diagonalized"

# Strong Product Model Building Blocks

In [6]:
def blockwise_trace_ks(
    Lam: np.ndarray,
    D: np.ndarray
) -> np.ndarray:
    """
    Computes tr_d2[(Lam kronsum D)^-1]

    Lam, D are diagonal matrices
    """

    internal = 1 / (Lam[:, None] + D[None, :])
    return internal.sum(axis=1)

def stridewise_trace_ks(
    Lam: np.ndarray,
    D: np.ndarray
) -> np.ndarray:
    """
    Computes tr^d1[(Lam kronsum D)^-1]

    Lam, D are diagonal matrices
    """

    internal = 1 / (Lam[:, None] + D[None, :])
    return internal.sum(axis=0)

def stridewise_trace_mult(
    Lam: np.ndarray,
    D: np.ndarray
) -> np.ndarray:
    """
    Computes tr^d1[(Lam kronsum D)^-1 * (Lam kronprod I)]

    Lam, D are diagonal matrices
    """

    internal = Lam[:, None] / (Lam[:, None] * D[None, :])
    return internal.sum(axis=0)

def vec_kron_sum(Xs: list) -> np.array:
    """Compute the Kronecker vector-sum"""
    if len(Xs) == 1:
        return Xs[0]
    elif len(Xs) == 2:
        return np.kron(Xs[0], np.ones(Xs[1].shape[0])) + np.kron(np.ones(Xs[0].shape[0]), Xs[1])
    else:
        d_slash0 = np.prod([X.shape[0] for X in Xs[1:]])
        return (
            np.kron(Xs[0], np.ones(d_slash0))
            + np.kron(np.ones(Xs[0].shape[0]), vec_kron_sum(Xs[1:]))
        )

In [7]:
def NLL(
    Psi_1: np.ndarray,
    Theta: np.ndarray,
    Psi_2w: np.ndarray,
    S_2: np.ndarray,
    Data: np.ndarray
) -> float:
    Lam, _ = linalg.eigh(Psi_1)
    _, D = sim_diag(Theta, Psi_2w)
    _, detTheta = np.linalg.slogdet(Theta)

    if Lam.min() <= 0 or D.min() <= 0:
        # Don't allow non-positive-definite matrices
        return np.inf

    logdets = - np.log(vec_kron_sum([Lam, D])).sum() - Psi_1.shape[0] * detTheta
    traces = np.trace(Psi_2w @ S_2) + np.trace(Psi_1 @ Data @ Theta @ Data.T)

    return logdets + traces


In [8]:
def armijo(
    new_value: float,
    old_value: float,
    eta: float,
    beta: float,
    grad_norm: tuple[np.ndarray]
) -> float:
    """
    Armijo line search
    """
    return new_value <= old_value - eta * beta * grad_norm

In [9]:
def gradients(
    X: np.ndarray,
    S_2: np.ndarray,
    Psi_1: np.ndarray,
    V: np.ndarray,
    Lam: np.ndarray,
    Theta: np.ndarray,
    P: np.ndarray,
    Psi_2w: np.ndarray,
    D: np.ndarray,
    rho_psi_1: float,
    rho_psi_2w: float,
    rho_theta: float,
) -> tuple[
    np.ndarray,
    np.ndarray,
    np.ndarray,
]:
    """
    Computes G * Gamma, where G is the Euclidean gradient at Gamma,
    and Gamma is each of our three parameters.
    """

    P_inv = linalg.inv(P)

    Psi_1_core = blockwise_trace_ks(Lam, D)
    Psi_2w_core = stridewise_trace_ks(Lam, D)
    Theta_core = stridewise_trace_mult(Lam, D)

    XTheta = X @ Theta
    XtPsi = X.T @ Psi_1

    log_Psi_1 = (V * np.log(Lam)) @ V.T
    log_Psi_2w = linalg.logm(Psi_2w)
    log_Theta = linalg.logm(Theta)

    Psi_1_grad = - (V * Psi_1_core) @ V.T + XTheta @ X.T + 2 * rho_psi_1 * np.linalg.inv(Psi_1) @ log_Psi_1
    #assert (np.allclose(- (V * Psi_1_core) @ V.T, (- (V * Psi_1_core) @ V.T).T))
    #assert (np.allclose(XTheta @ X.T , (XTheta @ X.T).T))
    #assert (np.allclose(2 * rho_psi_1 * np.linalg.inv(Psi_1) @ log_Psi_1 , (2 * rho_psi_1 * np.linalg.inv(Psi_1) @ log_Psi_1).T))
    Psi_2w_grad = - (P_inv.T * Psi_2w_core) @ P_inv + S_2 + 2 * rho_psi_2w * np.linalg.inv(Psi_2w) @ log_Psi_2w
    Theta_grad = - (P_inv.T * Theta_core) @ P_inv + XtPsi @ X + 2 * rho_theta * np.linalg.inv(Theta) @ log_Theta

    Psi_1_grad = (Psi_1_grad + Psi_1_grad.T) / 2
    Psi_2w_grad = (Psi_2w_grad + Psi_2w_grad.T) / 2
    Theta_grad = (Theta_grad + Theta_grad.T) / 2

    return Psi_1_grad, Theta_grad, Psi_2w_grad


def gradients_shifted(
    X: np.ndarray,
    S_2: np.ndarray,
    Psi_1: np.ndarray,
    V: np.ndarray,
    Lam: np.ndarray,
    Theta: np.ndarray,
    P: np.ndarray,
    Psi_2w: np.ndarray,
    D: np.ndarray,
    rho_psi_1: float,
    rho_psi_2w: float,
    rho_theta: float,
) -> tuple[
    np.ndarray,
    np.ndarray,
    np.ndarray,
]:
    """
    Computes G * Gamma, where G is the Euclidean gradient at Gamma,
    and Gamma is each of our three parameters.
    """

    P_inv = linalg.inv(P)

    Psi_1_core = blockwise_trace_ks(Lam, D) * Lam
    Psi_2w_core = stridewise_trace_ks(Lam, D) * D
    Theta_core = stridewise_trace_mult(Lam, D)

    XTheta = X @ Theta
    XtPsi = X.T @ Psi_1

    log_Psi_1 = (V * np.log(Lam)) @ V.T
    log_Psi_2w = linalg.logm(Psi_2w)
    log_Theta = linalg.logm(Theta)

    Psi_1_grad = - (V * Psi_1_core) @ V.T + XTheta @ XtPsi + 2 * rho_psi_1 * log_Psi_1
    Psi_2w_grad = - (P_inv.T * Psi_2w_core) @ P.T + S_2 @ Psi_2w + 2 * rho_psi_2w * log_Psi_2w
    Theta_grad = - (P_inv.T * Theta_core) @ P.T + XtPsi @ XTheta + 2 * rho_theta * log_Theta

    return Psi_1_grad, Theta_grad, Psi_2w_grad


In [10]:
# Generate three random positive definite matrices
n = 5
samples = 10
X = np.random.rand(n, samples * n)
Y = np.random.rand(n + 1, samples * n)
Z = np.random.rand(n + 1, samples * n)
X = X @ X.T / (n * samples)
Y = Y @ Y.T / (n * samples)
Z = Z @ Z.T / (n * samples)

# Eigendecompose the first matrix
Lam, V = linalg.eigh(X)

# Simultaneously diagonalize the other two
P, D = sim_diag(Y, Z)

# Generate random data matrix
Data = np.random.rand(n, n+1)
S_2 = Data.T @ Data

# Check that the gradients_shifted function runs
A, B, C = gradients_shifted(
    Data,
    S_2,
    X,
    V,
    Lam,
    Y,
    P,
    Z,
    D,
    0.1,
    0.1,
    0.1
)
print(A.shape, B.shape, C.shape)
A

(5, 5) (6, 6) (6, 6)


array([[1.91995063, 3.56222315, 3.36246415, 3.46802192, 3.32214367],
       [4.26044885, 4.07600713, 4.72198365, 4.92783979, 4.82463725],
       [3.08362713, 3.61089228, 2.347799  , 3.5378608 , 3.39123785],
       [3.56097982, 4.22641289, 3.96495407, 3.12076385, 3.99606477],
       [4.00400513, 4.85399415, 4.46478607, 4.68464317, 3.5382007 ]])

In [11]:
lr = 0.01
print(X @ linalg.expm(lr * A))
print(Y @ linalg.expm(lr * B))
print(Z @ linalg.expm(lr * C))

[[0.35409367 0.30435395 0.27052285 0.29553409 0.2962426 ]
 [0.30435395 0.44671222 0.34719354 0.36117632 0.33003236]
 [0.27052285 0.34719354 0.410154   0.32956301 0.32977825]
 [0.29553409 0.36117632 0.32956301 0.4295315  0.33540936]
 [0.2962426  0.33003236 0.32977825 0.33540936 0.40088303]]
[[0.48460813 0.36847968 0.40350571 0.37347465 0.35519743 0.37829382]
 [0.36847968 0.4122708  0.3685833  0.3263359  0.3329526  0.34391307]
 [0.40350571 0.3685833  0.47696326 0.37101104 0.38411322 0.35740317]
 [0.37347465 0.3263359  0.37101104 0.39794199 0.35675886 0.32969898]
 [0.35519743 0.3329526  0.38411322 0.35675886 0.40907015 0.33658295]
 [0.37829382 0.34391307 0.35740317 0.32969898 0.33658295 0.42906397]]
[[0.34980305 0.28150628 0.2717028  0.3126787  0.26431591 0.22925258]
 [0.28150628 0.37233941 0.30506463 0.33195329 0.2623392  0.27647295]
 [0.2717028  0.30506463 0.38312323 0.33973025 0.28020184 0.28565458]
 [0.3126787  0.33195329 0.33973025 0.44165262 0.32379263 0.31789422]
 [0.26431591 0.262

In [12]:
# Generate three random positive definite matrices
n = 100
samples = n + 2
# X = np.random.rand(n, samples * n)
# Y = np.random.rand(n + 1, samples * n)
# Z = np.random.rand(n + 1, samples * n)
# X = X @ X.T# / (n * samples)
# Y = Y @ Y.T# / (n * samples)
# Z = Z @ Z.T# / (n * samples)
X = np.eye(n)
Y = np.eye(samples)
Z = 2 * np.eye(samples)

# Generate random data matrix
Data = np.random.rand(n, samples)
S_2 = Data.T @ Data

for i in range(100):
    # Eigendecompose the first matrix
    Lam, V = linalg.eigh(X)

    # Simultaneously diagonalize the other two
    P, D = sim_diag(Y, Z)

    # Check that the gradients_shifted function runs
    A, B, C = gradients_shifted(
        Data,
        S_2,
        X,
        V,
        Lam,
        Y,
        P,
        Z,
        D,
        0.00001,
        0.00001,
        0.00001
    )
    # A, B, C = gradients(
    #     Data,
    #     S_2,
    #     X,
    #     V,
    #     Lam,
    #     Y,
    #     P,
    #     Z,
    #     D,
    #     1,
    #     1,
    #     1
    # )
    # A @= linalg.inv(X)
    # B @= linalg.inv(Y)
    # C @= linalg.inv(Z)

    # print('----')
    # test, _ = np.linalg.eigh(X)
    # print(test)
    # test, _ = np.linalg.eigh(Y)
    # print(test)
    # test, _ = np.linalg.eigh(Z)
    # print(test)
    # print('----')

    # print(A @ np.linalg.inv(X))
    # print(A_)


    old_NLL = NLL(X, Y, Z, S_2, Data)
    line_search_init = 1
    decrease = 0.5
    beta = 0.5
    lr = line_search_init

    old_X = X
    old_Y = Y
    old_Z = Z

    X = old_X @ linalg.expm(-lr * A)
    Y = old_Y @ linalg.expm(-lr * B)
    Z = old_Z @ linalg.expm(-lr * C)

    grad_norm = (
        np.trace(np.linalg.matrix_power(A @ np.linalg.inv(old_X), 2))
        + np.trace(np.linalg.matrix_power(B @ np.linalg.inv(old_Y), 2))
        + np.trace(np.linalg.matrix_power(C @ np.linalg.inv(old_Z), 2))
    )

    old_value = NLL(old_X, old_Y, old_Z, S_2, Data)

    try:
        new_value = NLL(X, Y, Z, S_2, Data)
    except:
        new_value = np.inf

    converged = False
    while True:
        if new_value < old_value - lr * beta * grad_norm:
            # Check if all matrices are positive definite
            X_posdef = (np.linalg.eigh(X)[0] > 0).all()
            Y_posdef = (np.linalg.eigh(Y)[0] > 0).all()
            Z_posdef = (np.linalg.eigh(Z)[0] > 1).all()
            if X_posdef and Y_posdef and Z_posdef:
                break
        lr *= decrease
        X = old_X @ linalg.expm(-lr * A)
        Y = old_Y @ linalg.expm(-lr * B)
        Z = old_Z @ linalg.expm(-lr * C)
        try:
            new_value = NLL(X, Y, Z, S_2, Data)
        except:
            new_value = np.inf
        if lr < 1e-10:
            X = old_X
            Y = old_Y
            Z = old_Z
            converged = True
            break
    
    if converged:
        print(old_value)
        print("Converged")
        break
    else:
        #print(lr)
        print(new_value)

    #print(X)
    #print(Y)
    # print(Z)

    # Always stays in the positive definite cone
    # print((np.linalg.eigh(X)[0] > 0).all())
    # print((np.linalg.eigh(Y)[0] > 0).all())
    # print((np.linalg.eigh(Z)[0] > 0).all())

-4543.588094440915
-4715.743394162065
-4795.956639188187
-4815.388394380641
-4816.594078067642
-4817.196608608265
-4817.271907557166
-4817.290731730178
-4817.300143740756
-4817.302496734592
-4817.302496734592
Converged


In [13]:
print(X, Y, Z)
print(np.linalg.cond(X), np.linalg.cond(Y), np.linalg.cond(Z))

[[ 1.00117788 -0.00248379 -0.0028727  ... -0.00280856 -0.00278121
  -0.00254958]
 [-0.00248379  1.00101376 -0.00270991 ... -0.00274895 -0.00269723
  -0.00291966]
 [-0.0028727  -0.00270991  1.00090429 ... -0.00290925 -0.00282982
  -0.00275604]
 ...
 [-0.00280856 -0.00274895 -0.00290925 ...  1.00093895 -0.00292706
  -0.00289389]
 [-0.00278121 -0.00269723 -0.00282982 ... -0.00292706  1.00074864
  -0.00274208]
 [-0.00254958 -0.00291966 -0.00275604 ... -0.00289389 -0.00274208
   1.00107035]] [[ 1.00350699 -0.00283769 -0.00298779 ... -0.00320679 -0.00312044
  -0.00273313]
 [-0.00283769  1.00305106 -0.00294913 ... -0.00319252 -0.00283918
  -0.00272951]
 [-0.00298779 -0.00294913  1.00269614 ... -0.00371704 -0.00317598
  -0.00298648]
 ...
 [-0.00320679 -0.00319252 -0.00371704 ...  1.00250227 -0.00330761
  -0.00304322]
 [-0.00312044 -0.00283918 -0.00317598 ... -0.00330761  1.00342546
  -0.00278505]
 [-0.00273313 -0.00272951 -0.00298648 ... -0.00304322 -0.00278505
   1.00361626]] [[ 2.00583282 -0

# Final test

In [14]:
# Check if our `strong-product-model.py` implementation is correct
%load_ext autoreload
%autoreload 2
from strong_product_model import strong_product_model

In [15]:
#n = 500
#samples = n + 2
#Data = np.random.rand(n, samples)

# Using same data from test implementation for comparison

strong_product_model(
    data_matrix=Data,
    rho_rows=0.0001,
    rho_cols_within_rows=0.0001,
    rho_cols_between_rows=0.0001,
    verbose=True,
)

Iteration 1: -3987.9682313719522
Iteration 2: -4656.532210287609
Iteration 3: -4788.189373038928
Iteration 4: -4819.460535525576
Iteration 5: -4823.323119813105
Iteration 6: -4823.564181939269
Iteration 7: -4823.5943119249605
Iteration 8: -4823.609376723326
Iteration 9: -4823.616909073877
Iteration 10: -4823.617379844452
Iteration 11: -4823.617615229698
Iteration 12: -4823.617644652852
Iteration 13: -4823.617659364429
Iteration 14: -4823.617666720207
Iteration 15: -4823.617670398108
Iteration 16: -4823.617672237058
Iteration 17: -4823.617673156536
Iteration 18: -4823.617673271465
Iteration 19: -4823.6176733289385
Iteration 20: -4823.617673357666
Iteration 21: -4823.617673372028
Iteration 22: -4823.6176733729335
Iteration 23: -4823.617673373041
Iteration 24: -4823.617673373088 (converged)


{'rows': array([[ 1.00155615, -0.0024746 , -0.00289182, ..., -0.00282123,
         -0.00279646, -0.00254694],
        [-0.0024746 ,  1.00139461, -0.00270154, ..., -0.00274495,
         -0.00269344, -0.00294576],
        [-0.00289182, -0.00270154,  1.00129928, ..., -0.00290883,
         -0.00282696, -0.00275234],
        ...,
        [-0.00282123, -0.00274495, -0.00290883, ...,  1.00133718,
         -0.00293454, -0.00290484],
        [-0.00279646, -0.00269344, -0.00282696, ..., -0.00293454,
          1.00111537, -0.00274269],
        [-0.00254694, -0.00294576, -0.00275234, ..., -0.00290484,
         -0.00274269,  1.00145735]]),
 'cols_within_rows': array([[ 1.00415054, -0.00284068, -0.0029784 , ..., -0.00320875,
         -0.00315195, -0.00273758],
        [-0.00284068,  1.00365597, -0.00293123, ..., -0.0031882 ,
         -0.00283743, -0.0027296 ],
        [-0.0029784 , -0.00293123,  1.00332352, ..., -0.00373466,
         -0.00318064, -0.00298594],
        ...,
        [-0.00320875, -0.0