In [891]:
import numpy as np
import scipy.stats as stats

In [892]:
def vec_kron_sum(Xs: list) -> np.array:
    """Compute the Kronecker vector-sum"""
    if len(Xs) == 1:
        return Xs[0]
    elif len(Xs) == 2:
        return np.kron(Xs[0], np.ones(Xs[1].shape[0])) + np.kron(np.ones(Xs[0].shape[0]), Xs[1])
    else:
        d_slash0 = np.prod([X.shape[0] for X in Xs[1:]])
        return (
            np.kron(Xs[0], np.ones(d_slash0))
            + np.kron(np.ones(Xs[0].shape[0]), vec_kron_sum(Xs[1:]))
        )

In [893]:
# Test dataset
ds = [4, 3]
#data = np.arange(np.prod(ds)).reshape(ds).astype(float)
#data -= data.mean() + 10
data = np.random.normal(size=ds)
data -= 10
data += 5 * np.arange(ds[-1])

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
#Psis = [0 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell]) for ell in range(L)]
Psis = [stats.wishart.rvs(df=ds[ell], scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])) for ell in range(L)]

# Parameters to estimate
means = [np.zeros((d,)) for d in ds]
full_mean = 0
data

array([[-10.38502422,  -3.73570157,  -0.13257653],
       [ -9.17681933,  -4.86155745,  -0.21677802],
       [ -9.3473275 ,  -5.08042632,   1.15247353],
       [ -9.17186065,  -3.47267159,  -0.09709937]])

In [894]:
# # Derived parameters for our mean problem
# lsum_Psis = [Psis[ell].sum(axis=1) for ell in range(L)]
# sum_Psis = [lsum_Psis[ell].sum() for ell in range(L)]
# sum_Psis_slashes = [sum([d_slashes[ell] * sum_Psis[ell] for ell in range(L) if ell != ell_prime]) for ell_prime in range(L)]

# # The matrix that needs to be inverted
# trans = [d_slashes[ell] * Psis[ell] + sum_Psis_slashes[ell] * np.eye(ds[ell]) for ell in range(L)]
# inv_trans = [np.linalg.pinv(trans[ell]) for ell in range(L)]

# # The data contribution
# def datatrans(ell, data, Psis):
#     # Sum along all axes but ell
#     base = data.sum(axis=tuple([i for i in range(L) if i != ell]))
#     base = Psis[ell] @ base

#     for ell_prime in range(len(Psis)):
#         if ell_prime == ell:
#             continue
#         # Sum along all axes but ell and ell_prime
#         to_add = data.sum(axis=tuple([i for i in range(L) if i != ell and i != ell_prime]))
        
#         # Multiply by Psi_{ell_prime} and then sum along ell_prime
#         if ell_prime < ell:
#             to_add = (Psis[ell_prime] @ to_add).sum(axis=0)
#         else:
#             to_add = (Psis[ell_prime] @ to_add.T).sum(axis=0)

#         base += to_add

#     return base

# # The off-axis mean contribution
# def get_off_axis_contribution(ell_prime, means, lsum_Psis):
#     # These two values should always be zero when the mean constraint is satisfied
#     #mean_sums = [means[ell].sum() for ell in range(L)]
#     #mean_slash_sum = sum([mean_sums[ell] for ell in range(L) if ell != ell_prime])

#     mean_mahalanobis = (
#         vec_kron_sum([means[ell] for ell in range(L) if ell != ell_prime])
#         @ vec_kron_sum([lsum_Psis[ell] for ell in range(L) if ell != ell_prime])
#     )
#     #return mean_slash_sum * lsum_Psis[ell_prime] + mean_mahalanobis
#     return mean_mahalanobis

# # Do a cycle of coordinate descent
# max_cycles = 20
# for cycle in range(max_cycles):
#     for ell in range(L):
#         # Means
#         data_contribution = datatrans(ell, data - full_mean, Psis)
#         off_axis_contribution = get_off_axis_contribution(ell, means, lsum_Psis)
#         base = data_contribution - (off_axis_contribution).sum(axis=0)
#         means[ell] = (base) @ inv_trans[ell]

#         # Handle constraint
#         full_mean += means[ell].mean() # Offload excess values to full_mean
#         means[ell] -= means[ell].mean() # Then set to zero as required by the constraint
#         print(means[ell])

#     full_mean = (data.reshape(-1) - vec_kron_sum(means)) @ vec_kron_sum(lsum_Psis) / sum(d_slashes[ell] * sum_Psis[ell] for ell in range(L))
#     print('----')

In [895]:
# NEW VERSION BASED ON QUADRATIC PROGRAMMING

# Derived parameters for our mean problem
lsum_Psis = [Psis[ell].sum(axis=1) for ell in range(L)]
sum_Psis = [lsum_Psis[ell].sum() for ell in range(L)]
sum_Psis_slashes = [sum([d_slashes[ell] * sum_Psis[ell] for ell in range(L) if ell != ell_prime]) for ell_prime in range(L)]

# The matrix that needs to be inverted
trans = [d_slashes[ell] * Psis[ell] + sum_Psis_slashes[ell] * np.eye(ds[ell]) for ell in range(L)]
inv_trans = [np.linalg.pinv(trans[ell]) for ell in range(L)]

# The data contribution
def datatrans(ell, data, Psis):
    # Sum along all axes but ell
    base = data.sum(axis=tuple([i for i in range(L) if i != ell]))
    base = Psis[ell] @ base

    for ell_prime in range(len(Psis)):
        if ell_prime == ell:
            continue
        # Sum along all axes but ell and ell_prime
        to_add = data.sum(axis=tuple([i for i in range(L) if i != ell and i != ell_prime]))
        
        # Multiply by Psi_{ell_prime} and then sum along ell_prime
        if ell_prime < ell:
            to_add = (Psis[ell_prime] @ to_add).sum(axis=0)
        else:
            to_add = (Psis[ell_prime] @ to_add.T).sum(axis=0)

        base += to_add

    return base

b_bases = [
    datatrans(ell, data, Psis)
    for ell in range(L)
]
max_cycles = 15
for cycle in range(max_cycles):
    for ell in range(L):
        # Preliminary calculations
        mean_sum = vec_kron_sum([lsum_Psis[ell_prime] for ell_prime in range(L) if ell != ell_prime]).sum()
        mean_lsum = (
            vec_kron_sum([means[ell_prime] for ell_prime in range(L) if ell != ell_prime])
            @ vec_kron_sum([lsum_Psis[ell_prime] for ell_prime in range(L) if ell != ell_prime])
        )
        A_inv = inv_trans[ell]
        b = (
            d_slashes[ell] * full_mean * lsum_Psis[ell]
            + full_mean * sum_Psis[ell]
            + lsum_Psis[ell] * mean_sum
            + mean_lsum
            - b_bases[ell]
        )
        A_inv_b = A_inv @ b
        means[ell] = A_inv_b - (A_inv_b.sum() / A_inv.sum()) * A_inv.sum(axis=0)
        
    print(means)
    full_mean = (data.reshape(-1) - vec_kron_sum(means)) @ vec_kron_sum(lsum_Psis) / sum(d_slashes[ell] * sum_Psis[ell] for ell in range(L))

[array([ 0.01463114, -0.15557951,  0.64749921, -0.50655083]), array([ 2.89153174,  2.10054215, -4.9920739 ])]
[array([ 0.01241248, -0.01488281,  0.30569244, -0.30322211]), array([ 2.80486085,  0.15972287, -2.96458372])]
[array([ 0.01264521, -0.02964131,  0.34154656, -0.32455045]), array([ 2.81395227,  0.36330681, -3.17725908])]
[array([ 0.01262079, -0.0280932 ,  0.33778561, -0.3223132 ]), array([ 2.81299862,  0.3419517 , -3.15495031])]
[array([ 0.01262335, -0.02825559,  0.33818012, -0.32254788]), array([ 2.81309865,  0.34419176, -3.15729041])]
[array([ 0.01262309, -0.02823856,  0.33813874, -0.32252326]), array([ 2.81308816,  0.34395679, -3.15704495])]
[array([ 0.01262311, -0.02824035,  0.33814308, -0.32252584]), array([ 2.81308926,  0.34398143, -3.15707069])]
[array([ 0.01262311, -0.02824016,  0.33814262, -0.32252557]), array([ 2.81308915,  0.34397885, -3.15706799])]
[array([ 0.01262311, -0.02824018,  0.33814267, -0.3225256 ]), array([ 2.81308916,  0.34397912, -3.15706828])]
[array([ 0

In [896]:
data.mean(), full_mean, vec_kron_sum([m for m in means]).mean()

(-4.543780751536123, -5.66582949203678, -1.8503717077085943e-16)

In [897]:
data.mean() / full_mean

0.8019621412755756

In [898]:
data - full_mean

array([[-4.71919473,  1.93012792,  5.53325296],
       [-3.51098983,  0.80427204,  5.44905147],
       [-3.68149801,  0.58540318,  6.81830303],
       [-3.50603116,  2.1931579 ,  5.56873013]])

In [899]:
data - full_mean - vec_kron_sum([m for m in means]).reshape(ds)

array([[-7.544907  ,  1.57352571,  8.6776981 ],
       [-6.29583881,  0.48853312,  8.6343599 ],
       [-6.83272983, -0.09671858,  9.63722861],
       [-5.99659472,  2.1717044 ,  9.04832397]])

In [900]:
(vec_kron_sum([m for m in means])).reshape(ds)

array([[ 2.82571227,  0.3566022 , -3.14444514],
       [ 2.78484898,  0.31573892, -3.18530843],
       [ 3.15123182,  0.68212176, -2.81892559],
       [ 2.49056356,  0.0214535 , -3.47959385]])

In [901]:
data.mean(axis=tuple(x for x in range(L) if x != 1)) - full_mean, means[1]

(array([-3.85442843,  1.37824026,  5.8423344 ]),
 array([ 2.81308916,  0.34397909, -3.15706825]))

In [902]:
data.mean(axis=tuple(x for x in range(L) if x != 0)) - full_mean, means[0]

(array([0.91472872, 0.91411123, 1.24073607, 1.41861895]),
 array([ 0.01262311, -0.02824018,  0.33814266, -0.3225256 ]))