In [1]:
import numpy as np
import scipy.stats as stats
import mean_wrapper as mw
from GmGM import Dataset, GmGM

In [2]:
# Test dataset
ds = [100, 100]
#data = np.arange(np.prod(ds)).reshape(ds).astype(float)
#data -= data.mean() + 10
data = np.random.normal(size=ds)
data += np.arange(ds[-1])
data -= 10 + data.mean()

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
# Psis = {
#     f'Axis {ell}': 0 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell])
#     for ell in range(L)
# }
# Psis = {
#     f'Axis {ell}': stats.wishart.rvs(
#         df=ds[ell],
#         scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])
#     )
#     for ell in range(L)
# }
Psis = {
    f'Axis {ell}': np.eye(ds[ell]) / 2
    for ell in range(L)
}

# Parameters to estimate, give it good initial guesses
means_init = [(data - data.mean()).mean(axis=tuple([j for j in range(L) if j != ell])) for ell in range(L)]
full_mean_init = data.mean()
means = {f'Axis {i}': _.copy() for i, _ in enumerate(means_init)}
full_mean = full_mean_init
data

array([[-58.65324695, -58.8786048 , -57.45131409, ...,  36.67527566,
         35.92263661,  38.23170968],
       [-57.60140253, -60.49758969, -55.90592984, ...,  37.86799517,
         37.87879526,  40.55281186],
       [-58.59807607, -58.36343558, -57.24132604, ...,  38.33453366,
         38.58483794,  42.43846071],
       ...,
       [-61.94421016, -57.68349066, -57.73913815, ...,  38.84343661,
         39.58780226,  40.53957398],
       [-58.7840125 , -58.09439295, -56.64092186, ...,  37.24521965,
         38.42705689,  38.64528061],
       [-57.56114089, -56.86855379, -56.74806664, ...,  37.45919375,
         38.60477659,  39.1750386 ]])

In [3]:
def null_estimator(data):
    return Psis

def gmgm_estimator(data):
    output = GmGM(
        data,
        to_keep=0.5,
        random_state=0,
        batch_size=1000,
        verbose=False,
        n_comps=50,
        threshold_method='statistical-significance',
        readonly=True
    )
    return output.precision_matrices
        

NKS = mw.NoncentralKS(gmgm_estimator, (means, full_mean), Psis)

In [4]:
dataset = Dataset(
    dataset={'data': data},
    structure={'data': tuple([f'Axis {i}' for i in range(L)])},
)
dataset

Dataset(
	data: ('Axis 0', 'Axis 1')
)
Axes(
	Axis 1: 100
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
	Axis 0: 100
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
)

In [5]:
a, b = NKS.fit(dataset, verbose=True)
print(a[1], a[0]['Axis 1'])
print(b['Axis 0'].toarray())

Iteration: 1 (Change: 0.0825608283420819)
Iteration: 2 (Change: 0.001066696435467241)
Converged in 3 iterations
Iteration: 3 (Change: 3.7336270686879036e-05)
-9.999965291574318 [-49.40894029 -48.49414514 -47.6103787  -46.46841316 -45.3690133
 -44.54098409 -43.55846779 -42.49040637 -41.66780265 -40.38136107
 -39.54504059 -38.56317478 -37.36952791 -36.4378125  -35.28165559
 -34.21258297 -33.48786563 -32.4745594  -31.4998457  -30.66666208
 -29.49572801 -28.39495031 -27.69546714 -26.39141887 -25.46535823
 -24.46292048 -23.45900213 -22.40793424 -21.45720014 -20.50141449
 -19.41722025 -18.48456599 -17.54199316 -16.4502485  -15.43740651
 -14.46409903 -13.48619317 -12.58550998 -11.67175509 -10.56430879
  -9.35746787  -8.71581403  -7.48148542  -6.57729285  -5.54106915
  -4.32998414  -3.4872884   -2.4225331   -1.64670358  -0.4508701
   0.42334531   1.51080854   2.66811325   3.51588963   4.5430995
   5.52712035   6.54318385   7.45971031   8.54866533   9.45311503
  10.53762969  11.36788991  12.745

In [6]:
b['Axis 1'].toarray()

array([[0.28619018, 0.00096461, 0.01379754, ..., 0.03659574, 0.0115982 ,
        0.02690401],
       [0.00096461, 0.2546686 , 0.03930949, ..., 0.01427553, 0.00460409,
        0.00694059],
       [0.01379754, 0.03930949, 0.27948195, ..., 0.04728876, 0.00329339,
        0.01195691],
       ...,
       [0.03659574, 0.01427553, 0.04728876, ..., 0.27319497, 0.00693736,
        0.01640781],
       [0.0115982 , 0.00460409, 0.00329339, ..., 0.00693736, 0.27186447,
        0.0341961 ],
       [0.02690401, 0.00694059, 0.01195691, ..., 0.01640781, 0.0341961 ,
        0.24892157]], dtype=float32)