In [None]:
%matplotlib inline
import pytest
import ipytest
ipytest.autoconfig()

import numpy as np
import arviz as az
from preliz.distributions import Normal, Gamma
from preliz import predictive_explorer

In [None]:
%%ipytest

x = np.linspace(0, 1, 100)

@pytest.fixture
def model():
    def a_preliz_model(a_mu, a_sigma, c_sigma=1):
        a = Normal(a_mu, a_sigma).rvs()
        c = Gamma(mu=2, sigma=c_sigma).rvs()
        a = np.exp(a)
        b = Normal(a*x, c).rvs()
        return b
    return a_preliz_model

@pytest.mark.parametrize("iterations, kind_plot", [
    (50, "hist"),
    (10, "kde"),
    (10, "ecdf"),
])
def test_predictive_explorer(model, iterations, kind_plot):
    result = predictive_explorer(model, iterations, kind_plot)
    result._ipython_display_()
    slider0, slider1, slider2, plot_data = result.children
    slider0.value = -4
    slider1.value = 0.3
    slider2[2].value = 0.1
    assert 'image/png' in plot_data.outputs[0]["data"]

def lin_reg(predictions, ax):
    ax.plot(x, predictions.T, "k.")

def test_predictive_explorer_custom_plot(model, iterations, lin_reg):
    result = predictive_explorer(model, iterations, plot_func=lin_reg)
    result._ipython_display_()
    slider0, slider1, slider2, plot_data = result.children
    slider0.value = -4
    slider1.value = 0.3
    slider2[2].value = 0.1
    assert 'image/png' in plot_data.outputs[0]["data"]