In [None]:
from collections import Counter
from pathlib import Path

import netCDF4
import numpy as np
import numexpr
import pandas as pd
import xarray as xr

import plot

In [None]:
def index_over_dim(array, reduce_dim, index_array):
    assert len(array.shape) == len(index_array.shape) + 1
    
    non_reduced_indices = list(np.mgrid[tuple(range(0, i) for i in index_array.shape)])
    indices = [None] * len(array.shape)
    for i, dim in enumerate(array.dims):
        if dim == reduce_dim:
            indices[i] = index_array.values
        else:
            index_array.dims.index(dim)
            indices[i] = non_reduced_indices[index_array.dims.index(dim)]
    indices = tuple(indices)
    
    dims = tuple(d for d in array.dims if d != reduce_dim)
    coords = {n: v for n, v in array.coords.items() if reduce_dim not in v.dims}
    
    return xr.DataArray(
        array.values[indices],
        dims=dims,
        coords=coords,
    )

In [None]:
#sweep_pth = Path("../models/sfa_mri_cad/gexp_sweep.nc")
bic_pth = Path("../models/sfa_mri_cad/parameter_sweep_coarse-bics.nc")

In [None]:
bics = xr.open_dataset(str(bic_pth)).load()
bics

In [None]:
bics_nona = bics.sel(model=bics.to_array().isnull().sum('variable') == 0)
bics_array = bics_nona.set_index(model=['l_gexp', 'l_mri', 'alpha']).unstack('model')
min_bic = bics_array['bic'].min().values.item()
bics_array

In [None]:
sel_alpha_idx = bics_array['bic'].fillna(bics_array['bic'].max()).argmin('alpha')
plot.heatmap(index_over_dim(bics_array['bic'], 'alpha', sel_alpha_idx), zlim=[min_bic, bics.attrs['empty_model_bic']])

In [None]:
plot.heatmap(index_over_dim(bics_array['deviance_gexp'], 'alpha', sel_alpha_idx))

In [None]:
plot.heatmap(index_over_dim(bics_array['deviance_mri'], 'alpha', sel_alpha_idx))

In [None]:
plot.heatmap(xr.DataArray(np.asarray(sel_alpha_idx, dtype='f'), sel_alpha_idx.coords), cmap='Accent')

In [None]:
plot.heatmap(index_over_dim(bics_array['sparsity_gexp'], 'alpha', sel_alpha_idx))

In [None]:
plot.heatmap(index_over_dim(bics_array['sparsity_mri'], 'alpha', sel_alpha_idx))