# Fun with trees

Several ideas can be worth exploring:
- use jax to compute gradients of custom loss functions for XGB usecases
- use jax to build trees in parallel
- use jax to build soft relaxations of trees
- speed up inference

Compare accuracies, building speeds, etc.

In [1]:
import jax
import jax.numpy as jnp
import jax.random as random

from jax.nn import sigmoid

import numpy as np


In [77]:
# 

key = random.PRNGKey(20221211)

n = 10_000 # Sample size
p = 10 # Predictor

key, skey = random.split(key, 2)
x = random.normal(skey, (n, p))

key, skey = random.split(key)
beta = random.normal(skey, (p,))

z = sigmoid(x @ beta)
key, skey = random.split(key)
y = random.bernoulli(skey, p=z).astype('int32')

In [78]:
# sketch of a tree in jax

depth = 5
key, skey = random.split(key)
features = random.categorical(skey, logits=np.ones(p) / p, shape=(depth,))

key, skey = random.split(key)
sign = 2 * (random.bernoulli(skey, shape=(depth, )) - .5)

key, skey = random.split(key)
offset = random.normal(skey, shape=(depth, )) / 10

selector = jnp.zeros((p, depth))
for i, feature in enumerate(features):
    selector = selector.at[feature, i].set(1.)

prediction = (x @ (selector * sign) < offset).all(axis=1).astype('int32')

In [79]:
prediction

DeviceArray([0, 0, 0, ..., 0, 0, 0], dtype=int32)

In [80]:
prediction.mean()

DeviceArray(0.0324, dtype=float32)

In [None]:
# need to organize in layers



In [2]:
# JIT tests

@jax.jit
def f(x):
    return x**2

f(jnp.zeros(10))



DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [4]:
%%timeit
f(jnp.zeros(10)).block_until_ready()

90.7 µs ± 644 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [5]:
%%timeit
f(jnp.zeros(100)).block_until_ready()

91.6 µs ± 261 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [69]:
from typing import NamedTuple, Any

class DecisionTree(NamedTuple):
    feature_id: Any = None
    sign: Any = None
    offset: Any = None
    left: Any = None
    right: Any = None
    prediction: Any = None 


def apply_tree(x, tree):
    if tree.prediction is not None:
        return tree.prediction
    
    result = jnp.where(
        tree.sign * x[tree.feature_id] <= tree.offset,
        apply_tree(x, tree.left),
        apply_tree(x, tree.right),
    )
    
    return result


In [70]:
tree = DecisionTree(
    0,
    1.,
    0.,
    DecisionTree(prediction=0),
    DecisionTree(prediction=1),
)

In [71]:
jit_apply_tree = jax.jit(apply_tree, static_argnums=(1,))


In [72]:
x = 19 * jnp.ones(1)

In [73]:
%%timeit
apply_tree(x, tree)

443 µs ± 1.29 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [74]:
%%timeit
jit_apply_tree(x, tree).block_until_ready()

2.29 µs ± 7.88 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [51]:
# let's fit on known dataset, and modify


In [75]:
from sklearn import tree
from sklearn.datasets import load_iris

iris = load_iris()
X, y = iris.data, iris.target
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)

In [104]:
%%timeit

clf.predict(X)

30.5 µs ± 166 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [76]:
clf.score(X, y)

1.0

In [112]:
clf.tree_.children_left, clf.tree_.children_right

(array([ 1, -1,  3,  4,  5, -1, -1,  8, -1, 10, -1, -1, 13, 14, -1, -1, -1],
       dtype=int64),
 array([ 2, -1, 12,  7,  6, -1, -1,  9, -1, 11, -1, -1, 16, 15, -1, -1, -1],
       dtype=int64))

In [138]:
# Apply tree for a sklearn tree

def apply_tree(x, tree):
    i = 0
    while tree.children_left[i] != -1 and tree.children_right[i] != -1:
        x[tree.feature[0]]
        decision = x[tree.feature[i]] <= tree.threshold[i]
        i = tree.children_left[i] if decision else tree.children_right[i]
    return tree.value[i].argmax(axis=1)

In [142]:
apply_tree(X[-1], tree)

array([2])

In [141]:
clf.predict(X)

array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
       1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
       2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2])

In [164]:
# Vectorized apply tree for a sklearn tree
from functools import partial

def apply_tree(x, tree, i=0):
    
    if tree.children_left[i] == -1 and tree.children_right[i] == -1:
        return tree.value[i].argmax(axis=1)
    
    return jnp.where(
        x[tree.feature[i]] <= tree.threshold[i],
        apply_tree(x, tree, i=tree.children_left[i]),
        apply_tree(x, tree, i=tree.children_right[i]),
    )


jit_apply_tree = jax.jit(jax.vmap(partial(apply_tree, tree=tree)))

In [161]:
%%timeit
apply_tree(X[-1], tree)

107 µs ± 4.78 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [162]:
%%timeit
jit_apply_tree(X[-1])

3.92 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [173]:
%%timeit
jit_apply_tree(X).block_until_ready()

5.15 µs ± 167 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [174]:
%%timeit
clf.predict(X)

44 µs ± 1.55 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [189]:
X2 = np.concatenate(10000 * [X])

In [190]:
X2.shape

(1500000, 4)

In [191]:
%%timeit
jit_apply_tree(X2).block_until_ready();

5.99 ms ± 146 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [192]:
%%timeit
clf.predict(X);

44.8 µs ± 2.92 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
