Reference paper: stability is stable 
Experiment: vary the value of rho and see the minimum sample size required for replicability

func getConvergenceSampleNum(hyperparams)
    For sample_num range(min_subsets_size, max_subsets_size, step):
        For range(repeat_num):
            random draw a subset of the dataset, whose size = sample_num
            Use algorithm 10 to get a model
        Check whether the models we get are 'replicable' to each other, accoding to the hyperparams. (see definition in page 14 of the paper)
        if converged, return current sample_num

func experiment(): 
# vary the rho and see minumum sample size required for replicability. 
# To make it simple, we can fix the value of other hyperparamers, such as alpha and beta.
    for rho in range(min_rho, max_rho, step):
        sample_num = getConvergenceSampleNum(rho)
        theoretical_sample_num = getTheoreticalSampleNum(rho) # compute according to algorithm 10 in the paper
        print("Hyperparams: ", hyperparams, "Sample size: ", sample_num)
    plot(hyperparams, theoretical_sample_num) # draw the curve, where x-axis is the hyperparams and y-axis is the sample size
    plot(hyperparams, sample_num) # draw the curve, where x-axis is the hyperparams and y-axis is the sample size

In [1]:
import pandas as pd
import numpy as np
import config
import Algorithm10 as a10
from sklearn.tree import export_text

##### Config Variables #####
np = <module 'numpy' from '/opt/anaconda3/envs/learn/lib/python3.12/site-packages/numpy/__init__.py'>
dataset_path = ./dataset/Invistico_Airline.csv
model_path = ./models/
max_depth = 3
random_seed = 42
selected_features = ['Class', 'Seat comfort', 'Food and drink', 'Cleanliness', 'satisfaction']
rho = 0.9
alpha = 0.3
beta = 0.2
num_H = 10
m = 100
m_up_bound = 31.537459194175696
tau_up_bound = 0.11725951011387799
tau = 1.1725951011387799e-05
############################


In [2]:
data = pd.read_csv('dataset/Invistico_Airline.csv')
data.head()

Unnamed: 0,satisfaction,Customer Type,Age,Type of Travel,Class,Flight Distance,Seat comfort,Departure/Arrival time convenient,Food and drink,Gate location,...,Online support,Ease of Online booking,On-board service,Leg room service,Baggage handling,Checkin service,Cleanliness,Online boarding,Departure Delay in Minutes,Arrival Delay in Minutes
0,satisfied,Loyal Customer,65,Personal Travel,Eco,265,0,0,0,2,...,2,3,3,0,3,5,3,2,0,0.0
1,satisfied,Loyal Customer,47,Personal Travel,Business,2464,0,0,0,3,...,2,3,4,4,4,2,3,2,310,305.0
2,satisfied,Loyal Customer,15,Personal Travel,Eco,2138,0,0,0,3,...,2,2,3,3,4,4,4,2,0,0.0
3,satisfied,Loyal Customer,60,Personal Travel,Eco,623,0,0,0,3,...,3,1,1,0,1,4,1,3,0,0.0
4,satisfied,Loyal Customer,70,Personal Travel,Eco,354,0,0,0,3,...,4,2,2,0,2,4,2,5,0,0.0


In [3]:
data.shape

(129880, 22)

In [4]:
# function to check if two decision trees are equal
def are_trees_equal(tree1, tree2):
    # Check that both trees are fitted
    if not hasattr(tree1, 'tree_') or not hasattr(tree2, 'tree_'):
        raise ValueError("Both trees must be fitted before comparison.")

    # Compare parameters
    if tree1.get_params() != tree2.get_params():
        return False

    t1 = tree1.tree_
    t2 = tree2.tree_

    # Compare structure and splitting rules
    attributes_to_check = [
        'children_left', 'children_right',
        'feature', 'threshold',
        'impurity', 'n_node_samples', 'weighted_n_node_samples',
        'value'
    ]

    for attr in attributes_to_check:
        if not np.array_equal(getattr(t1, attr), getattr(t2, attr)):
            return False

    return True


In [5]:
def getConvergenceSampleNum(min_subset_size, max_subset_size, repeat_num, rho, sample_size_step=1):
    sample_size_replicablity_dict = {}
    
    X, y= a10.load_full_dataset(config.dataset_path, random_state=config.random_seed)
    for sample_size in range(min_subset_size, max_subset_size + 1, sample_size_step):
        #get dataset of size sample_size by sampling from the original dataset
        replicable_tree_list = []
        H = a10.build_candidate_trees(X, y,sample_size, max_depth=config.max_depth, num_trees=config.num_H, random_state=config.random_seed)    
        for i in range(repeat_num):
            print(f"sample size: {sample_size}, repeat: {i}")
            tree = a10.replicable_learner(X, y, H, sample_size, random_seed=config.random_seed+i)
            replicable_tree_list.append(tree)
            
        #check the probability if the trees in the replicable_tree_list are the same
        same_tree_count = 0
        for i in range(len(replicable_tree_list)):
            for j in range(i + 1, len(replicable_tree_list)):
                
                if are_trees_equal(replicable_tree_list[i], replicable_tree_list[j]):
                    # print(f"tree {i} and tree {j} are the same")
                    same_tree_count += 1
        prob = same_tree_count / (repeat_num * (repeat_num - 1) / 2)
        sample_size_replicablity_dict[sample_size] = prob
        if prob >= 1-rho:
            print(f"replicable at sample size: {sample_size}, prob: {prob}")
            return sample_size 
        
    # return sample_size_replicablity_dict
    print(f"not replicable at sample size between {min_subset_size} and {max_subset_size}, prob: {prob}")
    return -1

            
        
    

In [13]:
print("theoretical sample size: ", config.m_up_bound)
ans_dict = getConvergenceSampleNum(min_subset_size=10, max_subset_size=100, repeat_num=10, rho=config.rho, sample_size_step=10)

theoretical sample size:  31.537459194175696


100%|██████████| 10/10 [00:00<00:00, 167.40it/s]

sample size: 10, repeat: 0





OPT error 0.19999999999999996 errors dict_values([0.19999999999999996, 0.5, 0.5, 0.4, 0.30000000000000004, 0.7, 0.4, 0.5, 0.4, 0.5])
v: 0.26142859029269766
sample size: 10, repeat: 1
OPT error 0.0 errors dict_values([0.4, 0.0, 0.19999999999999996, 0.7, 0.19999999999999996, 0.5, 0.09999999999999998, 0.30000000000000004, 0.5, 0.30000000000000004])
v: 0.0037002121061974124
sample size: 10, repeat: 2
OPT error 0.0 errors dict_values([0.19999999999999996, 0.4, 0.0, 0.09999999999999998, 0.19999999999999996, 0.19999999999999996, 0.30000000000000004, 0.5, 0.19999999999999996, 0.4])
v: 0.039245789718678016
sample size: 10, repeat: 3
OPT error 0.0 errors dict_values([0.19999999999999996, 0.8, 0.4, 0.0, 0.5, 0.30000000000000004, 0.19999999999999996, 0.5, 0.6, 0.4])
v: 0.02612535442814468
sample size: 10, repeat: 4
OPT error 0.0 errors dict_values([0.09999999999999998, 0.19999999999999996, 0.30000000000000004, 0.4, 0.0, 0.7, 0.30000000000000004, 0.30000000000000004, 0.19999999999999996, 0.30000000

100%|██████████| 10/10 [00:00<00:00, 143.18it/s]


sample size: 20, repeat: 0
OPT error 0.09999999999999998 errors dict_values([0.09999999999999998, 0.5, 0.5, 0.35, 0.44999999999999996, 0.30000000000000004, 0.25, 0.25, 0.5, 0.35])
v: 0.16142859029269768
sample size: 20, repeat: 1
OPT error 0.050000000000000044 errors dict_values([0.5, 0.050000000000000044, 0.25, 0.44999999999999996, 0.55, 0.44999999999999996, 0.25, 0.30000000000000004, 0.5, 0.25])
v: 0.05370021210619746
sample size: 20, repeat: 2
OPT error 0.050000000000000044 errors dict_values([0.35, 0.35, 0.050000000000000044, 0.25, 0.44999999999999996, 0.35, 0.050000000000000044, 0.09999999999999998, 0.30000000000000004, 0.25])
v: 0.08924578971867805
sample size: 20, repeat: 3
OPT error 0.0 errors dict_values([0.30000000000000004, 0.4, 0.35, 0.0, 0.30000000000000004, 0.4, 0.35, 0.30000000000000004, 0.44999999999999996, 0.35])
v: 0.02612535442814468
sample size: 20, repeat: 4
OPT error 0.09999999999999998 errors dict_values([0.4, 0.30000000000000004, 0.4, 0.35, 0.09999999999999998, 

100%|██████████| 10/10 [00:00<00:00, 111.46it/s]


sample size: 30, repeat: 0
OPT error 0.1333333333333333 errors dict_values([0.1333333333333333, 0.2666666666666667, 0.2666666666666667, 0.5, 0.4, 0.3666666666666667, 0.30000000000000004, 0.2666666666666667, 0.3666666666666667, 0.30000000000000004])
v: 0.194761923626031
sample size: 30, repeat: 1
OPT error 0.1333333333333333 errors dict_values([0.43333333333333335, 0.1333333333333333, 0.1333333333333333, 0.4666666666666667, 0.4666666666666667, 0.33333333333333337, 0.16666666666666663, 0.23333333333333328, 0.33333333333333337, 0.2666666666666667])
v: 0.13703354543953072
sample size: 30, repeat: 2
OPT error 0.06666666666666665 errors dict_values([0.5666666666666667, 0.06666666666666665, 0.06666666666666665, 0.3666666666666667, 0.5, 0.33333333333333337, 0.09999999999999998, 0.09999999999999998, 0.2666666666666667, 0.33333333333333337])
v: 0.10591245638534466
sample size: 30, repeat: 3
OPT error 0.1333333333333333 errors dict_values([0.43333333333333335, 0.33333333333333337, 0.3333333333333

100%|██████████| 10/10 [00:00<00:00, 160.87it/s]


sample size: 40, repeat: 0
OPT error 0.17500000000000004 errors dict_values([0.17500000000000004, 0.35, 0.17500000000000004, 0.5, 0.375, 0.35, 0.25, 0.25, 0.4, 0.275])
v: 0.23642859029269775
sample size: 40, repeat: 1
OPT error 0.19999999999999996 errors dict_values([0.525, 0.19999999999999996, 0.19999999999999996, 0.35, 0.275, 0.25, 0.19999999999999996, 0.30000000000000004, 0.32499999999999996, 0.22499999999999998])
v: 0.20370021210619738
sample size: 40, repeat: 2
OPT error 0.050000000000000044 errors dict_values([0.6, 0.19999999999999996, 0.050000000000000044, 0.35, 0.44999999999999996, 0.22499999999999998, 0.07499999999999996, 0.19999999999999996, 0.22499999999999998, 0.30000000000000004])
v: 0.08924578971867805
sample size: 40, repeat: 3
OPT error 0.19999999999999996 errors dict_values([0.6, 0.375, 0.35, 0.19999999999999996, 0.35, 0.35, 0.30000000000000004, 0.30000000000000004, 0.42500000000000004, 0.275])
v: 0.22612535442814463
sample size: 40, repeat: 4
OPT error 0.125 errors di

100%|██████████| 10/10 [00:00<00:00, 141.99it/s]

sample size: 50, repeat: 0
OPT error 0.21999999999999997 errors dict_values([0.28, 0.28, 0.33999999999999997, 0.21999999999999997, 0.30000000000000004, 0.33999999999999997, 0.26, 0.24, 0.4, 0.28])
v: 0.2814285902926977
sample size: 50, repeat: 1
OPT error 0.16000000000000003 errors dict_values([0.4, 0.16000000000000003, 0.30000000000000004, 0.19999999999999996, 0.30000000000000004, 0.21999999999999997, 0.18000000000000005, 0.18000000000000005, 0.36, 0.24])
v: 0.16370021210619745
sample size: 50, repeat: 2
OPT error 0.09999999999999998 errors dict_values([0.48, 0.09999999999999998, 0.16000000000000003, 0.19999999999999996, 0.31999999999999995, 0.24, 0.09999999999999998, 0.09999999999999998, 0.36, 0.30000000000000004])
v: 0.139245789718678
sample size: 50, repeat: 3
OPT error 0.19999999999999996 errors dict_values([0.31999999999999995, 0.30000000000000004, 0.4, 0.19999999999999996, 0.31999999999999995, 0.36, 0.28, 0.33999999999999997, 0.4, 0.30000000000000004])
v: 0.22612535442814463
sam




OPT error 0.14 errors dict_values([0.38, 0.28, 0.31999999999999995, 0.21999999999999997, 0.33999999999999997, 0.26, 0.26, 0.21999999999999997, 0.31999999999999995, 0.14])
v: 0.16339137108895052


100%|██████████| 10/10 [00:00<00:00, 110.03it/s]


sample size: 60, repeat: 0
OPT error 0.2666666666666667 errors dict_values([0.2833333333333333, 0.2833333333333333, 0.2833333333333333, 0.2666666666666667, 0.4, 0.35, 0.2833333333333333, 0.2666666666666667, 0.35, 0.33333333333333337])
v: 0.3280952569593644
sample size: 60, repeat: 1
OPT error 0.18333333333333335 errors dict_values([0.3666666666666667, 0.18333333333333335, 0.18333333333333335, 0.25, 0.33333333333333337, 0.3833333333333333, 0.18333333333333335, 0.19999999999999996, 0.3666666666666667, 0.23333333333333328])
v: 0.18703354543953077
sample size: 60, repeat: 2
OPT error 0.1333333333333333 errors dict_values([0.4666666666666667, 0.15000000000000002, 0.1333333333333333, 0.31666666666666665, 0.30000000000000004, 0.23333333333333328, 0.1333333333333333, 0.18333333333333335, 0.2833333333333333, 0.33333333333333337])
v: 0.1725791230520113
sample size: 60, repeat: 3
OPT error 0.16666666666666663 errors dict_values([0.33333333333333337, 0.30000000000000004, 0.30000000000000004, 0.166

100%|██████████| 10/10 [00:00<00:00, 169.79it/s]


sample size: 70, repeat: 0
OPT error 0.22857142857142854 errors dict_values([0.27142857142857146, 0.2857142857142857, 0.30000000000000004, 0.22857142857142854, 0.2857142857142857, 0.2857142857142857, 0.30000000000000004, 0.24285714285714288, 0.37142857142857144, 0.2857142857142857])
v: 0.29000001886412624
sample size: 70, repeat: 1
OPT error 0.18571428571428572 errors dict_values([0.3142857142857143, 0.18571428571428572, 0.18571428571428572, 0.2571428571428571, 0.19999999999999996, 0.2857142857142857, 0.18571428571428572, 0.19999999999999996, 0.3142857142857143, 0.2571428571428571])
v: 0.18941449782048314
sample size: 70, repeat: 2
OPT error 0.15714285714285714 errors dict_values([0.34285714285714286, 0.17142857142857137, 0.15714285714285714, 0.2857142857142857, 0.30000000000000004, 0.19999999999999996, 0.15714285714285714, 0.2142857142857143, 0.2142857142857143, 0.3142857142857143])
v: 0.19638864686153515
sample size: 70, repeat: 3
OPT error 0.18571428571428572 errors dict_values([0.3

In [14]:
ans_dict

70