Reference paper: stability is stable 
Experiment: vary the value of rho and see the minimum sample size required for replicability

func getConvergenceSampleNum(hyperparams)
    For sample_num range(min_subsets_size, max_subsets_size, step):
        For range(repeat_num):
            random draw a subset of the dataset, whose size = sample_num
            Use algorithm 10 to get a model
        Check whether the models we get are 'replicable' to each other, accoding to the hyperparams. (see definition in page 14 of the paper)
        if converged, return current sample_num

func experiment(): 
# vary the rho and see minumum sample size required for replicability. 
# To make it simple, we can fix the value of other hyperparamers, such as alpha and beta.
    for rho in range(min_rho, max_rho, step):
        sample_num = getConvergenceSampleNum(rho)
        theoretical_sample_num = getTheoreticalSampleNum(rho) # compute according to algorithm 10 in the paper
        print("Hyperparams: ", hyperparams, "Sample size: ", sample_num)
    plot(hyperparams, theoretical_sample_num) # draw the curve, where x-axis is the hyperparams and y-axis is the sample size
    plot(hyperparams, sample_num) # draw the curve, where x-axis is the hyperparams and y-axis is the sample size

In [1]:
import pandas as pd
import numpy as np
import config
import Algorithm10 as a10
from sklearn.tree import export_text

##### Config Variables #####
np = <module 'numpy' from '/opt/anaconda3/envs/myenv/lib/python3.13/site-packages/numpy/__init__.py'>
dataset_path = ./dataset/Invistico_Airline.csv
model_path = ./models/
max_depth = 3
random_seed = 42
selected_features = ['Class', 'Seat comfort', 'Food and drink', 'Cleanliness', 'satisfaction']
rho = 0.3
alpha = 0.3
beta = 0.2
num_H = 10
m_up_bound = 8954.993893610603
tau_up_bound = 0.03908650337129266
tau = 0.003908650337129266
m = 100
############################


In [2]:
data = pd.read_csv('dataset/Invistico_Airline.csv')
data.head()

Unnamed: 0,satisfaction,Customer Type,Age,Type of Travel,Class,Flight Distance,Seat comfort,Departure/Arrival time convenient,Food and drink,Gate location,...,Online support,Ease of Online booking,On-board service,Leg room service,Baggage handling,Checkin service,Cleanliness,Online boarding,Departure Delay in Minutes,Arrival Delay in Minutes
0,satisfied,Loyal Customer,65,Personal Travel,Eco,265,0,0,0,2,...,2,3,3,0,3,5,3,2,0,0.0
1,satisfied,Loyal Customer,47,Personal Travel,Business,2464,0,0,0,3,...,2,3,4,4,4,2,3,2,310,305.0
2,satisfied,Loyal Customer,15,Personal Travel,Eco,2138,0,0,0,3,...,2,2,3,3,4,4,4,2,0,0.0
3,satisfied,Loyal Customer,60,Personal Travel,Eco,623,0,0,0,3,...,3,1,1,0,1,4,1,3,0,0.0
4,satisfied,Loyal Customer,70,Personal Travel,Eco,354,0,0,0,3,...,4,2,2,0,2,4,2,5,0,0.0


In [3]:
data.shape

(129880, 22)

In [4]:
# function to check if two decision trees are equal
def are_trees_equal(tree1, tree2):
    # Check that both trees are fitted
    if not hasattr(tree1, 'tree_') or not hasattr(tree2, 'tree_'):
        raise ValueError("Both trees must be fitted before comparison.")

    # Compare parameters
    if tree1.get_params() != tree2.get_params():
        return False

    t1 = tree1.tree_
    t2 = tree2.tree_

    # Compare structure and splitting rules
    attributes_to_check = [
        'children_left', 'children_right',
        'feature', 'threshold',
        'impurity', 'n_node_samples', 'weighted_n_node_samples',
        'value'
    ]

    for attr in attributes_to_check:
        if not np.array_equal(getattr(t1, attr), getattr(t2, attr)):
            return False

    return True


In [5]:
def getConvergenceSampleNum(min_subset_size, max_subset_size, repeat_num, rho, sample_size_step=1):
    sample_size_replicablity_dict = {}
    
    X, y= a10.load_full_dataset(config.dataset_path, random_state=config.random_seed)
    for sample_size in range(min_subset_size, max_subset_size + 1, sample_size_step):
        #get dataset of size sample_size by sampling from the original dataset
        replicable_tree_list = []
        H = a10.build_candidate_trees(X, y,sample_size, max_depth=config.max_depth, num_trees=config.num_H, random_state=config.random_seed)    
        for i in range(repeat_num):
            print(f"sample size: {sample_size}, repeat: {i}")
            tree = a10.replicable_learner(X, y, H, sample_size, random_seed=config.random_seed+i)
            replicable_tree_list.append(tree)
            
        #check the probability if the trees in the replicable_tree_list are the same
        same_tree_count = 0
        for i in range(len(replicable_tree_list)):
            for j in range(i + 1, len(replicable_tree_list)):
                
                if are_trees_equal(replicable_tree_list[i], replicable_tree_list[j]):
                    # print(f"tree {i} and tree {j} are the same")
                    same_tree_count += 1
        prob = same_tree_count / (repeat_num * (repeat_num - 1) / 2)
        sample_size_replicablity_dict[sample_size] = prob
        if prob >= 1-rho:
            print(f"replicable at sample size: {sample_size}, prob: {prob}")
            return sample_size 
        
    # return sample_size_replicablity_dict
    print(f"not replicable at sample size between {min_subset_size} and {max_subset_size}, prob: {prob}")
    return -1

            
        
    

In [10]:
print("theoretical sample size: ", config.m_up_bound)
ans_dict = getConvergenceSampleNum(min_subset_size=50000, max_subset_size=60000, repeat_num=10, rho=config.rho, sample_size_step=1000)

theoretical sample size:  8954.993893610603


100%|██████████| 10/10 [00:00<00:00, 68.15it/s]


sample size: 50000, repeat: 0
OPT error 0.21499999999999997 errors dict_values([0.23746, 0.23746, 0.23746, 0.23746, 0.23746, 0.23746, 0.21499999999999997, 0.23746, 0.23746, 0.23746])
v: 0.229412249360853
sample size: 50000, repeat: 1
OPT error 0.21282 errors dict_values([0.23540000000000005, 0.23540000000000005, 0.23540000000000005, 0.23540000000000005, 0.23540000000000005, 0.23540000000000005, 0.21282, 0.23540000000000005, 0.23540000000000005, 0.23540000000000005])
v: 0.21890782954054774
sample size: 50000, repeat: 2
OPT error 0.21721999999999997 errors dict_values([0.23919999999999997, 0.23919999999999997, 0.23919999999999997, 0.23919999999999997, 0.23919999999999997, 0.23919999999999997, 0.21721999999999997, 0.23919999999999997, 0.23919999999999997, 0.23919999999999997])
v: 0.2716183325741609
sample size: 50000, repeat: 3
OPT error 0.21555999999999997 errors dict_values([0.23802, 0.23802, 0.23802, 0.23802, 0.23802, 0.23802, 0.21555999999999997, 0.23802, 0.23802, 0.23802])
v: 0.25071

100%|██████████| 10/10 [00:00<00:00, 60.39it/s]

sample size: 51000, repeat: 0
OPT error 0.21490196078431367 errors dict_values([0.23733333333333329, 0.23733333333333329, 0.23733333333333329, 0.23733333333333329, 0.23733333333333329, 0.23733333333333329, 0.21490196078431367, 0.23733333333333329, 0.23733333333333329, 0.23733333333333329])
v: 0.2293142101451667
sample size: 51000, repeat: 1





OPT error 0.2131764705882353 errors dict_values([0.2356078431372549, 0.2356078431372549, 0.2356078431372549, 0.2356078431372549, 0.2356078431372549, 0.2356078431372549, 0.2131764705882353, 0.2356078431372549, 0.2356078431372549, 0.2356078431372549])
v: 0.21926430012878304
sample size: 51000, repeat: 2
OPT error 0.21713725490196079 errors dict_values([0.2393725490196078, 0.2393725490196078, 0.2393725490196078, 0.2393725490196078, 0.2393725490196078, 0.2393725490196078, 0.21713725490196079, 0.2393725490196078, 0.2393725490196078, 0.2393725490196078])
v: 0.2715355874761217
sample size: 51000, repeat: 3
OPT error 0.21592156862745093 errors dict_values([0.23831372549019603, 0.23831372549019603, 0.23831372549019603, 0.23831372549019603, 0.23831372549019603, 0.23831372549019603, 0.21592156862745093, 0.23831372549019603, 0.23831372549019603, 0.23831372549019603])
v: 0.25107794658583527
sample size: 51000, repeat: 4
OPT error 0.2127450980392157 errors dict_values([0.23529411764705888, 0.2352941

100%|██████████| 10/10 [00:00<00:00, 68.40it/s]

sample size: 52000, repeat: 0
OPT error 0.2146538461538462 errors dict_values([0.23721153846153842, 0.23721153846153842, 0.23721153846153842, 0.23721153846153842, 0.23721153846153842, 0.23721153846153842, 0.2146538461538462, 0.23721153846153842, 0.23721153846153842, 0.23721153846153842])
v: 0.22906609551469923
sample size: 52000, repeat: 1





OPT error 0.21344230769230765 errors dict_values([0.23598076923076927, 0.23598076923076927, 0.23598076923076927, 0.23598076923076927, 0.23598076923076927, 0.23598076923076927, 0.21344230769230765, 0.23598076923076927, 0.23598076923076927, 0.23598076923076927])
v: 0.2195301372328554
sample size: 52000, repeat: 2
OPT error 0.21686538461538463 errors dict_values([0.23907692307692308, 0.23907692307692308, 0.23907692307692308, 0.23907692307692308, 0.23907692307692308, 0.23907692307692308, 0.21686538461538463, 0.23907692307692308, 0.23907692307692308, 0.23907692307692308])
v: 0.27126371718954556
sample size: 52000, repeat: 3
OPT error 0.21607692307692306 errors dict_values([0.23848076923076922, 0.23848076923076922, 0.23848076923076922, 0.23848076923076922, 0.23848076923076922, 0.23848076923076922, 0.21607692307692306, 0.23848076923076922, 0.23848076923076922, 0.23848076923076922])
v: 0.2512333010353074
sample size: 52000, repeat: 4
OPT error 0.21296153846153842 errors dict_values([0.23550000

100%|██████████| 10/10 [00:00<00:00, 67.27it/s]

sample size: 53000, repeat: 0
OPT error 0.21456603773584904 errors dict_values([0.23715094339622644, 0.23715094339622644, 0.23715094339622644, 0.23715094339622644, 0.23715094339622644, 0.23715094339622644, 0.21456603773584904, 0.23715094339622644, 0.23715094339622644, 0.23715094339622644])
v: 0.22897828709670207
sample size: 53000, repeat: 1
OPT error 0.21411320754716978 errors dict_values([0.23667924528301887, 0.23667924528301887, 0.23667924528301887, 0.23667924528301887, 0.23667924528301887, 0.23667924528301887, 0.21411320754716978, 0.23667924528301887, 0.23667924528301887, 0.23667924528301887])
v: 0.22020103708771752
sample size: 53000, repeat: 2





OPT error 0.21705660377358493 errors dict_values([0.23926415094339626, 0.23926415094339626, 0.23926415094339626, 0.23926415094339626, 0.23926415094339626, 0.23926415094339626, 0.21705660377358493, 0.23926415094339626, 0.23926415094339626, 0.23926415094339626])
v: 0.27145493634774587
sample size: 53000, repeat: 3
OPT error 0.21647169811320754 errors dict_values([0.23892452830188682, 0.23892452830188682, 0.23892452830188682, 0.23892452830188682, 0.23892452830188682, 0.23892452830188682, 0.21647169811320754, 0.23892452830188682, 0.23892452830188682, 0.23892452830188682])
v: 0.2516280760715919
sample size: 53000, repeat: 4
OPT error 0.2135283018867925 errors dict_values([0.23596226415094335, 0.23596226415094335, 0.23596226415094335, 0.23596226415094335, 0.23596226415094335, 0.23596226415094335, 0.2135283018867925, 0.23596226415094335, 0.23596226415094335, 0.23596226415094335])
v: 0.22483179102073053
sample size: 53000, repeat: 5
OPT error 0.21577358490566034 errors dict_values([0.238566037

100%|██████████| 10/10 [00:00<00:00, 67.29it/s]

sample size: 54000, repeat: 0
OPT error 0.21429629629629632 errors dict_values([0.2367407407407407, 0.2367407407407407, 0.2367407407407407, 0.2367407407407407, 0.2367407407407407, 0.2367407407407407, 0.21429629629629632, 0.2367407407407407, 0.2367407407407407, 0.2367407407407407])
v: 0.22870854565714935
sample size: 54000, repeat: 1





OPT error 0.2143518518518519 errors dict_values([0.23701851851851852, 0.23701851851851852, 0.23701851851851852, 0.23701851851851852, 0.23701851851851852, 0.23701851851851852, 0.2143518518518519, 0.23701851851851852, 0.23701851851851852, 0.23701851851851852])
v: 0.22043968139239964
sample size: 54000, repeat: 2
OPT error 0.21718518518518515 errors dict_values([0.23942592592592593, 0.23942592592592593, 0.23942592592592593, 0.23942592592592593, 0.23942592592592593, 0.23942592592592593, 0.21718518518518515, 0.23942592592592593, 0.23942592592592593, 0.23942592592592593])
v: 0.2715835177593461
sample size: 54000, repeat: 3
OPT error 0.21659259259259256 errors dict_values([0.2391481481481481, 0.2391481481481481, 0.2391481481481481, 0.2391481481481481, 0.2391481481481481, 0.2391481481481481, 0.21659259259259256, 0.2391481481481481, 0.2391481481481481, 0.2391481481481481])
v: 0.2517489705509769
sample size: 54000, repeat: 4
OPT error 0.2136296296296296 errors dict_values([0.23596296296296293, 0

100%|██████████| 10/10 [00:00<00:00, 65.62it/s]

sample size: 55000, repeat: 0
OPT error 0.2143454545454545 errors dict_values([0.23672727272727268, 0.23672727272727268, 0.23672727272727268, 0.23672727272727268, 0.23672727272727268, 0.23672727272727268, 0.2143454545454545, 0.23672727272727268, 0.23672727272727268, 0.23672727272727268])
v: 0.22875770390630754
sample size: 55000, repeat: 1





OPT error 0.21410909090909092 errors dict_values([0.23694545454545457, 0.23694545454545457, 0.23694545454545457, 0.23694545454545457, 0.23694545454545457, 0.23694545454545457, 0.21410909090909092, 0.23694545454545457, 0.23694545454545457, 0.23694545454545457])
v: 0.22019692044963865
sample size: 55000, repeat: 2
OPT error 0.2173272727272727 errors dict_values([0.23961818181818184, 0.23961818181818184, 0.23961818181818184, 0.23961818181818184, 0.23961818181818184, 0.23961818181818184, 0.2173272727272727, 0.23961818181818184, 0.23961818181818184, 0.23961818181818184])
v: 0.27172560530143364
sample size: 55000, repeat: 3
OPT error 0.2163090909090909 errors dict_values([0.23883636363636362, 0.23883636363636362, 0.23883636363636362, 0.23883636363636362, 0.23883636363636362, 0.23883636363636362, 0.2163090909090909, 0.23883636363636362, 0.23883636363636362, 0.23883636363636362])
v: 0.25146546886747523
sample size: 55000, repeat: 4
OPT error 0.21372727272727277 errors dict_values([0.2360545454

100%|██████████| 10/10 [00:00<00:00, 56.02it/s]

sample size: 56000, repeat: 0





OPT error 0.21441071428571423 errors dict_values([0.23687499999999995, 0.23687499999999995, 0.23687499999999995, 0.23687499999999995, 0.23687499999999995, 0.23687499999999995, 0.21441071428571423, 0.23687499999999995, 0.23687499999999995, 0.23687499999999995])
v: 0.22882296364656726
sample size: 56000, repeat: 1
OPT error 0.21410714285714283 errors dict_values([0.23680357142857145, 0.23680357142857145, 0.23680357142857145, 0.23680357142857145, 0.23680357142857145, 0.23680357142857145, 0.21410714285714283, 0.23680357142857145, 0.23680357142857145, 0.23680357142857145])
v: 0.22019497239769056
sample size: 56000, repeat: 2
OPT error 0.2169821428571429 errors dict_values([0.23939285714285718, 0.23939285714285718, 0.23939285714285718, 0.23939285714285718, 0.23939285714285718, 0.23939285714285718, 0.2169821428571429, 0.23939285714285718, 0.23939285714285718, 0.23939285714285718])
v: 0.27138047543130384
sample size: 56000, repeat: 3
OPT error 0.21607142857142858 errors dict_values([0.23864285

100%|██████████| 10/10 [00:00<00:00, 65.16it/s]

sample size: 57000, repeat: 0





OPT error 0.2142982456140351 errors dict_values([0.23673684210526313, 0.23673684210526313, 0.23673684210526313, 0.23673684210526313, 0.23673684210526313, 0.23673684210526313, 0.2142982456140351, 0.23673684210526313, 0.23673684210526313, 0.23673684210526313])
v: 0.22871049497488813
sample size: 57000, repeat: 1
OPT error 0.2136842105263158 errors dict_values([0.23635087719298242, 0.23635087719298242, 0.23635087719298242, 0.23635087719298242, 0.23635087719298242, 0.23635087719298242, 0.2136842105263158, 0.23635087719298242, 0.23635087719298242, 0.23635087719298242])
v: 0.21977204006686354
sample size: 57000, repeat: 2
OPT error 0.21738596491228068 errors dict_values([0.23973684210526314, 0.23973684210526314, 0.23973684210526314, 0.23973684210526314, 0.23973684210526314, 0.23973684210526314, 0.21738596491228068, 0.23973684210526314, 0.23973684210526314, 0.23973684210526314])
v: 0.2717842974864416
sample size: 57000, repeat: 3
OPT error 0.21592982456140353 errors dict_values([0.23842105263

100%|██████████| 10/10 [00:00<00:00, 64.33it/s]


sample size: 58000, repeat: 0
OPT error 0.21377586206896548 errors dict_values([0.23627586206896556, 0.23627586206896556, 0.23627586206896556, 0.23627586206896556, 0.23627586206896556, 0.23627586206896556, 0.21377586206896548, 0.23627586206896556, 0.23627586206896556, 0.23627586206896556])
v: 0.2281881114298185
sample size: 58000, repeat: 1
OPT error 0.21393103448275863 errors dict_values([0.2365172413793103, 0.2365172413793103, 0.2365172413793103, 0.2365172413793103, 0.2365172413793103, 0.2365172413793103, 0.21393103448275863, 0.2365172413793103, 0.2365172413793103, 0.2365172413793103])
v: 0.22001886402330637
sample size: 58000, repeat: 2
OPT error 0.21705172413793106 errors dict_values([0.23944827586206896, 0.23944827586206896, 0.23944827586206896, 0.23944827586206896, 0.23944827586206896, 0.23944827586206896, 0.21705172413793106, 0.23944827586206896, 0.23944827586206896, 0.23944827586206896])
v: 0.271450056712092
sample size: 58000, repeat: 3
OPT error 0.21574137931034487 errors dic

100%|██████████| 10/10 [00:00<00:00, 64.42it/s]

sample size: 59000, repeat: 0





OPT error 0.21364406779661016 errors dict_values([0.23625423728813555, 0.23625423728813555, 0.23625423728813555, 0.23625423728813555, 0.23625423728813555, 0.23625423728813555, 0.21364406779661016, 0.23625423728813555, 0.23625423728813555, 0.23625423728813555])
v: 0.2280563171574632
sample size: 59000, repeat: 1
OPT error 0.2137457627118644 errors dict_values([0.23650847457627122, 0.23650847457627122, 0.23650847457627122, 0.23650847457627122, 0.23650847457627122, 0.23650847457627122, 0.2137457627118644, 0.23650847457627122, 0.23650847457627122, 0.23650847457627122])
v: 0.21983359225241214
sample size: 59000, repeat: 2
OPT error 0.2169152542372882 errors dict_values([0.23954237288135594, 0.23954237288135594, 0.23954237288135594, 0.23954237288135594, 0.23954237288135594, 0.23954237288135594, 0.2169152542372882, 0.23954237288135594, 0.23954237288135594, 0.23954237288135594])
v: 0.2713135868114491
sample size: 59000, repeat: 3
OPT error 0.2159491525423729 errors dict_values([0.2386440677966

100%|██████████| 10/10 [00:00<00:00, 60.98it/s]

sample size: 60000, repeat: 0





OPT error 0.21376666666666666 errors dict_values([0.2364666666666667, 0.2364666666666667, 0.2364666666666667, 0.2364666666666667, 0.2364666666666667, 0.2364666666666667, 0.21376666666666666, 0.2364666666666667, 0.2364666666666667, 0.2364666666666667])
v: 0.2281789160275197
sample size: 60000, repeat: 1
OPT error 0.21348333333333336 errors dict_values([0.2360833333333333, 0.2360833333333333, 0.2360833333333333, 0.2360833333333333, 0.2360833333333333, 0.2360833333333333, 0.21348333333333336, 0.2360833333333333, 0.2360833333333333, 0.2360833333333333])
v: 0.2195711628738811
sample size: 60000, repeat: 2
OPT error 0.21696666666666664 errors dict_values([0.23965000000000003, 0.23965000000000003, 0.23965000000000003, 0.23965000000000003, 0.23965000000000003, 0.23965000000000003, 0.21696666666666664, 0.23965000000000003, 0.23965000000000003, 0.23965000000000003])
v: 0.2713649992408276
sample size: 60000, repeat: 3
OPT error 0.21606666666666663 errors dict_values([0.23876666666666668, 0.238766

In [7]:
ans_dict

-1