In [None]:
%matplotlib inline
import pytest
import ipytest
ipytest.autoconfig()

import numpy as np
import arviz as az
import pymc as pm
import pandas as pd
import bambi as bmb

from preliz.distributions import Normal, Gamma
from preliz import predictive_explorer

In [None]:
%%ipytest

x = np.linspace(0, 1, 100)

@pytest.fixture
def model():
    def a_preliz_model(a_mu, a_sigma, c_sigma=1):
        a = Normal(a_mu, a_sigma).rvs()
        c = Gamma(mu=2, sigma=c_sigma).rvs()
        a = np.exp(a)
        b = Normal(a*x, c).rvs()
        return b
    return a_preliz_model

@pytest.mark.parametrize("iterations, kind_plot", [
    (50, "hist"),
    (10, "kde"),
    (10, "ecdf"),
])
def test_predictive_explorer(model, iterations, kind_plot):
    predictive_explorer(model, iterations, kind_plot)

def lin_reg(predictions, ax):
    ax.plot(x, predictions.T, "k.")

def test_predictive_explorer_custom_plot(model):
    predictive_explorer(model, 50, plot_func=lin_reg)
    

In [None]:
%%ipytest

@pytest.fixture
def model():
    def a_pymc_model(a_mu, b_sigma=1):
        with pm.Model() as model:
            a = pm.Normal("a", a_mu, 1)
            b = pm.HalfNormal("b", b_sigma)
            c = pm.Normal("c", a, b, observed=[0]*100)
        return model
    return a_pymc_model
        
        
@pytest.mark.parametrize("iterations, kind_plot", [
    (50, "hist"),
    (10, "kde"),
    (10, "ecdf"),
])
def test_predictive_explorer(model, iterations, kind_plot):
    predictive_explorer(model, iterations, kind_plot, engine="pymc")

In [None]:
%%ipytest

data = pd.DataFrame(
    {
        "y": np.random.normal(size=117),
        "x": np.random.normal(size=117),
    }
)
data.head()

@pytest.fixture
def model():
    def a_bambi_model(a_mu, b_sigma=1):
        prior = {"Intercept": bmb.Prior("Normal", mu=a_mu, sigma=b_sigma)}
        a_model = bmb.Model("y ~ x", data, priors=prior)
        return a_model
    return a_bambi_model
    
@pytest.mark.parametrize("iterations, kind_plot", [
    (50, "hist"),
    (10, "kde"),
    (10, "ecdf"),
])
def test_predictive_explorer(model, iterations, kind_plot):
    predictive_explorer(model, iterations, kind_plot)