In [4]:
import math

class DecisionTree:
    
    def __init__(self, criterion = 'entropy', max_depth = 5): # Instantiate the class
        self.name = 'Decision tree model'
        self.criterion = criterion
        self.max_depth = max_depth
    
    def fit(X,y): # Make the tree
        if criterion == 'entropy': # The tree uses the entropy loss at each stage to determine the conditions of each branch
            # The decision tree is a greedy algorithm so will pick the best option at each stage i.e. largest entropy loss
            get_sublists(X,y, 0)
            
    def get_sublists(X,y, depth_counter):
        best_information_gain = 0 # This will contain the current best total information gain
        best_information_gain_index = 0 # Index of the current best

        for example in range(len(X)): # Assumes a complete dataset
            # Runs through each example choosing it as the pivot point - may repeat

            for label in range(len(X[example])): # Runs through each label

                # Creates two dictionaries that contain the actual class distribution above and below the data
                dict_1 = {0:0,1:0,2:0, 'total':0}
                dict_2 = {0:0,1:0,2:0, 'total':0} # The total will make it easier to calculate the probablilities of each class later
                for data_point in range(len(X)):
                    if X[data_point][label] <= X[example][label]: # Get the distribution of the data either side of the pivot
                        dict_1[y[data_point]] += 1
                        dict_1['total'] += 1
                    else:
                        dict_2[y[data_point]] += 1
                        dict_1['total'] += 1

                total_information_gain = (1 - get_entropy(dict_1)) + (1 - get_entropy(dict_2)) # This value may be greater than 1

                if total_information_gain > best_information_gain:
                    best_information_gain = total_information_gain
                    best_information_gain_index = [label, example]

        best_label = best_information_gain_index[0]
        best_pivot = best_information_gain_index[1]
        
        list_1_X, list_1_y, list_2_X, list_2_y = build_sublists(X, best_pivot, best_label)
        depth_counter += 1
        
        if depth_counter < self.max_depth:
            list_1_X = get_sublists(list_1_X, list_1_y, depth_counter)
            list_2_X = get_sublists(list_2_X, list_2_y, depth_counter)
        
        return [list_1_X, list_2_X]

        ### Build the new lists 
        ### Make this a recursive function
        ### Test it out

    def build_sublists(X, pivot, label): # Build two list with the pivot and label provided
        list_1_X = []
        list_1_y = []
        list_2_X = []
        list_2_y = []
        for example in range(len(X)):
            if X[example][label] <= pivot:
                list_1_X.append(X[example])
                list_1_y.append(y[example])
            else:
                list_2_X.append(X[example])
                list_2_y.append(y[example])
        
        return list_1_X, list_1_y, list_2_X, list_2_y
    
    def get_entropy(distribution_of_classes): # recieves a dictionary of the distribution of the classes
        entropy = 0
        for class_ in range(len(distribution_of_classes) - 1): # Minus 1 to disclude the total part
            probability_of_class = distribution_of_classes[class_]/distribution_of_classes['total']
            entropy -= probability_of_class * math.log(probability_of_class, 2) # See equation for entropy
        
        return entropy