Reference paper: stability is stable 
Experiment: vary the value of rho and see the minimum sample size required for replicability

func getConvergenceSampleNum(hyperparams)
    For sample_num range(min_subsets_size, max_subsets_size, step):
        For range(repeat_num):
            random draw a subset of the dataset, whose size = sample_num
            Use algorithm 10 to get a model
        Check whether the models we get are 'replicable' to each other, accoding to the hyperparams. (see definition in page 14 of the paper)
        if converged, return current sample_num

func experiment(): 
# vary the rho and see minumum sample size required for replicability. 
# To make it simple, we can fix the value of other hyperparamers, such as alpha and beta.
    for rho in range(min_rho, max_rho, step):
        sample_num = getConvergenceSampleNum(rho)
        theoretical_sample_num = getTheoreticalSampleNum(rho) # compute according to algorithm 10 in the paper
        print("Hyperparams: ", hyperparams, "Sample size: ", sample_num)
    plot(hyperparams, theoretical_sample_num) # draw the curve, where x-axis is the hyperparams and y-axis is the sample size
    plot(hyperparams, sample_num) # draw the curve, where x-axis is the hyperparams and y-axis is the sample size

In [9]:
import pandas as pd
import numpy as np
import config
import Algorithm10 as a10
from sklearn.tree import export_text

In [10]:
data = pd.read_csv('dataset/Invistico_Airline.csv')
data.head()

Unnamed: 0,satisfaction,Customer Type,Age,Type of Travel,Class,Flight Distance,Seat comfort,Departure/Arrival time convenient,Food and drink,Gate location,...,Online support,Ease of Online booking,On-board service,Leg room service,Baggage handling,Checkin service,Cleanliness,Online boarding,Departure Delay in Minutes,Arrival Delay in Minutes
0,satisfied,Loyal Customer,65,Personal Travel,Eco,265,0,0,0,2,...,2,3,3,0,3,5,3,2,0,0.0
1,satisfied,Loyal Customer,47,Personal Travel,Business,2464,0,0,0,3,...,2,3,4,4,4,2,3,2,310,305.0
2,satisfied,Loyal Customer,15,Personal Travel,Eco,2138,0,0,0,3,...,2,2,3,3,4,4,4,2,0,0.0
3,satisfied,Loyal Customer,60,Personal Travel,Eco,623,0,0,0,3,...,3,1,1,0,1,4,1,3,0,0.0
4,satisfied,Loyal Customer,70,Personal Travel,Eco,354,0,0,0,3,...,4,2,2,0,2,4,2,5,0,0.0


In [11]:
data.shape

(129880, 22)

In [12]:
# function to check if two decision trees are equal
def are_trees_equal(tree1, tree2):
    # Check that both trees are fitted
    if not hasattr(tree1, 'tree_') or not hasattr(tree2, 'tree_'):
        raise ValueError("Both trees must be fitted before comparison.")

    # Compare parameters
    if tree1.get_params() != tree2.get_params():
        return False

    t1 = tree1.tree_
    t2 = tree2.tree_

    # Compare structure and splitting rules
    attributes_to_check = [
        'children_left', 'children_right',
        'feature', 'threshold',
        'impurity', 'n_node_samples', 'weighted_n_node_samples',
        'value'
    ]

    for attr in attributes_to_check:
        if not np.array_equal(getattr(t1, attr), getattr(t2, attr)):
            return False

    return True


In [13]:
def getConvergenceSampleNum(min_subset_size, max_subset_size, repeat_num, rho, sample_size_step=1):
    sample_size_replicablity_dict = {}
    
    X, y= a10.load_full_dataset(config.dataset_path, random_state=config.random_seed)
    for sample_size in range(min_subset_size, max_subset_size + 1, sample_size_step):
        #get dataset of size sample_size by sampling from the original dataset
        replicable_tree_list = []
        H = a10.build_candidate_trees(X, y,sample_size, max_depth=config.max_depth, num_trees=config.num_H, random_state=config.random_seed)    
        for i in range(repeat_num):
            print(f"sample size: {sample_size}, repeat: {i}")
            res_trees = a10.replicable_learner(X, y, H, sample_size, random_seed=config.random_seed+i)
            print(f"number of res_trees: {len(res_trees)}")
            # check whether the candidate trees below v are same to each other
            for a in range(len(res_trees)):
                for b in range(a + 1, len(res_trees)):
                    if are_trees_equal(res_trees[a], res_trees[b]):
                        print(f"tree {a} and tree {b} are the same")
                    else:
                        print(f"tree {a} and tree {b} are different")
                        export_text(res_trees[a])
                        export_text(res_trees[b])
                
            tree = res_trees[0]
            replicable_tree_list.append(tree)
            
        #check the probability if the trees in the replicable_tree_list are the same
        same_tree_count = 0
        for i in range(len(replicable_tree_list)):
            for j in range(i + 1, len(replicable_tree_list)):
                
                if are_trees_equal(replicable_tree_list[i], replicable_tree_list[j]):
                    # print(f"tree {i} and tree {j} are the same")
                    same_tree_count += 1
        prob = same_tree_count / (repeat_num * (repeat_num - 1) / 2)
        sample_size_replicablity_dict[sample_size] = prob
        if prob >= 1-rho:
            print(f"replicable at sample size: {sample_size}, prob: {prob}")
            return sample_size 
        
    # return sample_size_replicablity_dict
    print(f"not replicable at sample size between {min_subset_size} and {max_subset_size}, prob: {prob}")
    return -1

            
        
    

In [14]:
print("theoretical sample size: ", config.m_up_bound)
ans_dict = getConvergenceSampleNum(min_subset_size=100, max_subset_size=1000, repeat_num=10, rho=config.rho, sample_size_step=100)

theoretical sample size:  311347.6151224051


100%|██████████| 3/3 [00:00<00:00, 142.31it/s]

sample size: 100, repeat: 0
OPT error 0.27 errors dict_values([0.27, 0.31000000000000005, 0.28])
v: 0.32649109409973104
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 100, repeat: 1
OPT error 0.25 errors dict_values([0.28, 0.28, 0.25])
v: 0.2542531672891165
number of res_trees: 1
sample size: 100, repeat: 2
OPT error 0.17000000000000004 errors dict_values([0.32999999999999996, 0.26, 0.17000000000000004])
v: 0.20800454778602231
number of res_trees: 1
sample size: 100, repeat: 3
OPT error 0.26 errors dict_values([0.31999999999999995, 0.35, 0.26])
v: 0.284561455891711
number of res_trees: 1
sample size: 100, repeat: 4
OPT error 0.21999999999999997 errors dict_values([0.27, 0.30000000000000004, 0.21999999999999997])
v: 0.22789700662890505
number of res_trees: 1
sample size: 100, repeat: 5
OPT error 0.22999999999999998 errors dict_values([0.31000000000000005, 0.31000000000000005, 0.22999999999999998])
v: 0.




OPT error 0.19999999999999996 errors dict_values([0.19999999999999996, 0.25, 0.20999999999999996])
v: 0.22140299846068692
number of res_trees: 2
tree 0 and tree 1 are different


100%|██████████| 3/3 [00:00<00:00, 133.43it/s]


sample size: 200, repeat: 0
OPT error 0.24 errors dict_values([0.24, 0.30000000000000004, 0.31000000000000005])
v: 0.296491094099731
number of res_trees: 1
sample size: 200, repeat: 1
OPT error 0.22999999999999998 errors dict_values([0.24, 0.22999999999999998, 0.265])
v: 0.23425316728911655
number of res_trees: 1
sample size: 200, repeat: 2
OPT error 0.20999999999999996 errors dict_values([0.21499999999999997, 0.28, 0.20999999999999996])
v: 0.24800454778602224
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 200, repeat: 3
OPT error 0.255 errors dict_values([0.255, 0.29000000000000004, 0.27])
v: 0.279561455891711
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 200, repeat: 4
OPT error 0.25 errors dict_values([0.25, 0.29000000000000004, 0.26])
v: 0.2578970066289051
number of res_trees: 1
sample size: 200, repeat: 5
OPT error 0.19499999999999995 errors dict_values([0.19499999999999995, 0.28500000000000003, 0.25])
v: 0.22655820580695715
number of res

100%|██████████| 3/3 [00:00<00:00, 147.54it/s]


sample size: 300, repeat: 0
OPT error 0.22333333333333338 errors dict_values([0.23333333333333328, 0.31666666666666665, 0.22333333333333338])
v: 0.2798244274330644
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 300, repeat: 1
OPT error 0.24 errors dict_values([0.2766666666666666, 0.26, 0.24])
v: 0.24425316728911656
number of res_trees: 1
sample size: 300, repeat: 2
OPT error 0.20666666666666667 errors dict_values([0.23333333333333328, 0.2766666666666666, 0.20666666666666667])
v: 0.24467121445268894
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 300, repeat: 3
OPT error 0.23333333333333328 errors dict_values([0.2566666666666667, 0.29666666666666663, 0.23333333333333328])
v: 0.2578947892250443
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 300, repeat: 4
OPT error 0.21999999999999997 errors dict_values([0.2466666666666667, 0.30000000000000004, 0.21999999999999997])
v: 0.22789700662890505
number of res_trees: 1
sample size: 30

100%|██████████| 3/3 [00:00<00:00, 156.79it/s]

sample size: 400, repeat: 0
OPT error 0.21999999999999997 errors dict_values([0.21999999999999997, 0.30000000000000004, 0.21999999999999997])
v: 0.276491094099731
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 400, repeat: 1





OPT error 0.245 errors dict_values([0.25, 0.265, 0.245])
v: 0.24925316728911656
number of res_trees: 1
sample size: 400, repeat: 2
OPT error 0.21250000000000002 errors dict_values([0.21750000000000003, 0.29500000000000004, 0.21250000000000002])
v: 0.25050454778602227
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 400, repeat: 3
OPT error 0.22499999999999998 errors dict_values([0.24750000000000005, 0.3075, 0.22499999999999998])
v: 0.24956145589171094
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 400, repeat: 4
OPT error 0.235 errors dict_values([0.245, 0.31000000000000005, 0.235])
v: 0.24289700662890507
number of res_trees: 1
sample size: 400, repeat: 5
OPT error 0.21499999999999997 errors dict_values([0.21999999999999997, 0.3225, 0.21499999999999997])
v: 0.24655820580695717
number of res_trees: 2
tree 0 and tree 1 are different
sample size: 400, repeat: 6
OPT error 0.235 errors dict_values([0.255, 0.3375, 0.235])
v: 0.28281143989393775
number 

100%|██████████| 3/3 [00:00<00:00, 156.84it/s]


sample size: 500, repeat: 0
OPT error 0.23399999999999999 errors dict_values([0.252, 0.23399999999999999, 0.238])
v: 0.290491094099731
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 500, repeat: 1
OPT error 0.20599999999999996 errors dict_values([0.262, 0.20599999999999996, 0.236])
v: 0.21025316728911653
number of res_trees: 1
sample size: 500, repeat: 2
OPT error 0.22399999999999998 errors dict_values([0.244, 0.22799999999999998, 0.22399999999999998])
v: 0.2620045477860222
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 500, repeat: 3
OPT error 0.20999999999999996 errors dict_values([0.23399999999999999, 0.20999999999999996, 0.21199999999999997])
v: 0.23456145589171093
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 500, repeat: 4
OPT e

100%|██████████| 3/3 [00:00<00:00, 140.96it/s]


sample size: 600, repeat: 0
OPT error 0.22166666666666668 errors dict_values([0.2283333333333334, 0.22166666666666668, 0.2283333333333334])
v: 0.2781577607663977
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 600, repeat: 1
OPT error 0.20333333333333337 errors dict_values([0.2366666666666667, 0.20333333333333337, 0.2366666666666667])
v: 0.20758650062244993
number of res_trees: 1
sample size: 600, repeat: 2
OPT error 0.21833333333333338 errors dict_values([0.21833333333333338, 0.21833333333333338, 0.21833333333333338])
v: 0.2563378811193556
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 600, repeat: 3
OPT error 0.20499999999999996 errors dict_values([0.21333333333333337, 0.20499999999999996, 0.21333333333333337])
v: 0.22956145589171092
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are differen

100%|██████████| 3/3 [00:00<00:00, 138.54it/s]


sample size: 700, repeat: 0
OPT error 0.2242857142857143 errors dict_values([0.22571428571428576, 0.2242857142857143, 0.22571428571428576])
v: 0.28077680838544533
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 700, repeat: 1
OPT error 0.20714285714285718 errors dict_values([0.2371428571428571, 0.20714285714285718, 0.2371428571428571])
v: 0.21139602443197375
number of res_trees: 1
sample size: 700, repeat: 2
OPT error 0.2171428571428572 errors dict_values([0.22142857142857142, 0.2171428571428572, 0.22142857142857142])
v: 0.25514740492887944
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 700, repeat: 3
OPT error 0.20714285714285718 errors dict_values([0.2171428571428572, 0.20714285714285718, 0.2171428571428572])
v: 0.23170431303456815
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different


100%|██████████| 3/3 [00:00<00:00, 133.05it/s]

sample size: 800, repeat: 0
OPT error 0.21999999999999997 errors dict_values([0.22375, 0.21999999999999997, 0.22375])
v: 0.276491094099731
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 800, repeat: 1
OPT error 0.20875 errors dict_values([0.235, 0.20875, 0.235])
v: 0.21300316728911656
number of res_trees: 1
sample size: 800, repeat: 2
OPT error 0.21125000000000005 errors dict_values([0.22750000000000004, 0.21125000000000005, 0.22750000000000004])
v: 0.24925454778602232
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 800, repeat: 3





OPT error 0.20750000000000002 errors dict_values([0.21625000000000005, 0.20750000000000002, 0.21625000000000005])
v: 0.23206145589171098
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 800, repeat: 4
OPT error 0.21375 errors dict_values([0.24375000000000002, 0.21375, 0.24375000000000002])
v: 0.22164700662890507
number of res_trees: 1
sample size: 800, repeat: 5
OPT error 0.18625000000000003 errors dict_values([0.21125000000000005, 0.18625000000000003, 0.21125000000000005])
v: 0.21780820580695723
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 800, repeat: 6
OPT error 0.21499999999999997 errors dict_values([0.21499999999999997, 0.21875, 0.21499999999999997])
v: 0.26281143989393774
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 800, repeat

100%|██████████| 3/3 [00:00<00:00, 142.38it/s]

sample size: 900, repeat: 0
OPT error 0.21777777777777774 errors dict_values([0.2222222222222222, 0.21777777777777774, 0.2222222222222222])
v: 0.27426887187750876
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 900, repeat: 1
OPT error 0.20666666666666667 errors dict_values([0.23111111111111116, 0.20666666666666667, 0.23111111111111116])
v: 0.21091983395578323
number of res_trees: 1
sample size: 900, repeat: 2
OPT error 0.2155555555555555 errors dict_values([0.22333333333333338, 0.2155555555555555, 0.22333333333333338])
v: 0.25356010334157775
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 900, repeat: 3
OPT error 0.20777777777777773 errors dict_values([0.21111111111111114, 0.20777777777777773, 0.21111111111111114])
v: 0.2323392336694887
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are differe


100%|██████████| 3/3 [00:00<00:00, 138.22it/s]

sample size: 1000, repeat: 0
OPT error 0.21699999999999997 errors dict_values([0.22399999999999998, 0.21699999999999997, 0.22399999999999998])
v: 0.273491094099731
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 1000, repeat: 1
OPT error 0.21099999999999997 errors dict_values([0.23199999999999998, 0.21099999999999997, 0.23199999999999998])
v: 0.21525316728911653
number of res_trees: 1
sample size: 1000, repeat: 2
OPT error 0.21299999999999997 errors dict_values([0.22199999999999998, 0.21299999999999997, 0.22199999999999998])
v: 0.2510045477860222
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are different
tree 1 and tree 2 are different
sample size: 1000, repeat: 3
OPT error 0.21699999999999997 errors dict_values([0.21799999999999997, 0.21699999999999997, 0.21799999999999997])
v: 0.24156145589171094
number of res_trees: 3
tree 0 and tree 1 are different
tree 0 and tree 2 are d




In [15]:
ans_dict

-1