In [None]:
from collections import Counter
from pathlib import Path

from IPython.display import display, Markdown
import matplotlib.pyplot
import netCDF4
import numpy as np
import numexpr
import pandas as pd
import seaborn
import xarray as xr

import plot

In [None]:
def index_over_dim(array, reduce_dim, index_array):
    assert len(array.shape) == len(index_array.shape) + 1
    
    non_reduced_indices = list(np.mgrid[tuple(range(0, i) for i in index_array.shape)])
    indices = [None] * len(array.shape)
    for i, dim in enumerate(array.dims):
        if dim == reduce_dim:
            indices[i] = index_array.values
        else:
            index_array.dims.index(dim)
            indices[i] = non_reduced_indices[index_array.dims.index(dim)]
    indices = tuple(indices)
    
    dims = tuple(d for d in array.dims if d != reduce_dim)
    coords = {n: v for n, v in array.coords.items() if reduce_dim not in v.dims}
    
    return xr.DataArray(
        array.values[indices],
        dims=dims,
        coords=coords,
    )

In [None]:
data_ds = xr.open_dataset('../data/processed/concat-data.nc')

In [None]:
sweep_pth = Path("../models/sfa_mri_cad/parameter_sweep_coarse.nc")
bic_pth = Path("../models/sfa_mri_cad/parameter_sweep_coarse-bics.nc")

In [None]:
bics = xr.open_dataset(str(bic_pth)).load()
bics

In [None]:
bics_sel = bics.sel(model=bics['n_iter'] < 9998)
bics_nona = bics_sel.sel(model=bics_sel.to_array().isnull().sum('variable') == 0)
bics_array = bics_nona.set_index(model=['l_gexp', 'l_mri', 'alpha', 'k']).unstack('model')
min_bic = bics_array['bic'].min().values.item()

Heatmaps of BIC over the parameters
===================================

Best BIC over all k and alpha, for lambda penalties of the datatypes.

In [None]:
plot.heatmap(
    bics_array['bic'].fillna(bics_array['bic'].max()).min(['alpha', 'k']),
    zlim=[min_bic, bics.attrs['empty_model_bic']],
    zlabel='BIC',
)

In [None]:
for k in bics_array['k'].values:
    display(Markdown(f"### k={k} ###"))
    ba = bics_array.sel(k=k).drop('k')
    alpha_idx = ba['bic'].fillna(ba['bic'].max()).argmin('alpha')
    plot.heatmap(
        index_over_dim(ba['bic'], 'alpha', alpha_idx),
        zlim=[min_bic, bics.attrs['empty_model_bic']],
        zlabel='BIC',
    )
    plot.heatmap(
        xr.DataArray(ba.coords['alpha'][np.asarray(alpha_idx)].values, alpha_idx.coords),
        cmap='inferno',
        zlabel='alpha'
    )
    plot.heatmap(
        index_over_dim(ba['deviance_gexp'], 'alpha', alpha_idx),
        zlabel='Deviance (gexp)',
    )
    plot.heatmap(
        index_over_dim(ba['deviance_mri'], 'alpha', alpha_idx),
        zlabel='Deviance (MRI)',
    )
    plot.heatmap(
        index_over_dim(ba['dof_gexp'], 'alpha', alpha_idx),
        zlabel='Degrees of Freedom (gexp)',
    )
    plot.heatmap(
        index_over_dim(ba['dof_mri'], 'alpha', alpha_idx),
        zlabel='Degrees of Freedom (MRI)',
    )
    plot.heatmap(
        index_over_dim(ba['sparsity_gexp'], 'alpha', alpha_idx),
        zlim=[0, 1], cmap='Greys',
        zlabel='Sparsity (gexp)',
    )
    plot.heatmap(
        index_over_dim(ba['sparsity_mri'], 'alpha', alpha_idx),
        zlim=[0, 1], cmap='Greys',
        zlabel='Sparsity (MRI)',
    )

In [None]:
with plot.subplots(1, 1, figsize=(10, 3)) as (fig, ax):
    plot.lines(
        bics_array['bic'].min(['alpha', 'l_gexp', 'l_mri']),
        ax=ax,
    )