In [220]:
import numpy as np
import scipy.stats as stats
import mean_wrapper as mw
from GmGM import Dataset, GmGM

In [221]:
# Test dataset
ds = [100, 100]
data = np.random.normal(size=ds, scale=1)

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
# Psis = {
#     f'Axis {ell}': 10 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell])
#     for ell in range(L)
# }
Psis = {
    f'Axis {ell}': stats.wishart.rvs(
        df=ds[ell],
        scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])
    ) / ds[ell]
    for ell in range(L)
}
# Psis = {
#     f'Axis {ell}': np.eye(ds[ell])
#     for ell in range(L)
# }

# Transform data to have certain dependencies
for ell in range(L):
    data = np.moveaxis(
        np.moveaxis(
            data,
            ell, -1
        ) @ np.linalg.cholesky(Psis[f'Axis {ell}']).T,
        -1, ell
    )

# Add mean offsets to data
data += np.arange(ds[-1])
data -= 10 + data.mean()

# Parameters to estimate, give it good initial guesses
means_init = [(data - data.mean()).mean(axis=tuple([j for j in range(L) if j != ell])) for ell in range(L)]
full_mean_init = data.mean()
means = {f'Axis {i}': _.copy() for i, _ in enumerate(means_init)}
full_mean = full_mean_init
data

array([[-61.33356318, -57.89470627, -57.1793275 , ...,  37.05095412,
         39.00943671,  39.63939288],
       [-58.51222698, -57.51798186, -58.15777283, ...,  38.79837826,
         40.13245665,  39.36783795],
       [-59.95354382, -59.83914472, -55.78431538, ...,  39.08822421,
         39.71725236,  39.60308067],
       ...,
       [-61.12165162, -56.28375389, -57.11666534, ...,  39.27804978,
         40.75777691,  39.46577528],
       [-60.37305027, -57.18973723, -57.78763227, ...,  38.92162927,
         41.8359825 ,  38.38868514],
       [-61.93552314, -58.57922563, -56.55943459, ...,  37.19652209,
         39.31660956,  39.29997415]])

In [222]:
def null_estimator(data):
    return Psis

def gmgm_estimator(data):
    output = GmGM(
        data,
        to_keep=0.5,
        random_state=0,
        batch_size=1000,
        verbose=False,
        n_comps=50,
        threshold_method='statistical-significance',
        readonly=True
    )
    return output.precision_matrices
        

NKS = mw.NoncentralKS(gmgm_estimator, (means, full_mean), Psis)

In [223]:
dataset = Dataset(
    dataset={'data': data},
    structure={'data': tuple([f'Axis {i}' for i in range(L)])},
)
dataset

Dataset(
	data: ('Axis 0', 'Axis 1')
)
Axes(
	Axis 0: 100
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
	Axis 1: 100
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
)

In [224]:
a, b = NKS.fit(dataset, verbose=True)
print(a[1], a[0]['Axis 1'])

Iteration: 1 (Change: 112.80700021255379)
Iteration: 2 (Change: 0.0479331974525785)
Iteration: 3 (Change: 0.0008734762622852189)
Converged in 4 iterations
Iteration: 4 (Change: 2.7854614609581658e-05)
-10.000021159816269 [-50.83033846 -48.07441677 -47.18785204 -47.9367314  -45.23810681
 -45.43114633 -42.93327462 -42.04502005 -42.58024214 -40.00059635
 -39.73305306 -38.58192841 -37.5410955  -36.78176906 -35.83393581
 -35.33306763 -32.50571191 -33.24305924 -32.03171906 -31.69359397
 -28.66494651 -28.05405276 -28.65838175 -26.66173044 -27.96784449
 -23.88685245 -23.3894561  -21.78211103 -20.6844571  -19.99110919
 -20.33615546 -18.34477747 -16.48855866 -15.97118583 -16.29917772
 -14.49909041 -12.41745547 -13.5120476  -11.35713943 -10.66646783
  -9.16322527  -8.73990489  -7.97705086  -7.57758035  -4.96450829
  -2.9651595   -3.65684052  -1.95707928  -0.37524086  -0.85284251
   0.20865004   1.71980861   1.68057744   3.57524052   3.59652335
   2.9943771    6.51846832   8.53867584   7.38291154 

In [225]:
(b['Axis 0'].toarray().diagonal().sum() * ds[0] + ds[1] * b['Axis 1'].toarray().diagonal().sum()) / (ds[0] * ds[1])

0.7668465805053711

In [226]:
b['Axis 0'].toarray()

array([[0.34335104, 0.06771474, 0.0480319 , ..., 0.06942948, 0.00921197,
        0.0232182 ],
       [0.06771474, 0.3247674 , 0.05914727, ..., 0.02119167, 0.01530092,
        0.02660878],
       [0.0480319 , 0.05914727, 0.37487522, ..., 0.05696075, 0.02421396,
        0.00561598],
       ...,
       [0.06942948, 0.02119167, 0.05696075, ..., 0.36693436, 0.00891732,
        0.03361278],
       [0.00921197, 0.01530092, 0.02421396, ..., 0.00891732, 0.44192332,
        0.01611859],
       [0.0232182 , 0.02660878, 0.00561598, ..., 0.03361278, 0.01611859,
        0.3633759 ]], dtype=float32)