In [None]:
%matplotlib inline
import sys

import arviz as az
import numpy as np
import pymc as pm

if not sys.warnoptions:
    import warnings
warnings.simplefilter("ignore")

from forest import *

## Test 1

### Create Models

In [None]:
dict_cmp = {}
d1 = az.load_arviz_data("centered_eight")
d2 = az.load_arviz_data("non_centered_eight")
dict_cmp["mA"] = d1
dict_cmp["mB"] = d2

### Build Dashboard

In [None]:
dashboard_forest(dict_cmp)

## Components

In [None]:
fp = az.plot_forest(
    dict_cmp["mA"],
    kind="forestplot",
    hdi_prob=0.9,
    backend="bokeh",
    figsize=(9, 9),
    combined=True,
)

In [None]:
vars(fp[0][0])["_property_values"]

In [None]:
vars(
    vars(vars(fp[0][0])["_property_values"]["renderers"][0])["_property_values"][
        "data_source"
    ]
)["_property_values"]

In [None]:
vars(
    vars(vars(fp[0][0])["_property_values"]["renderers"][1])["_property_values"][
        "data_source"
    ]
)["_property_values"]

In [None]:
vars(
    vars(vars(fp[0][0])["_property_values"]["renderers"][2])["_property_values"][
        "glyph"
    ]
)["_property_values"]

In [None]:
vars(
    vars(vars(fp[0][0])["_property_values"]["renderers"][3])["_property_values"][
        "data_source"
    ]
)["_property_values"]

In [None]:
vars(
    vars(vars(fp[0][0])["_property_values"]["renderers"][4])["_property_values"][
        "data_source"
    ]
)["_property_values"]

In [None]:
vars(
    vars(vars(fp[0][0])["_property_values"]["renderers"][5])["_property_values"][
        "glyph"
    ]
)["_property_values"]

## Ridge_Plot

In [None]:
az.plot_forest(
    dict_cmp["mA"],
    kind="ridgeplot",  # need to change
    var_names=["theta"],  # need to change (the variables?)
    combined=True,
    ridgeplot_truncate=False,  # need to change
    ridgeplot_quantiles=[0.25, 0.5, 0.75],  # need to change
    ridgeplot_overlap=0.7,
    colors="white",
    figsize=(9, 7),
)

In [None]:
ridge_quant = (0.25, 0.75)

In [None]:
temp_quant = list(ridge_quant)

In [None]:
temp_quant

In [None]:
sum(temp_quant) / 2

In [None]:
quant_ls = temp_quant

In [None]:
quant_ls.append(sum(temp_quant) / 2)

In [None]:
quant_ls

## Coords Link to Data Vars

In [None]:
d2

In [None]:
list(d1.posterior.data_vars.variables)

In [None]:
d1.posterior.data_vars.variables["theta"]

In [None]:
d1.posterior.data_vars.variables["mu"].shape

In [None]:
d1.posterior.data_vars.variables["mu"][0][0].size

In [None]:
d1.posterior.data_vars.variables["tau"][0][0].size

In [None]:
d1.posterior.data_vars.variables["theta"].shape

In [None]:
d1.posterior.data_vars.variables["mu"][0][0]

In [None]:
variables_selection = pn.widgets.Select(
    value=None, options=list(d1.posterior.data_vars.variables), name="Data Variables"
)
coord_selection = pn.widgets.Select(
    value=None, options=[""], name="Coordinates Variables"
)


@pn.depends(variables_selection.param.value)
def update_coords(variables_selection):
    if d1.posterior.data_vars.variables[variables_selection][0][0].size > 1:
        coord_selection.options = list(d1.posterior.indexes["school"])
    else:
        coord_selection.options = [""]
    return coord_selection

In [None]:
display(pn.Row(variables_selection, update_coords))

In [None]:
import param

In [None]:
class ModelVar(param.Parameterized):
    idatas_cmp = dict_cmp
    default_model = list(idatas_cmp.keys())[0]
    model = param.Selector(list(idatas_cmp.keys()), default=default_model)
    data_variable = param.Selector(
        list(idatas_cmp[default_model].posterior.data_vars.variables)
    )
    coor_variable = param.Selector("")

    @param.depends("model", watch=True)
    def _update_data_variables(self):
        data_variables = list(self.idatas_cmp[self.model].posterior.data_vars.variables)
        self.param["data_variable"].objects = data_variables
        if self.data_variable not in data_variables:
            self.data_variable = data_variables[0]

    @param.depends("data_variable", watch=True)
    def _update_coordinates(self):
        if (
            self.idatas_cmp[self.model]
            .posterior.data_vars.variables[self.data_variable][0][0]
            .size
            > 1
        ):
            coor_variables = list(
                self.idatas_cmp[self.model].posterior.indexes["school"]
            )
        else:
            coor_variables = [""]
        self.param["coor_variable"].objects = coor_variables
        if self.coor_variable not in coor_variables:
            self.coor_variable = coor_variables[0]


c = ModelVar()
pn.Row(c)

In [None]:
class ModelVar(param.Parameterized):
    default_model = list(idatas_cmp.keys())[0]
    model = param.Selector(list(idatas_cmp.keys()), default=default_model)
    data_variable = param.Selector(
        list(idatas_cmp[default_model].posterior.data_vars.variables)
    )
    coor_variable = param.Selector("")

    @param.depends("model", watch=True)
    def _update_data_variables(self):
        data_variables = list(self.idatas_cmp[self.model].posterior.data_vars.variables)
        self.param["data_variable"].objects = data_variables
        if self.data_variable not in data_variables:
            self.data_variable = data_variables[0]

    @param.depends("data_variable", watch=True)
    def _update_coordinates(self):
        if (
            self.idatas_cmp[self.model]
            .posterior.data_vars.variables[self.data_variable][0][0]
            .size
            > 1
        ):
            coor_variables = list(
                self.idatas_cmp[self.model].posterior.indexes["school"]
            )
        else:
            coor_variables = [""]
        self.param["coor_variable"].objects = coor_variables
        if self.coor_variable not in coor_variables:
            self.coor_variable = coor_variables[0]


class ForestDashboard(ModelVar):
    def __init__(self, idatas_cmp) -> None:
        self.idatas_cmp = idatas_cmp
        self.models = super()

    def dashboard_forest(self):
        # define the widgets
        multi_select = pn.widgets.MultiSelect(
            name="ModelSelect",
            options=list(self.idatas_cmp.keys()),
            value=["mA"],
        )
        thre_slider = pn.widgets.FloatSlider(
            name="HDI Probability", start=0, end=1, step=0.05, value=0.7, width=200
        )
        truncate_checkbox = pn.widgets.Checkbox(name="Ridgeplot Truncate")
        ridge_quant = pn.widgets.RangeSlider(
            name="Ridgeplot Quantiles",
            start=0,
            end=1,
            value=(0.25, 0.75),
            step=0.01,
            width=200,
        )
        op_slider = pn.widgets.FloatSlider(
            name="Ridgeplot Overlap", start=0, end=1, step=0.05, value=0.7, width=200
        )

        # construct widget
        @pn.depends(
            multi_select.param.value,
            thre_slider.param.value,
        )
        def get_forest_plot(
            multi_select,
            thre_slider,
        ):
            # generate graph
            data = []
            for model_ in multi_select:
                data.append(self.idatas_cmp[model_])

            forest_plt = az.plot_forest(
                data,
                model_names=multi_select,
                kind="forestplot",
                hdi_prob=thre_slider,
                backend="bokeh",
                figsize=(9, 9),
                show=False,
                combined=True,
                colors="cycle",
            )
            return forest_plt[0][0]

        @pn.depends(
            multi_select.param.value,
            thre_slider.param.value,
            truncate_checkbox.param.value,
            ridge_quant.param.value,
            op_slider.param.value,
        )
        def get_ridge_plot(
            multi_select,
            thre_slider,
            truncate_checkbox,
            ridge_quant,
            op_slider,
        ):
            # calculate the ridgeplot_quantiles
            temp_quant = list(ridge_quant)
            quant_ls = temp_quant
            quant_ls.sort()
            avg_quant = sum(temp_quant) / 2
            if quant_ls[0] < 0.5 and quant_ls[1] > 0.5:
                quant_ls.append(0.5)
                quant_ls.sort()
            else:
                quant_ls.append(avg_quant)
                quant_ls.sort()

            # generate graph
            data = []
            for model_ in multi_select:
                data.append(self.idatas_cmp[model_])

            ridge_plt = az.plot_forest(
                data,
                model_names=multi_select,
                kind="ridgeplot",
                hdi_prob=thre_slider,
                ridgeplot_truncate=truncate_checkbox,
                ridgeplot_quantiles=quant_ls,
                ridgeplot_overlap=op_slider,
                backend="bokeh",
                figsize=(9, 9),
                show=False,
                combined=True,
                colors="white",
            )
            return ridge_plt[0][0]

        plot_result_1 = pn.Row(get_forest_plot)
        plot_result_2 = pn.Column(
            pn.Row(truncate_checkbox),
            pn.Row(ridge_quant, op_slider),
            get_ridge_plot,
        )

        # show up
        display(
            pn.Column(
                pn.Row(multi_select),
                thre_slider,
                # pn.Row(variables_selection, update_coords),
                pn.Tabs(
                    ("Forest_Plot", plot_result_1),
                    (
                        "Rdiget_Plot",
                        plot_result_2,
                    ),
                ),
            ).servable(),
        )

In [None]:
forest_dashboard = ForestDashboard(dict_cmp)

In [None]:
forest_dashboard.dashboard_forest()