In [6]:
import numpy as np
import scipy.stats as stats
import mean_wrapper as mw
from GmGM import Dataset

In [7]:
# Test dataset
ds = [4, 3, 2]
data = np.arange(np.prod(ds)).reshape(ds).astype(float)
data -= data.mean() + 10
#data = np.random.normal(size=ds)
#data -= 10
#data += 5 * np.arange(ds[-1])

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
# Psis = {
#     f'Axis {ell}': 0 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell])
#     for ell in range(L)
# }
Psis = {
    f'Axis {ell}': stats.wishart.rvs(
        df=ds[ell],
        scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])
    )
    for ell in range(L)
}

# Parameters to estimate, give it good initial guesses
means_init = [(data - data.mean()).mean(axis=tuple([j for j in range(L) if j != ell])) for ell in range(L)]
full_mean_init = data.mean()
means = {f'Axis {i}': _.copy() for i, _ in enumerate(means_init)}
full_mean = full_mean_init
data

array([[[-21.5, -20.5],
        [-19.5, -18.5],
        [-17.5, -16.5]],

       [[-15.5, -14.5],
        [-13.5, -12.5],
        [-11.5, -10.5]],

       [[ -9.5,  -8.5],
        [ -7.5,  -6.5],
        [ -5.5,  -4.5]],

       [[ -3.5,  -2.5],
        [ -1.5,  -0.5],
        [  0.5,   1.5]]])

In [8]:
def estimator(data):
    return Psis

NKS = mw.NoncentralKS(estimator, (means, full_mean), Psis)

In [9]:
dataset = Dataset(
    dataset={'data': data},
    structure={'data': tuple([f'Axis {i}' for i in range(L)])},
)
dataset

Dataset(
	data: ('Axis 0', 'Axis 1', 'Axis 2')
)
Axes(
	Axis 2: 2
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
	Axis 0: 4
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
	Axis 1: 3
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
)

In [10]:
NKS.fit(dataset)

(({'Axis 0': array([-9., -3.,  3.,  9.]),
   'Axis 1': array([-2.00000000e+00,  2.66453526e-15,  2.00000000e+00]),
   'Axis 2': array([-0.5,  0.5])},
  -10.0),
 {'Axis 0': array([[ 1.42149502,  2.10619886,  0.47247159,  2.37656316],
         [ 2.10619886,  6.64189466,  3.95781846,  7.56378769],
         [ 0.47247159,  3.95781846,  5.82515644,  8.17484838],
         [ 2.37656316,  7.56378769,  8.17484838, 15.35628867]]),
  'Axis 1': array([[0.85418576, 0.35540206, 0.42655192],
         [0.35540206, 2.02058489, 2.25828062],
         [0.42655192, 2.25828062, 2.67279143]]),
  'Axis 2': array([[2.06190564, 0.18812029],
         [0.18812029, 1.37777927]])})