In [None]:
import pymc as pm
import numpy as np
import matplotlib.pyplot as plt
import arviz as az
%matplotlib inline

In [None]:
az.style.use('arviz-white')
%config Completer.use_jedi = False# Don't use the force?
np.random.seed(1846)

## Data generation

In [None]:
X = np.random.uniform(low=0, high=1.0, size=(100, 1000))

In [None]:
f_x = 10*np.sin(np.pi*X[:,0]*X[:,1]) + 20*(X[:,2]-0.5)**2 + 10*X[:,3] + 5*X[:,4]
Y = np.random.normal(f_x, 1)

In [None]:
names_x = ['10', '20', '100', '1000']
idatas_Xs = {}
VIs_Xs = []

In [None]:
idatas = {}
trees = [10, 20, 50, 100, 200]
VIs = []
for m in trees:
    with pm.Model(rng_seeder=678) as model:
        μ = pm.BART('μ', X[:,:10], Y, m=m) 
        σ = pm.HalfNormal('σ', 1)
        y = pm.Normal('y', μ, σ, observed=Y)
        idata = pm.sample(chains=4, random_seed=678)
        idatas[str(m)] = idata
        # Variable importance
        VI = idata.sample_stats["variable_inclusion"].stack(samples=("chain", "draw")).mean("samples").values
        VIs.append(VI / VI.sum())

In [None]:
for i in range(0, len(VIs)):
    name = str(trees[i])
    plt.plot(VIs[i],
             label='trees = {}'.format(name),marker='o',linestyle='dashed')
plt.axhline(1/X[:,:10].shape[1], ls="--", color="k")
plt.legend();
plt.savefig('var-importance.png')