In [1]:
import numpy as np
from scipy.stats import norm, entropy, uniform, bernoulli
from scipy.special import expit
import pandas as pd
from functionTree import *

RNG = np.random.default_rng(seed = 0)

var2 = 0.5
marginal_x1_pdf = uniform(-8, 8).pdf # norm(0, 4).pdf  

def cond_mean_x2(x1):
    return x1+2*np.sin(10*x1/(2*np.pi))

# generate data
def rvs(n, irr=10):
    x1 = uniform(-1, 1).rvs(n, RNG)
    x2 = norm.rvs(loc=cond_mean_x2(np.array([0 for _ in x1])), scale=var2**0.5, random_state=RNG)
    x3 = x1 * x2 #uniform(-8, 8).rvs(n, RNG) #norm.rvs(loc=cond_mean_x2(x1+2*x2-1), scale=var2**0.5, random_state=RNG)
    y = bernoulli.rvs(expit(-x1+x2+x3), random_state=RNG)
    
    irr_lst = [norm(0, 4).rvs(n) for _ in range(irr)]
    for each in [x1, x2, x3]:
        irr_lst.append(each)
    irr_columns = ['X' + str(i) for i in range(4, irr+4)]
    rr_columns = ['X1', 'X2', 'X3']
    cols = irr_columns + rr_columns
    df = pd.DataFrame(np.column_stack(irr_lst))
    df.columns = cols
    return df, y     

$Y = \sigma(-X_1 + X_2 + X_3)$, where $X_1 \sim U(-1, 1)$, $X_2 \sim N(0, 1)$, $X_3 = X_1X_2$

In [2]:
predictor, target = rvs(100,3) # sample size= 10, irrelevant variable =0, default relevant variable=3
predictor

Unnamed: 0,X4,X5,X6,X1,X2,X3
0,6.518587,4.052888,2.328284,-0.363038,-0.948386,0.344300
1,-6.158268,1.316741,-0.076888,-0.730213,-0.991024,0.723659
2,-5.638200,-0.062576,0.972643,-0.959026,0.355450,-0.340886
3,-5.200589,-2.030027,2.316268,-0.983472,0.699833,-0.688266
4,2.587911,-0.641523,1.714403,-0.186730,-0.116174,0.021693
...,...,...,...,...,...,...
95,3.115257,-7.031121,-7.433308,-0.042790,-0.035782,0.001531
96,-1.166182,-4.460834,0.295547,-0.851236,-0.200200,0.170417
97,4.626912,6.250948,-0.365407,-0.027371,1.161954,-0.031804
98,-4.348491,-4.578561,-6.542406,-0.110064,-0.906970,0.099825


In [3]:
predictor, target = rvs(1000,2) # sample size= 10, irrelevant variable =0, default relevant variable=3
predictor=predictor[['X3', 'X1']]
predictor

Unnamed: 0,X3,X1
0,-0.582395,-0.416680
1,-0.414775,-0.959782
2,-0.165938,-0.288513
3,0.060995,-0.430974
4,0.098203,-0.174043
...,...,...
995,0.333666,-0.426923
996,0.167024,-0.226028
997,-0.140782,-0.147127
998,0.090536,-0.138772


In [4]:
tree = FunctionTree(option='rp', AIC=True) # use RP tree, AIC criteria. If AIC=False, use BIC
(aic, fmi, num_leaves, rules),tree = tree.fit(predictor, target)  # fmi is not fixed.
aic, num_leaves, rules

71


(1867.580937708543,
 20,
 [[('rp', array([ 0.97337974, -0.22919835]), 0.0669898415880815, '<'),
   ('rp', array([0.79240723, 0.60999244]), -0.44166231553350155, '>'),
   ('rp', array([-0.67966399, -0.7335236 ]), 0.1701419818058276, '<'),
   ('rp', array([-0.98328798, -0.18205697]), 0.028939153500693794, '>'),
   ('rp', array([-0.47332312,  0.88088888]), -0.07952202130441989, '<')],
  [('rp', array([ 0.97337974, -0.22919835]), 0.0669898415880815, '<'),
   ('rp', array([0.79240723, 0.60999244]), -0.44166231553350155, '>'),
   ('rp', array([-0.67966399, -0.7335236 ]), 0.1701419818058276, '<'),
   ('rp', array([-0.98328798, -0.18205697]), 0.028939153500693794, '>'),
   ('rp', array([-0.47332312,  0.88088888]), -0.07952202130441989, '>')],
  [('rp', array([ 0.97337974, -0.22919835]), 0.0669898415880815, '<'),
   ('rp', array([0.79240723, 0.60999244]), -0.44166231553350155, '>'),
   ('rp', array([-0.67966399, -0.7335236 ]), 0.1701419818058276, '<'),
   ('rp', array([-0.98328798, -0.18205697]

In [5]:
partition, loss, risk, ids_lst = tree.predict(predictor, target)  # partition, logloss, least square loss

In [7]:
loss

0.9137958617718838

In [8]:
risk

0.3510000000000001

In [9]:
ids_lst


dict_values([16, 18, 4, 10, 8, 18, 5, 2, 12, 15, 11, 9, 11, 7, 16, 4, 17, 17, 8, 10, 17, 8, 13, 16, 17, 13, 6, 14, 12, 3, 13, 11, 13, 17, 0, 4, 11, 13, 18, 15, 9, 9, 6, 6, 17, 17, 7, 17, 2, 1, 18, 12, 6, 13, 14, 9, 1, 14, 17, 5, 10, 11, 18, 15, 17, 10, 14, 2, 16, 16, 16, 19, 8, 10, 4, 13, 14, 18, 17, 17, 14, 2, 15, 5, 8, 19, 16, 5, 6, 8, 9, 9, 1, 5, 18, 16, 18, 11, 17, 13, 2, 13, 15, 19, 2, 13, 3, 14, 10, 1, 9, 2, 10, 2, 8, 4, 2, 16, 16, 11, 14, 5, 17, 10, 19, 9, 9, 15, 7, 19, 7, 7, 7, 4, 19, 12, 11, 14, 9, 9, 17, 1, 9, 6, 3, 19, 19, 19, 14, 4, 5, 9, 17, 10, 10, 2, 11, 15, 3, 1, 17, 13, 7, 0, 6, 17, 16, 6, 8, 14, 13, 6, 18, 16, 4, 8, 18, 14, 16, 19, 7, 10, 19, 12, 7, 10, 14, 19, 16, 12, 16, 10, 12, 10, 19, 6, 2, 19, 0, 13, 8, 3, 13, 15, 12, 13, 2, 0, 2, 10, 7, 5, 7, 14, 3, 19, 4, 19, 19, 14, 18, 15, 17, 11, 8, 10, 16, 9, 12, 2, 9, 8, 19, 16, 13, 19, 11, 13, 4, 5, 19, 12, 7, 9, 13, 14, 4, 10, 3, 14, 14, 11, 11, 16, 13, 1, 17, 9, 11, 10, 1, 15, 13, 17, 13, 10, 19, 5, 18, 19, 19, 16, 14, 