In [313]:
import numpy as np
import scipy.stats as stats
import mean_wrapper as mw
from GmGM import Dataset, GmGM

In [314]:
# Test dataset
ds = [100, 100]
data = np.random.normal(size=ds, scale=1)

# Derived parameters
L = len(ds)
d_slashes = [np.prod(ds[:ell] + ds[ell + 1:]) for ell in range(L)]

# Fixed parameters
# Covs = {
#     f'Axis {ell}': 10 * np.ones((ds[ell], ds[ell])) / ds[ell] + np.eye(ds[ell])
#     for ell in range(L)
# }
# Chose df=2d_ell to make the covariance matrix "sufficiently invertible"
# and divided by df because Wishart gets upscaled by df, want to counteract.
Covs = {
    f'Axis {ell}': stats.wishart.rvs(
        df=ds[ell] * 2,
        scale=np.ones((ds[ell], ds[ell])) / 1.1 + np.eye(ds[ell])
    ) / (2 * ds[ell])
    for ell in range(L)
}
# Covs = {
#     f'Axis {ell}': np.eye(ds[ell])
#     for ell in range(L)
# }
Psis = {
    axis: np.linalg.inv(Covs[axis])
    for axis in Covs
}

# Transform data to have certain dependencies
for ell in range(L):
    data = np.moveaxis(
        np.moveaxis(
            data,
            ell, -1
        ) @ np.linalg.inv(np.linalg.cholesky(Covs[f'Axis {ell}']).T),
        -1, ell
    )

# Add mean offsets to data
data += np.arange(ds[-1])
data -= 10 + data.mean()

# Parameters to estimate, give it good initial guesses
means_init = [(data - data.mean()).mean(axis=tuple([j for j in range(L) if j != ell])) for ell in range(L)]
full_mean_init = data.mean()
means = {f'Axis {i}': _.copy() for i, _ in enumerate(means_init)}
full_mean = full_mean_init
data

array([[-59.41826235, -57.79358728, -58.27086317, ...,  38.42265286,
         39.01341914,  40.97583882],
       [-59.47035279, -58.9253894 , -58.0594765 , ...,  36.174564  ,
         41.04321186,  39.4609283 ],
       [-60.31361482, -58.60932466, -56.01347907, ...,  37.15885879,
         34.24271625,  38.95957979],
       ...,
       [-59.96653271, -56.89025871, -53.43715324, ...,  37.70747188,
         36.16821051,  29.3732596 ],
       [-60.15534426, -58.48743617, -57.57112255, ...,  36.11743618,
         38.90546436,  39.7900084 ],
       [-59.75016642, -58.69930887, -57.11622452, ...,  34.21847019,
         41.54752031,  38.13697547]])

In [315]:
def null_estimator(data):
    return Psis

def gmgm_estimator(data):
    output = GmGM(
        data,
        to_keep=0.5,
        random_state=0,
        batch_size=1000,
        verbose=False,
        n_comps=50,
        threshold_method='statistical-significance',
        readonly=True
    )
    return output.precision_matrices
        

NKS = mw.NoncentralKS(gmgm_estimator, (means, full_mean), Psis)

In [316]:
dataset = Dataset(
    dataset={'data': data},
    structure={'data': tuple([f'Axis {i}' for i in range(L)])},
)
dataset

Dataset(
	data: ('Axis 0', 'Axis 1')
)
Axes(
	Axis 0: 100
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
	Axis 1: 100
		Prior: None
		Gram: Not calculated
		Eig: Not calculated
)

In [317]:
a, b = NKS.fit(dataset, verbose=True)
print(a[1], a[0]['Axis 1'])

Iteration: 1 (Change: 36.75808332319635)
Iteration: 2 (Change: 0.502480285741767)
Iteration: 3 (Change: 0.008695968482999049)
Iteration: 4 (Change: 0.0001177014000971781)
Converged in 5 iterations
Iteration: 5 (Change: 2.583334843653122e-05)
-10.000024646408352 [-49.51075036 -48.52810066 -47.43998385 -46.53401782 -45.3225614
 -44.43136768 -43.39835753 -42.54916851 -41.670686   -40.25599695
 -39.28766817 -38.41115715 -37.62853957 -36.50068327 -35.57445114
 -34.503688   -33.76714935 -32.46075297 -31.66462838 -30.68058803
 -29.43097779 -28.28151837 -27.54196972 -26.39830193 -25.30266681
 -24.64024843 -23.67630196 -22.78259296 -21.37358717 -20.24137553
 -19.4025297  -18.56169822 -17.39427742 -16.56260075 -15.62114587
 -14.65243404 -13.48464052 -12.30093937 -11.54999109 -10.43461604
  -9.36521567  -8.37040171  -7.7742786   -6.50235815  -5.27691332
  -4.65177845  -3.37587161  -2.35280725  -1.44889238  -0.47209707
   0.55444706   1.24921638   2.98568436   3.83848912   4.53045697
   5.28588981

In [318]:
(b['Axis 0'].toarray().diagonal().sum() * ds[0] + ds[1] * b['Axis 1'].toarray().diagonal().sum()) / (ds[0] * ds[1])

0.8207617950439453

In [319]:
b['Axis 1'].toarray()

array([[0.1873    , 0.09501915, 0.10362364, ..., 0.04516349, 0.01806848,
        0.01211174],
       [0.09501915, 0.29144   , 0.01464269, ..., 0.03406863, 0.01119631,
        0.03076433],
       [0.10362364, 0.01464269, 0.23606384, ..., 0.06947554, 0.0131749 ,
        0.02226917],
       ...,
       [0.04516349, 0.03406863, 0.06947554, ..., 0.46768895, 0.01402486,
        0.0165386 ],
       [0.01806848, 0.01119631, 0.0131749 , ..., 0.01402486, 0.52952176,
        0.03224508],
       [0.01211174, 0.03076433, 0.02226917, ..., 0.0165386 , 0.03224508,
        0.5324196 ]], dtype=float32)

In [320]:
Psis['Axis 1']

array([[ 2.30184936e+00, -7.73840634e-02, -1.66836604e-01, ...,
         1.73989710e-01,  1.77917233e-01, -2.39465835e-01],
       [-7.73840634e-02,  1.82995773e+00, -8.26740002e-03, ...,
        -8.85227677e-02,  4.70354092e-02, -1.24658370e-01],
       [-1.66836604e-01, -8.26740002e-03,  2.03422258e+00, ...,
        -3.70255706e-01, -2.02809472e-01,  7.43807052e-02],
       ...,
       [ 1.73989710e-01, -8.85227677e-02, -3.70255706e-01, ...,
         1.91799648e+00,  1.09758503e-03,  1.44581913e-01],
       [ 1.77917233e-01,  4.70354092e-02, -2.02809472e-01, ...,
         1.09758503e-03,  2.04003367e+00,  1.14778046e-01],
       [-2.39465835e-01, -1.24658370e-01,  7.43807052e-02, ...,
         1.44581913e-01,  1.14778046e-01,  2.07144467e+00]])