In [805]:
import numpy as np
import scipy.stats as stats

In [806]:
def vec_kron_sum(Xs: list) -> np.array:
    """Compute the Kronecker vector-sum"""
    if len(Xs) == 1:
        return Xs[0]
    elif len(Xs) == 2:
        return np.kron(Xs[0], np.ones(Xs[1].shape[0])) + np.kron(np.ones(Xs[0].shape[0]), Xs[1])
    else:
        d_slash0 = np.prod([X.shape[0] for X in Xs[1:]])
        return (
            np.kron(Xs[0], np.ones(d_slash0))
            + np.kron(np.ones(Xs[0].shape[0]), vec_kron_sum(Xs[1:]))
        )
    
def kron_sum(Xs: list) -> np.array:
    """Compute the Kronecker sum"""
    if len(Xs) == 1:
        return Xs[0]
    elif len(Xs) == 2:
        return np.kron(Xs[0], np.eye(Xs[1].shape[0])) + np.kron(np.eye(Xs[0].shape[0]), Xs[1])
    else:
        d_slash0 = np.prod([X.shape[0] for X in Xs[1:]])
        return (
            np.kron(Xs[0], np.eye(d_slash0))
            + np.kron(np.eye(Xs[0].shape[0]), kron_sum(Xs[1:]))
        )

In [807]:
# Test dataset
ds = [4, 3, 2]
data = np.arange(np.prod(ds)).reshape(ds).astype(float)
data -= data.mean() + 10
#data = np.random.normal(size=ds)
#data -= 10
#data += 5 * np.arange(ds[-1])

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
#Psis = [0 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell]) for ell in range(L)]
Psis = [stats.wishart.rvs(df=ds[ell], scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])) for ell in range(L)]

# Parameters to estimate, give it good initial guesses
means_init = [(data - data.mean()).mean(axis=tuple([j for j in range(L) if j != ell])) for ell in range(L)]
full_mean_init = data.mean()
means = [_.copy() for _ in means_init]
full_mean = full_mean_init
data

array([[[-21.5, -20.5],
        [-19.5, -18.5],
        [-17.5, -16.5]],

       [[-15.5, -14.5],
        [-13.5, -12.5],
        [-11.5, -10.5]],

       [[ -9.5,  -8.5],
        [ -7.5,  -6.5],
        [ -5.5,  -4.5]],

       [[ -3.5,  -2.5],
        [ -1.5,  -0.5],
        [  0.5,   1.5]]])

In [808]:
# NEW VERSION BASED ON QUADRATIC PROGRAMMING

# Derived parameters for our mean problem
lsum_Psis = [Psis[ell].sum(axis=1) for ell in range(L)]
sum_Psis = [lsum_Psis[ell].sum() for ell in range(L)]
sum_Psis_slashes = [sum([d_slashes[ell] / ds[ell_prime] * sum_Psis[ell] for ell in range(L) if ell != ell_prime]) for ell_prime in range(L)]

# The matrix that needs to be inverted
A = [(d_slashes[ell] * Psis[ell] + sum_Psis_slashes[ell] * np.eye(ds[ell])) for ell in range(L)]
A_inv = [np.linalg.pinv(A[ell]) for ell in range(L)]

# The data contribution
def datatrans(ell, data, Psis):
    # Sum along all axes but ell
    base = data.sum(axis=tuple([i for i in range(L) if i != ell]))
    base = Psis[ell] @ base

    for ell_prime in range(len(Psis)):
        if ell_prime == ell:
            continue
        # Sum along all axes but ell and ell_prime
        to_add = data.sum(axis=tuple([i for i in range(L) if i != ell and i != ell_prime]))
        
        # Multiply by Psi_{ell_prime} and then sum along ell_prime
        if ell_prime < ell:
            to_add = (lsum_Psis[ell_prime] @ to_add)
        else:
            to_add = (lsum_Psis[ell_prime] @ to_add.T)

        base += to_add

    return base

b_bases = [
    datatrans(ell, data, Psis)
    for ell in range(L)
]
max_cycles = 15
print("Start", means, full_mean)
for cycle in range(max_cycles):
    for ell in range(L):
        # Preliminary calculations
        #mean_sum = vec_kron_sum([means[ell_prime] for ell_prime in range(L) if ell != ell_prime]).sum() # Should always be zero
        mean_lsum = (
            vec_kron_sum([means[ell_prime] for ell_prime in range(L) if ell != ell_prime])
            @ vec_kron_sum([lsum_Psis[ell_prime] for ell_prime in range(L) if ell != ell_prime])
        )
        # Shouild b be negative?  Everything seems to work better if so - prove it
        b = -(
            d_slashes[ell] * full_mean * lsum_Psis[ell]
            + full_mean * sum_Psis[ell]
            #+ lsum_Psis[ell] * mean_sum # mean_sum should always be zero zero
            + mean_lsum
            - b_bases[ell]
        )
        print(b)
        A_inv_b = A_inv[ell] @ b
        means[ell] = A_inv_b - (A_inv_b.sum() / A_inv[ell].sum()) * A_inv[ell].sum(axis=0)
        
    full_mean = (data.reshape(-1) - vec_kron_sum(means)) @ vec_kron_sum(lsum_Psis) / sum(d_slashes[ell] * sum_Psis[ell] for ell in range(L))
    print(means, full_mean)

Start [array([-9., -3.,  3.,  9.]), array([-2.,  0.,  2.]), array([-0.5,  0.5])] -10.0
[-139.23291432  645.86193346 1079.67931572 1462.62043438]
[-3428.67797197 -2787.27861353 -2171.76834398]
[-5141.28880415 -4638.9867035 ]
[array([-9., -3.,  3.,  9.]), array([-2.00000000e+00,  7.10542736e-15,  2.00000000e+00]), array([-0.5,  0.5])] -10.000000000000002
[-139.23291432  645.86193346 1079.67931572 1462.62043438]
[-3428.67797197 -2787.27861353 -2171.76834398]
[-5141.28880415 -4638.9867035 ]
[array([-9., -3.,  3.,  9.]), array([-2.00000000e+00,  7.10542736e-15,  2.00000000e+00]), array([-0.5,  0.5])] -10.000000000000005
[-139.23291432  645.86193346 1079.67931572 1462.62043438]
[-3428.67797197 -2787.27861353 -2171.76834398]
[-5141.28880415 -4638.9867035 ]
[array([-9., -3.,  3.,  9.]), array([-2.00000000e+00,  5.32907052e-15,  2.00000000e+00]), array([-0.5,  0.5])] -10.0
[-139.23291432  645.86193346 1079.67931572 1462.62043438]
[-3428.67797197 -2787.27861353 -2171.76834398]
[-5141.28880415 -4

In [809]:
data.mean(), full_mean, vec_kron_sum([m for m in means]).mean()

(-10.0, -10.0, 4.440892098500626e-16)

In [810]:
data.mean() / full_mean

1.0

In [811]:
data - full_mean

array([[[-11.5, -10.5],
        [ -9.5,  -8.5],
        [ -7.5,  -6.5]],

       [[ -5.5,  -4.5],
        [ -3.5,  -2.5],
        [ -1.5,  -0.5]],

       [[  0.5,   1.5],
        [  2.5,   3.5],
        [  4.5,   5.5]],

       [[  6.5,   7.5],
        [  8.5,   9.5],
        [ 10.5,  11.5]]])

In [812]:
data - full_mean - vec_kron_sum([m for m in means]).reshape(ds)

array([[[ 1.77635684e-14,  1.77635684e-14],
        [ 1.06581410e-14,  1.06581410e-14],
        [ 1.50990331e-14,  1.50990331e-14]],

       [[-7.99360578e-15, -7.99360578e-15],
        [-1.42108547e-14, -1.42108547e-14],
        [-9.76996262e-15, -9.76996262e-15]],

       [[ 2.66453526e-15,  2.66453526e-15],
        [-3.55271368e-15, -3.55271368e-15],
        [ 8.88178420e-16,  8.88178420e-16]],

       [[-2.66453526e-15, -2.66453526e-15],
        [-8.88178420e-15, -8.88178420e-15],
        [-3.55271368e-15, -3.55271368e-15]]])

In [813]:
(vec_kron_sum([m for m in means])).reshape(ds)

array([[[-11.5, -10.5],
        [ -9.5,  -8.5],
        [ -7.5,  -6.5]],

       [[ -5.5,  -4.5],
        [ -3.5,  -2.5],
        [ -1.5,  -0.5]],

       [[  0.5,   1.5],
        [  2.5,   3.5],
        [  4.5,   5.5]],

       [[  6.5,   7.5],
        [  8.5,   9.5],
        [ 10.5,  11.5]]])

In [814]:
_mean = data.mean(axis=tuple(x for x in range(L) if x != 1)) - full_mean
_mean - _mean.mean(), means[1], (_mean - _mean.mean()) / means[1]

(array([-2.,  0.,  2.]),
 array([-2.00000000e+00,  5.32907052e-15,  2.00000000e+00]),
 array([1., 0., 1.]))

In [815]:
_mean = data.mean(axis=tuple(x for x in range(L) if x != 0)) - full_mean
_mean - _mean.mean(), means[0], (_mean - _mean.mean()) / means[0]

(array([-9., -3.,  3.,  9.]),
 array([-9., -3.,  3.,  9.]),
 array([1., 1., 1., 1.]))

Todo: these ratios (for KS-decomp independent data) always seem to depend only on the axis sizes!  They do not depend on the variance.

Note that for KS-decomposable data, the estimated mean should always be the data itself, since it is possible to set the difference from the mean to be zero!

Hence I have likely got the math wrong somewhere, but it's close...

Removing d_slashes from the calculation of sum_slashes makes the equation work (well, ratio is -1, i.e. there is another problem)

In [816]:
left = data.reshape(-1) - vec_kron_sum(means) - full_mean
center = kron_sum(Psis)
left_init = data.reshape(-1) - vec_kron_sum(means_init) - full_mean_init
left @ center @ left.T, left_init @ center @ left_init.T

(4.275938550164428e-26, 0.0)

In [817]:
term1 = (1/2) * vec_kron_sum(means) @ center @ vec_kron_sum(means).T
term2 = (full_mean - data.reshape(-1)) @ kron_sum(Psis) @ vec_kron_sum(means).T
term1_init = (1/2) * vec_kron_sum(means_init) @ center @ vec_kron_sum(means_init).T
term2_init = (full_mean_init - data.reshape(-1)) @ kron_sum(Psis) @ vec_kron_sum(means_init).T
term1 + term2, term1_init + term2_init

(-9241.551295646817, -9241.551295646817)

In [818]:
X = np.kron(np.eye(ds[0]), np.ones(d_slashes[0]).reshape(-1, 1))

In [819]:
term1 = (1/2) * means[0].T @ A[0] @ means[0]
b_ = (full_mean + vec_kron_sum([np.ones(ds[0]), means[1]]) - data.reshape(-1)) @ kron_sum(Psis) @ X
term1_init = (1/2) * means_init[0].T @ A[0] @ means_init[0]
b_init = (full_mean_init + vec_kron_sum([np.ones(ds[0]), means_init[1]]) - data.reshape(-1)) @ kron_sum(Psis) @ X
term1 + b_ @ means[0], term1_init + b_init @ means_init[0], b_, b_init

ValueError: operands could not be broadcast together with shapes (12,) (24,) 

In [None]:
b2 = d_slashes[0] * full_mean * Psis[0].sum(axis=0) + full_mean * Psis[1].sum() - datatrans(0, data, Psis)
b3 = (full_mean + np.kron(np.ones(ds[0]), means[1]) - data.reshape(-1)) @ kron_sum(Psis) @ X
b4 = (
    (full_mean + np.kron(np.ones(ds[0]), means[1]) - data.reshape(-1))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
b5 = (
    d_slashes[0] * full_mean * Psis[0].sum(axis=0)
    + d_slashes[1] / ds[0] * full_mean * Psis[1].sum()
    - data.reshape(-1) @ kron_sum(Psis) @ X
    + Psis[0].sum(axis=0) * vec_kron_sum(means[1:]).sum() # 0
    + Psis[1].sum(axis=0) @ means[1] # 0
)
b2, b3, b4, b5

(array([ -4.36903294,  -3.36042458, -16.7221565 ]),
 array([ -3.32813157,  -2.31952321, -15.68125513]),
 array([ -3.32813157,  -2.31952321, -15.68125513]),
 array([ -3.32813157,  -2.31952321, -15.68125513]))

In [None]:
holy = (
    (full_mean + vec_kron_sum([np.ones(ds[0]), means[1]]) - data.reshape(-1))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
carp = (
    (full_mean * np.ones(ds[0] * ds[1]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
    + (vec_kron_sum([np.ones(ds[0]), means[1]]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
    - (data.reshape(-1))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
gawd = (
    (full_mean * np.ones(ds[0] * ds[1]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
dammit = (
    (full_mean * np.ones(ds[0] * ds[1]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1))).T
    + (full_mean * np.ones(ds[0] * ds[1]))
    @ (np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
fawk = (
    d_slashes[0] * full_mean * Psis[0].sum(axis=0)
    + d_slashes[1] / ds[0] * full_mean * Psis[1].sum()
)
aaaah = (
    - (data.reshape(-1))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
pain = (
    - data.reshape(-1) @ kron_sum(Psis) @ X
)
punishment = (
    (np.kron(np.ones(ds[0]), means[1]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
trick = (
    + Psis[0].sum(axis=0) * vec_kron_sum(means[1:]).sum() # 0
    + Psis[1].sum(axis=0) @ means[1] # 0
)
holy, carp, gawd, dammit, fawk, aaaah, pain, punishment, trick

(array([22.64735165, 12.47145146, 20.95868305]),
 array([22.64735165, 12.47145146, 20.95868305]),
 array([-259.7548322 , -147.90974672, -366.39938172]),
 array([-259.7548322 , -147.90974672, -366.39938172]),
 array([-259.7548322 , -147.90974672, -366.39938172]),
 array([255.38579925, 144.54932214, 349.67722523]),
 array([255.38579925, 144.54932214, 349.67722523]),
 array([1.04090137, 1.04090137, 1.04090137]),
 array([1.04090137, 1.04090137, 1.04090137]))

In [None]:
eternal = (
    (vec_kron_sum([np.ones(ds[0]), means[1]]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1)) + np.kron(np.eye(ds[0]), Psis[1].sum(axis=1))).T
)
damnation = (
    + Psis[0].sum(axis=0) * vec_kron_sum(means[1:]).sum() # 0
    + Psis[1].sum(axis=0) @ means[1] # 0
)
please = (
    (vec_kron_sum([np.ones(ds[0]), means[1]]))
    @ (np.kron(Psis[0], np.ones(ds[1]).reshape(1, -1))
)
eternal, damnation

SyntaxError: invalid syntax (2388457648.py, line 13)