This code will run the model on a dataset with 100 trees, then sort the trees by their likelihood on the test set. Then the trees will be added to an ensemble one by one and have their performance tested for accuracy on both test and training sets. this list of accuracies (actually 2 lists) will be reported and plotted.

We are testing the hypothesis that the code eventually finds 1 good tree and that is the main part of the results improving by getting more trees.

If this is true, we expect the tree with highest likelihood on the test set alone to give the best likelihood, and more trees to decrease likelihood. If it is false, we expect a different pattern, for example a logarithmic increase just like if we increase the number of trees.

For this test, we will use a forest of 100 trees, with trees of depth 5.

In [86]:
#libraries and imports here
from experiments.prep import get_data
from sklearn.model_selection import KFold
import numpy as np
from gefs.trees import RandomForest
from time import time
from gefs.nodes import evaluate, eval_root, eval_root_class

In [100]:
#functions here
def good_tree_test(dataset_name, nr_trees, depth, seed = 0):
    """
    Runs a good tree test, to see if there is one very good tree that pulls accuracy up, with specified parameters. 
    Dataset_name: The dataset to use for the experiment
    nr_trees: number of trees in the forest
    depth: depth set as maximum for each tree in the forest
    seed: the seed for random numbers
    """
    
    start = time()
    
    #open dataset & preprep data
    data, n_cat = get_data(dataset_name, "data/")
    print(f"{time()-start} Data preprocessing done after ")
    
    
    #split into train (90%) and test (10%)
    splitter = KFold(10, random_state = seed, shuffle = True)
    splitted = splitter.split(data) #creates a generator that splits stuff
    train_index, test_index = next(splitted) #takes the first item of the generator, the first fold
    
    #training data
    data_train = data[train_index]
    X_train, y_train = data_train[:, :-1], data_train[:, -1] 
    
    #testing data
    data_test = data[test_index]
    correct_class_test = data_test[:, -1].copy()
    #data_test[:, -1] = np.nan
    
    print(len(data_test))
    
    #copy of training data to be used as testing data (needs to be seperate because class is removed)
    data_train_no_class = data[train_index]
    correct_class_train = data_train_no_class[:, -1].copy()
    #data_train_no_class[:, -1] = np.nan
    
    #make the model
    print(f"{time()-start} model creation started")
    #set a seed
    np.random.seed(seed)
    model = RandomForest(n_estimators=nr_trees, ncat=n_cat, max_depth = depth)
    model.fit(X_train, y_train)
    print(f"{time()-start} model created")
    
    #transform it to a pc
    circuit = model.topc(learnspn=np.Inf)
    print(f"{time()-start} model circuited")
    
    #get individual trees out of the model
    #print(circuit.root)
    trees = circuit.root.children#[child for child in circuit.root.children]
    
    #test each tree to get likelyhood on test set:
    ll_test = []
    
    for tree in trees:
        class_var = len(data_test[0])-1
        n_classes = len(np.unique(data_test[:,-1]))
        naive = False
        #print(data_test[:,-1])
        #print(class_var, n_classes)
        ll_test_new = eval_root_class(tree, data_test, class_var, n_classes, naive)#eval_root(tree, data_test)
        
        #print(mul(ll_test_new))
        print(ll_test_new)
        #print(sum(ll_test_new))
        ll_test.append(ll_test_new)

    
    #sort them by likelyhood on test set
    
    #add trees to a new forest one by one and test its performance (accuracy)
    
    #return performance
    return acc_test, acc_train

def mul(iterable):
    """multiplies all values in the iterable and gives the result"""
    result = 1
    for x in iterable:
        result *= x
    return result

In [101]:
#actual executing code here
dataset = "diabetes.csv"
nr_trees = 100
depth = 5
seed = 0

acc_test, acc_train = good_tree_test(dataset, nr_trees, depth, seed)


0.01795816421508789 Data preprocessing done after 
77
0.020946264266967773 model creation started


 40%|████████████████████████████████                                                | 40/100 [00:00<00:00, 391.96it/s]

0.2903289794921875 model created


100%|███████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 333.93it/s]


0.5927846431732178 model circuited
[[[         -inf  -30.89719332]
  [         -inf  -30.20404617]]

 [[         -inf  -39.64727157]
  [         -inf  -40.14370843]]

 [[         -inf  -32.50319361]
  [         -inf  -32.72633711]]

 [[ -27.52860634          -inf]
  [ -24.54292459          -inf]]

 [[ -37.30770303          -inf]
  [ -52.91497323          -inf]]

 [[ -31.29121847          -inf]
  [ -29.79729352          -inf]]

 [[ -44.65439701          -inf]
  [ -26.74454187          -inf]]

 [[ -57.55719055          -inf]
  [ -39.64733542          -inf]]

 [[ -27.02632661          -inf]
  [ -24.04064487          -inf]]

 [[ -27.10504491          -inf]
  [ -24.11936316          -inf]]

 [[ -29.86472991          -inf]
  [ -28.37080496          -inf]]

 [[ -27.24999144          -inf]
  [ -25.75606649          -inf]]

 [[ -28.04739717          -inf]
  [ -26.55347222          -inf]]

 [[ -69.17620658          -inf]
  [ -66.19052483          -inf]]

 [[ -34.98159585          -inf]
  [ -31.9

[[[-32.66302823         -inf]
  [-33.1738538          -inf]]

 [[        -inf -37.84031739]
  [        -inf -53.04212256]]

 [[-30.3101011          -inf]
  [-32.38954177         -inf]]

 [[-45.7469751          -inf]
  [-27.07943427         -inf]]

 [[        -inf -30.06558097]
  [        -inf -34.14311743]]

 [[        -inf -31.10172582]
  [        -inf -31.68364734]]

 [[-46.91167499         -inf]
  [-28.24413416         -inf]]

 [[-60.45691059         -inf]
  [-41.78936976         -inf]]

 [[-45.74487679         -inf]
  [-27.07733597         -inf]]

 [[-43.12353528         -inf]
  [-25.87403748         -inf]]

 [[        -inf -31.00537654]
  [        -inf -28.56787223]]

 [[        -inf -26.70069667]
  [        -inf -24.26319236]]

 [[        -inf -28.50156106]
  [        -inf -26.06405675]]

 [[-61.27254985         -inf]
  [-42.60500902         -inf]]

 [[-60.15717093         -inf]
  [-42.90767313         -inf]]

 [[        -inf -29.228071  ]
  [        -inf -33.30560746]]

 [[     

[[[ -28.56719719          -inf]
  [ -27.36109901          -inf]]

 [[         -inf          -inf]
  [         -inf          -inf]]

 [[ -31.0741887           -inf]
  [ -32.81715788          -inf]]

 [[ -30.73078747          -inf]
  [ -27.39146568          -inf]]

 [[ -31.00108677          -inf]
  [ -32.74405596          -inf]]

 [[ -29.14696122          -inf]
  [ -30.8899304           -inf]]

 [[ -31.11994158          -inf]
  [ -27.7806198           -inf]]

 [[ -41.37110231          -inf]
  [ -38.03178053          -inf]]

 [[ -30.49805161          -inf]
  [ -27.15872983          -inf]]

 [[ -30.74345002          -inf]
  [ -27.40412823          -inf]]

 [[ -29.84593381          -inf]
  [ -29.04156101          -inf]]

 [[ -26.45227558          -inf]
  [ -25.24617739          -inf]]

 [[ -29.17026911          -inf]
  [ -27.96417092          -inf]]

 [[ -41.75618806          -inf]
  [ -38.41686628          -inf]]

 [[ -35.70583922          -inf]
  [ -34.90146642          -inf]]

 [[       

[[[-29.64572294         -inf]
  [-27.87643638         -inf]]

 [[        -inf         -inf]
  [        -inf         -inf]]

 [[        -inf -29.95253672]
  [        -inf -32.28329256]]

 [[-31.25175712         -inf]
  [-28.46226668         -inf]]

 [[        -inf -31.41715569]
  [        -inf -33.74791153]]

 [[        -inf -30.22868031]
  [        -inf -32.55943615]]

 [[-30.04269616         -inf]
  [-27.25320571         -inf]]

 [[-43.020453           -inf]
  [-40.23096255         -inf]]

 [[-28.93857425         -inf]
  [-26.1490838          -inf]]

 [[-28.81734903         -inf]
  [-26.02785858         -inf]]

 [[-30.7372696          -inf]
  [-27.94777916         -inf]]

 [[-27.12123642         -inf]
  [-25.35194987         -inf]]

 [[-29.33716485         -inf]
  [-26.5476744          -inf]]

 [[-38.20354144         -inf]
  [-35.41405099         -inf]]

 [[-33.64711212         -inf]
  [-30.85762167         -inf]]

 [[        -inf -29.44795433]
  [        -inf -31.77871017]]

 [[-30.7

[[[ -30.44765774          -inf]
  [ -26.65780261          -inf]]

 [[         -inf          -inf]
  [         -inf          -inf]]

 [[         -inf  -35.55472605]
  [         -inf  -38.08045423]]

 [[ -33.21082736          -inf]
  [ -29.42097223          -inf]]

 [[         -inf  -28.95542892]
  [         -inf  -29.360894  ]]

 [[         -inf  -30.44369092]
  [         -inf  -33.53473242]]

 [[ -30.96401255          -inf]
  [ -27.17415742          -inf]]

 [[ -52.21166931          -inf]
  [ -48.42181418          -inf]]

 [[ -49.34440336          -inf]
  [ -32.16159694          -inf]]

 [[ -29.54965664          -inf]
  [ -25.75980151          -inf]]

 [[ -31.19267729          -inf]
  [ -27.40282217          -inf]]

 [[ -30.46710146          -inf]
  [ -30.51472951          -inf]]

 [[ -30.28482057          -inf]
  [ -30.33244862          -inf]]

 [[ -40.08021879          -inf]
  [ -36.29036366          -inf]]

 [[ -35.75733402          -inf]
  [ -31.9674789           -inf]]

 [[       

[[[-28.47008362         -inf]
  [-25.98517789         -inf]]

 [[        -inf         -inf]
  [        -inf         -inf]]

 [[-45.04107456         -inf]
  [-29.14612234         -inf]]

 [[-46.76405284         -inf]
  [-28.25719439         -inf]]

 [[        -inf -31.67752402]
  [        -inf -34.56789546]]

 [[        -inf -29.23816976]
  [        -inf -32.1285412 ]]

 [[-46.98486616         -inf]
  [-28.47800771         -inf]]

 [[-61.78872496         -inf]
  [-43.28186651         -inf]]

 [[-46.37819993         -inf]
  [-27.87134148         -inf]]

 [[-45.61348198         -inf]
  [-28.66247716         -inf]]

 [[        -inf -30.71260417]
  [        -inf -29.08846014]]

 [[        -inf -29.24156798]
  [        -inf -27.61742395]]

 [[        -inf -29.62091194]
  [        -inf -27.99676791]]

 [[-55.80117353         -inf]
  [-37.29431508         -inf]]

 [[        -inf -38.09152228]
  [        -inf -36.46737825]]

 [[        -inf -29.60469848]
  [        -inf -32.49506993]]

 [[     

[[[ -46.29621377          -inf]
  [ -28.27601058          -inf]]

 [[         -inf          -inf]
  [         -inf          -inf]]

 [[ -33.27215348          -inf]
  [ -31.13208776          -inf]]

 [[ -28.94660434          -inf]
  [ -44.95933959          -inf]]

 [[         -inf  -33.61172848]
  [         -inf  -50.06629644]]

 [[         -inf  -30.74287297]
  [         -inf  -28.88115495]]

 [[ -45.14213407          -inf]
  [ -27.12193087          -inf]]

 [[ -54.3673089           -inf]
  [ -36.34710571          -inf]]

 [[ -44.1367633           -inf]
  [ -26.11656011          -inf]]

 [[ -45.1629046           -inf]
  [ -27.14270141          -inf]]

 [[         -inf  -29.70168657]
  [         -inf  -27.83996854]]

 [[         -inf  -26.88160479]
  [         -inf  -25.01988677]]

 [[         -inf  -28.32161468]
  [         -inf  -26.45989666]]

 [[ -51.98792166          -inf]
  [ -35.60746166          -inf]]

 [[         -inf -317.01417294]
  [         -inf -333.02690819]]

 [[       

[[[         -inf  -46.21563389]
  [         -inf  -29.10428643]]

 [[         -inf  -45.72763245]
  [         -inf  -45.57348178]]

 [[         -inf          -inf]
  [         -inf          -inf]]

 [[ -28.26210385          -inf]
  [ -25.1907343           -inf]]

 [[ -32.28243901          -inf]
  [ -30.51665509          -inf]]

 [[ -28.54516267          -inf]
  [ -44.15243286          -inf]]

 [[ -29.6658538           -inf]
  [ -26.59448425          -inf]]

 [[ -37.47273354          -inf]
  [ -34.40136398          -inf]]

 [[ -27.6776675           -inf]
  [ -24.60629795          -inf]]

 [[ -28.04396651          -inf]
  [ -24.97259695          -inf]]

 [[ -32.28658931          -inf]
  [ -30.52080539          -inf]]

 [[ -26.50882199          -inf]
  [ -24.74303807          -inf]]

 [[ -28.68120371          -inf]
  [ -26.91541979          -inf]]

 [[ -76.12162717          -inf]
  [ -59.82120988          -inf]]

 [[ -36.8658226           -inf]
  [ -33.79445305          -inf]]

 [[       

[[[-45.57134299         -inf]
  [-27.24497291         -inf]]

 [[        -inf         -inf]
  [        -inf         -inf]]

 [[-31.05670136         -inf]
  [-32.76144932         -inf]]

 [[-47.47717414         -inf]
  [-29.15080406         -inf]]

 [[-31.13567341         -inf]
  [-47.99570645         -inf]]

 [[-29.16207074         -inf]
  [-30.8668187          -inf]]

 [[-45.59432705         -inf]
  [-27.26795697         -inf]]

 [[-79.55844992         -inf]
  [-61.23207985         -inf]]

 [[-27.62975692         -inf]
  [-25.87025835         -inf]]

 [[-44.55137252         -inf]
  [-26.22500244         -inf]]

 [[-30.45766957         -inf]
  [-28.60528559         -inf]]

 [[-27.25912473         -inf]
  [-25.49962617         -inf]]

 [[-27.71159739         -inf]
  [-25.95209882         -inf]]

 [[-52.97806661         -inf]
  [-34.65169654         -inf]]

 [[-39.64420609         -inf]
  [-37.7918221          -inf]]

 [[        -inf -29.98077889]
  [        -inf -32.72161868]]

 [[-27.3

  [        -inf -32.1239958 ]]]
[[[-29.34822021         -inf]
  [-25.74408223         -inf]]

 [[-32.45194997         -inf]
  [-33.55056209         -inf]]

 [[-31.17828194         -inf]
  [-33.66318813         -inf]]

 [[-27.3438786          -inf]
  [-25.66212008         -inf]]

 [[-32.29310803         -inf]
  [-33.39172015         -inf]]

 [[-30.72663869         -inf]
  [-33.21154488         -inf]]

 [[-31.12528583         -inf]
  [-27.52114785         -inf]]

 [[-39.79538676         -inf]
  [-36.19124877         -inf]]

 [[-28.11382287         -inf]
  [-24.50968489         -inf]]

 [[-28.54881946         -inf]
  [-24.94468148         -inf]]

 [[-29.43983274         -inf]
  [-25.83569475         -inf]]

 [[-26.11032479         -inf]
  [-24.42856627         -inf]]

 [[-32.66459445         -inf]
  [-29.06045646         -inf]]

 [[-40.7896932          -inf]
  [-37.18555521         -inf]]

 [[-38.68316969         -inf]
  [-37.00919352         -inf]]

 [[        -inf -28.75533775]
  [     

NameError: name 'acc_test' is not defined

In [56]:
mul([1,2,3,4])

24

1.7976931348623157e+308