In [1]:
import numpy as np
from scipy.stats import norm, entropy, uniform, bernoulli
from scipy.special import expit
import pandas as pd
from functionTree import rpTree, honestTree, kdTree, classifcationTree
from baseline_models import forward_selection
from multiprocess import Pool
from test_tree_variable_selection import variable_sel

RNG = np.random.default_rng(seed = 0)

var2 = 0.5
marginal_x1_pdf = uniform(-8, 8).pdf # norm(0, 4).pdf  

def cond_mean_x2(x1):
    return x1+2*np.sin(10*x1/(2*np.pi))

# generate data
def rvs(n, irr=100):
    x1 = norm(0, 1).rvs(size=n)
    x2 = norm(0, 1).rvs(size=n)
    x3 = norm(0, 1).rvs(size=n)

    y = bernoulli.rvs(expit(x1+x2+x3), random_state=RNG)
    
    irr_lst = [uniform(-2, 2).rvs(n) for _ in range(irr)]
    for each in [x1, x2, x3]:
        irr_lst.append(each)
    irr_columns = ['X' + str(i) for i in range(4, irr+4)]
    rr_columns = ['X1', 'X2', 'X3']
    cols = irr_columns + rr_columns
    df = pd.DataFrame(np.column_stack(irr_lst))
    df.columns = cols
    return df, y     

$Y = \sigma(-X_1 + X_2 + X_3)$, where $X_1 \sim U(-1, 1)$, $X_2 \sim N(0, 1)$, $X_3 = X_1X_2$

In [2]:
predictor, target = rvs(50,100) # sample size= 10, irrelevant variable =0, default relevant variable=3

In [7]:
from multiprocess import Pool
from test_tree_variable_selection import variable_sel

In [8]:
from copy import deepcopy

pool = Pool()

model_dic = {'rp': rpTree(),
            'kd': kdTree(),
            'classification': classifcationTree(),
            'honest': honestTree()}
for model in model_dic:
    tree = deepcopy(model_dic[model])
    best_subset, best_aic = variable_sel(predictor, target, pool, tree, tree_rep=1, best_subset=[])
    print(model, best_aic, best_subset)

rp 65.93271959882478 ['X56', 'X47', 'X27']
kd 74.8988519906905 ['X10', 'X78', 'X57']
classification 19.51219512195122 ['X51', 'X1', 'X21', 'X49', 'X10', 'X42', 'X65']
honest 7.142857142857143 ['X48', 'X1', 'X27']


In [3]:
predictor, target = rvs(20,2) # sample size= 10, irrelevant variable =0, default relevant variable=3
predictor=predictor[['X3', 'X1']]
predictor

Unnamed: 0,X3,X1
0,-0.303522,-0.41668
1,-0.119997,-0.959782
2,0.164086,-0.288513
3,0.08837,-0.430974
4,0.113221,-0.174043
5,-0.22332,-0.46784
6,-0.045943,-0.186756
7,0.001177,-0.00299
8,0.506169,-0.649445
9,-0.176859,-0.828979


In [None]:
[[('rp', array([1.]), -0.9391980171054539, '>'), ('rp', array([-1.]), 1.0133026640524179, '<')], [('rp', array([1.]), -0.9391980171054539, '>'), ('rp', array([-1.]), 1.0133026640524179, '>')], [('rp', array([1.]), -0.9391980171054539, '<')]] 

[[(-0.6330815498654132,), (-1.0084711815850993,), (-0.8916807470437896,), (-0.9613225392100532,), (0.09141215988698681,), (-0.484781797858888,), (-0.646139820022364,), (-0.7182985075217854,), (-0.9580060289809842,), (-0.7266495997636335,), (0.04824067209320693,), (-0.34413025405923825,), (-0.8432861638592476,), (-0.9081677355838189,), (-0.9523869220462747,), (-0.26876989504621007,), (-0.5665955297319988,), (-0.24732980445068597,), (-0.9261476126236919,), (-0.9746193576764148,), (-0.13926977893204018,), (-0.7081124697370448,), (-0.4821883135844167,), (-0.7901206746796751,), (-0.7063693436816536,), (-0.20385967518602544,), (-0.7296955274136548,), (-0.9820631049983647,), (-0.8875448797108171,), (-0.8555651738703535,), (-0.5439881304660874,), (-0.6723296757779247,), (-0.9438156908812536,), (-0.7499764799890496,), (-0.7824724394524504,), (-0.9137093152366838,), (-0.2518829657274605,), (-0.060051786541031604,), (-0.9322286496165807,), (-0.20668517541554976,), (-0.9944731107681268,), (-0.2013532012383441,), (-0.8989926272614222,), (-0.6106888548240046,), (-0.5210614068833006,), (-0.6026600305543002,), (-0.69503591674329,), (-0.6773966189203249,), (-0.4178221354496857,), (-0.442975150582711,)], 
[(-1.3192589041407303,), (-1.4115055264191427,), (-2.0260059083230084,), (-2.85356464284673,), (-1.348832352685009,), (-1.5832324982583157,), (-1.7012111356715005,), (-1.451027522628173,), (-1.295575252243692,), (-2.3621424879996757,), (-1.7229399498358464,), (-2.090363620754941,), (-1.51074669976758,), (-1.43645193395446,), (-1.260288764222354,), (-1.2389499537203512,), (-1.050578503851126,), (-1.5137694562804134,), (-1.1681861817070147,), (-1.0349697607799377,), (-1.1857879632367418,), (-1.344555057145801,), (-1.2997424834608087,), (-1.3976902370724134,), (-1.0181341465197364,), (-1.067145705041755,), (-1.4137455987808603,), (-1.0840741223767045,), (-1.5330965384618613,), (-2.116436179967007,), (-1.1010820051074708,), (-1.3568516455075865,), (-1.663741916981891,), (-1.5780285323575758,), (-1.245401374494392,), (-2.6060457419564536,), (-1.0912681419738333,), (-1.5561912148708705,), (-1.061621490747472,), (-1.488862541539231,), (-1.1636079327064066,), (-1.1287726678367913,), (-1.2002333195529677,), (-1.4256815934863796,), (-1.3910706266450479,), (-1.019031367329893,), (-1.1295583479664844,), (-1.7498140631521122,), (-1.7198867610489952,), (-1.6936182615532573,)]]

In [6]:
for each in predictor:
    print(each)

X3
X1


In [4]:
from functionTree import RPTree, HonestTree, KDTree, classifcationTree

In [6]:
target

array([1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0,
       1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 1,
       0, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0,
       0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0,
       1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0,
       0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 0, 1, 1,
       0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1,
       0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0,
       1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0,
       0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1,
       0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1,
       1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1,
       0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 0, 0,
       1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1,

In [4]:
tree = FunctionTree(option='rp', AIC=True) # use RP tree, AIC criteria. If AIC=False, use BIC
(aic, fmi, num_leaves, rules),tree = tree.fit(predictor, target)  # fmi is not fixed.
aic, num_leaves, rules

71


(1867.580937708543,
 20,
 [[('rp', array([ 0.97337974, -0.22919835]), 0.0669898415880815, '<'),
   ('rp', array([0.79240723, 0.60999244]), -0.44166231553350155, '>'),
   ('rp', array([-0.67966399, -0.7335236 ]), 0.1701419818058276, '<'),
   ('rp', array([-0.98328798, -0.18205697]), 0.028939153500693794, '>'),
   ('rp', array([-0.47332312,  0.88088888]), -0.07952202130441989, '<')],
  [('rp', array([ 0.97337974, -0.22919835]), 0.0669898415880815, '<'),
   ('rp', array([0.79240723, 0.60999244]), -0.44166231553350155, '>'),
   ('rp', array([-0.67966399, -0.7335236 ]), 0.1701419818058276, '<'),
   ('rp', array([-0.98328798, -0.18205697]), 0.028939153500693794, '>'),
   ('rp', array([-0.47332312,  0.88088888]), -0.07952202130441989, '>')],
  [('rp', array([ 0.97337974, -0.22919835]), 0.0669898415880815, '<'),
   ('rp', array([0.79240723, 0.60999244]), -0.44166231553350155, '>'),
   ('rp', array([-0.67966399, -0.7335236 ]), 0.1701419818058276, '<'),
   ('rp', array([-0.98328798, -0.18205697]

In [5]:
partition, loss, risk, ids_lst = tree.predict(predictor, target)  # partition, logloss, least square loss

In [7]:
loss

0.9137958617718838

In [8]:
risk

0.3510000000000001

In [9]:
ids_lst


dict_values([16, 18, 4, 10, 8, 18, 5, 2, 12, 15, 11, 9, 11, 7, 16, 4, 17, 17, 8, 10, 17, 8, 13, 16, 17, 13, 6, 14, 12, 3, 13, 11, 13, 17, 0, 4, 11, 13, 18, 15, 9, 9, 6, 6, 17, 17, 7, 17, 2, 1, 18, 12, 6, 13, 14, 9, 1, 14, 17, 5, 10, 11, 18, 15, 17, 10, 14, 2, 16, 16, 16, 19, 8, 10, 4, 13, 14, 18, 17, 17, 14, 2, 15, 5, 8, 19, 16, 5, 6, 8, 9, 9, 1, 5, 18, 16, 18, 11, 17, 13, 2, 13, 15, 19, 2, 13, 3, 14, 10, 1, 9, 2, 10, 2, 8, 4, 2, 16, 16, 11, 14, 5, 17, 10, 19, 9, 9, 15, 7, 19, 7, 7, 7, 4, 19, 12, 11, 14, 9, 9, 17, 1, 9, 6, 3, 19, 19, 19, 14, 4, 5, 9, 17, 10, 10, 2, 11, 15, 3, 1, 17, 13, 7, 0, 6, 17, 16, 6, 8, 14, 13, 6, 18, 16, 4, 8, 18, 14, 16, 19, 7, 10, 19, 12, 7, 10, 14, 19, 16, 12, 16, 10, 12, 10, 19, 6, 2, 19, 0, 13, 8, 3, 13, 15, 12, 13, 2, 0, 2, 10, 7, 5, 7, 14, 3, 19, 4, 19, 19, 14, 18, 15, 17, 11, 8, 10, 16, 9, 12, 2, 9, 8, 19, 16, 13, 19, 11, 13, 4, 5, 19, 12, 7, 9, 13, 14, 4, 10, 3, 14, 14, 11, 11, 16, 13, 1, 17, 9, 11, 10, 1, 15, 13, 17, 13, 10, 19, 5, 18, 19, 19, 16, 14, 

In [1]:
from functionTree import HonestTree

HonestTree().fit([],[])

classification


AttributeError: 'list' object has no attribute 'values'