In [1]:
import numpy as np
import scipy.stats as stats

In [2]:
def vec_kron_sum(Xs: list) -> np.array:
    """Compute the Kronecker vector-sum"""
    if len(Xs) == 1:
        return Xs[0]
    elif len(Xs) == 2:
        return np.kron(Xs[0], np.ones(Xs[1].shape[0])) + np.kron(np.ones(Xs[0].shape[0]), Xs[1])
    else:
        d_slash0 = np.prod([X.shape[0] for X in Xs[1:]])
        return (
            np.kron(Xs[0], np.ones(d_slash0))
            + np.kron(np.ones(Xs[0].shape[0]), vec_kron_sum(Xs[1:]))
        )
    
def kron_sum(Xs: list) -> np.array:
    """Compute the Kronecker sum"""
    if len(Xs) == 1:
        return Xs[0]
    elif len(Xs) == 2:
        return np.kron(Xs[0], np.eye(Xs[1].shape[0])) + np.kron(np.eye(Xs[0].shape[0]), Xs[1])
    else:
        d_slash0 = np.prod([X.shape[0] for X in Xs[1:]])
        return (
            np.kron(Xs[0], np.eye(d_slash0))
            + np.kron(np.eye(Xs[0].shape[0]), kron_sum(Xs[1:]))
        )

In [3]:
# Test dataset
ds = [4, 3, 2]
data = np.arange(np.prod(ds)).reshape(ds).astype(float)
data -= data.mean() + 10
#data = np.random.normal(size=ds)
#data -= 10
#data += 5 * np.arange(ds[-1])

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
#Psis = [0 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell]) for ell in range(L)]
Psis = [stats.wishart.rvs(df=ds[ell], scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])) for ell in range(L)]

# Parameters to estimate, give it good initial guesses
means_init = [(data - data.mean()).mean(axis=tuple([j for j in range(L) if j != ell])) for ell in range(L)]
full_mean_init = data.mean()
means = [_.copy() for _ in means_init]
full_mean = full_mean_init
data

array([[[-21.5, -20.5],
        [-19.5, -18.5],
        [-17.5, -16.5]],

       [[-15.5, -14.5],
        [-13.5, -12.5],
        [-11.5, -10.5]],

       [[ -9.5,  -8.5],
        [ -7.5,  -6.5],
        [ -5.5,  -4.5]],

       [[ -3.5,  -2.5],
        [ -1.5,  -0.5],
        [  0.5,   1.5]]])

In [4]:
# NEW VERSION BASED ON QUADRATIC PROGRAMMING

# Derived parameters for our mean problem
lsum_Psis = [Psis[ell].sum(axis=1) for ell in range(L)]
sum_Psis = [lsum_Psis[ell].sum() for ell in range(L)]
sum_Psis_slashes = [sum([d_slashes[ell] / ds[ell_prime] * sum_Psis[ell] for ell in range(L) if ell != ell_prime]) for ell_prime in range(L)]

# The matrix that needs to be inverted
A = [(d_slashes[ell] * Psis[ell] + sum_Psis_slashes[ell] * np.eye(ds[ell])) for ell in range(L)]
A_inv = [np.linalg.pinv(A[ell]) for ell in range(L)]

# The data contribution
def datatrans(ell, data, Psis):
    # Sum along all axes but ell
    base = data.sum(axis=tuple([i for i in range(L) if i != ell]))
    base = Psis[ell] @ base

    for ell_prime in range(len(Psis)):
        if ell_prime == ell:
            continue
        # Sum along all axes but ell and ell_prime
        to_add = data.sum(axis=tuple([i for i in range(L) if i != ell and i != ell_prime]))
        
        # Multiply by Psi_{ell_prime} and then sum along ell_prime
        if ell_prime < ell:
            to_add = (lsum_Psis[ell_prime] @ to_add)
        else:
            to_add = (lsum_Psis[ell_prime] @ to_add.T)

        base += to_add

    return base

b_bases = [
    datatrans(ell, data, Psis)
    for ell in range(L)
]
max_cycles = 15
print("Start", means, full_mean)
for cycle in range(max_cycles):
    for ell in range(L):
        # Preliminary calculations
        #mean_sum = vec_kron_sum([means[ell_prime] for ell_prime in range(L) if ell != ell_prime]).sum() # Should always be zero
        mean_lsum = (
            vec_kron_sum([means[ell_prime] for ell_prime in range(L) if ell != ell_prime])
            @ vec_kron_sum([lsum_Psis[ell_prime] for ell_prime in range(L) if ell != ell_prime])
        )

        b = (
            d_slashes[ell] * full_mean * lsum_Psis[ell]
            + full_mean * sum_Psis[ell]
            #+ lsum_Psis[ell] * mean_sum # mean_sum should always be zero zero
            + mean_lsum
            - b_bases[ell]
        )
        A_inv_b = A_inv[ell] @ b
        means[ell] = (A_inv_b.sum() / A_inv[ell].sum()) * A_inv[ell].sum(axis=0) - A_inv_b
        
    full_mean = (data.reshape(-1) - vec_kron_sum(means)) @ vec_kron_sum(lsum_Psis) / sum(d_slashes[ell] * sum_Psis[ell] for ell in range(L))
    print(means, full_mean)

Start [array([-9., -3.,  3.,  9.]), array([-2.,  0.,  2.]), array([-0.5,  0.5])] -10.0
[array([-9., -3.,  3.,  9.]), array([-2.00000000e+00,  1.33226763e-15,  2.00000000e+00]), array([-0.5,  0.5])] -10.000000000000002
[array([-9., -3.,  3.,  9.]), array([-2.,  0.,  2.]), array([-0.5,  0.5])] -10.000000000000004
[array([-9., -3.,  3.,  9.]), array([-2.0000000e+00,  4.4408921e-16,  2.0000000e+00]), array([-0.5,  0.5])] -10.000000000000004
[array([-9., -3.,  3.,  9.]), array([-2.0000000e+00,  4.4408921e-16,  2.0000000e+00]), array([-0.5,  0.5])] -10.000000000000002
[array([-9., -3.,  3.,  9.]), array([-2.,  0.,  2.]), array([-0.5,  0.5])] -10.000000000000004
[array([-9., -3.,  3.,  9.]), array([-2.0000000e+00,  4.4408921e-16,  2.0000000e+00]), array([-0.5,  0.5])] -10.000000000000004
[array([-9., -3.,  3.,  9.]), array([-2.0000000e+00,  4.4408921e-16,  2.0000000e+00]), array([-0.5,  0.5])] -10.000000000000002
[array([-9., -3.,  3.,  9.]), array([-2.,  0.,  2.]), array([-0.5,  0.5])] -10.0

In [5]:
data.mean(), full_mean, vec_kron_sum([m for m in means]).mean()

(-10.0, -10.000000000000004, -1.4802973661668753e-16)

In [6]:
data.mean() / full_mean

0.9999999999999997

In [7]:
data - full_mean - vec_kron_sum([m for m in means]).reshape(ds)

array([[[-1.42108547e-14, -1.06581410e-14],
        [-1.42108547e-14, -1.06581410e-14],
        [-1.24344979e-14, -8.88178420e-15]],

       [[ 1.06581410e-14,  1.42108547e-14],
        [ 1.02140518e-14,  1.37667655e-14],
        [ 1.24344979e-14,  1.59872116e-14]],

       [[-3.55271368e-15,  0.00000000e+00],
        [-3.99680289e-15, -4.44089210e-16],
        [-1.77635684e-15,  1.77635684e-15]],

       [[ 1.24344979e-14,  1.59872116e-14],
        [ 1.24344979e-14,  1.59872116e-14],
        [ 1.42108547e-14,  1.77635684e-14]]])

In [8]:
_mean = data.mean(axis=tuple(x for x in range(L) if x != 1)) - full_mean
_mean - _mean.mean(), means[1], (_mean - _mean.mean()) / means[1]

(array([-2.,  0.,  2.]),
 array([-2.0000000e+00,  4.4408921e-16,  2.0000000e+00]),
 array([1., 0., 1.]))

In [9]:
_mean = data.mean(axis=tuple(x for x in range(L) if x != 0)) - full_mean
_mean - _mean.mean(), means[0], (_mean - _mean.mean()) / means[0]

(array([-9., -3.,  3.,  9.]),
 array([-9., -3.,  3.,  9.]),
 array([1., 1., 1., 1.]))