# Decision Trees

## Decision Node Class

In [None]:
class DecisionNode: 
    def __init__(self, feature_index=None, threshold=None, left=None, right=None, value=None):
        self.feature_index = feature_index
        self.threshold = threshold
        self.left = left
        self.right = right
        self.value = value

## Gini Impurity Computation Helper Functions 

In [None]:


def compute_Gini_split(y_left=None, y_right=None):
    n = len(y_left) + len(y_right)
    gini_left = compute_Gini_impurity(y_left)
    gini_right = compute_Gini_impurity(y_right)
    gini_split = (gini_left)*(len(y_left)/n) + (gini_right)*(len(y_right)/n)
    return gini_split

def compute_Gini_impurity(y):
    gini = 1
    size = len(y)

    occurrences = {}

    for element in y:
        occurrences[element] = occurrences.get(element, 0) + 1
    
    for outcome in occurrences:
        gini -= (occurrences[outcome]/size)*(occurrences[outcome]/size)
    
    return gini
    

## build_tree function 

In [None]:
def build_tree(X=None, y=None, depth = 0, max_depth = 0, min_samples_split = 0):
    # These are stopping conditions
    # Purity
    pure = (len(set(y)) == 1)
    if pure:
        return DecisionNode(value = y[0])
    
    # Too few samples or too deep down the tree
    if(len(y) < min_samples_split) or (depth >= max_depth):
        majority = most_common_label(y)
        return DecisionNode(value=majority)

    # Get the amount of features
    num_features = X.shape[1]

    best_gini = float('inf')
    best_feature = None
    best_threshold = None
    best_splits = None

    # For now, utilize, only taking midpoints of all the values
    for feature in range(num_features):
        samples_features = [sample[feature] for sample in X]        
        samples_features.sort()

        test_thresholds = return_thresholds(samples_features)

        for threshold in test_thresholds:
            X_left, y_left = []
            X_right, y_right = []

            for row in range(len(X)):
                if X[row][feature] < threshold:
                    X_left.append(X[row])
                    y_left.append(y[row])
                else:
                    X_right.append(X[row])
                    y_right.append(y[row])
            
            if (len(X_left) == 0) or (len(X_right) == 0):
                continue
        
            current_gini_split = compute_Gini_split(y_left, y_right)
            if(current_gini_split < best_gini):
                best_gini = current_gini_split
                best_feature = feature
                best_threshold = threshold
                best_splits = (X_left, y_left, X_right, y_right)



        #Splitting logic based on thresholds

    # Start splitting based on thresholds
    # If no more good splits left

def return_thresholds(samples_features=None):
    test_thresholds = []

    for feature in range(len(samples_features) - 1):
        average = (samples_features[feature] + samples_features[feature + 1])/2.0
        test_thresholds.append(average)
        
    return test_thresholds
    

def most_common_label(y=None):
    counts = {}
    for label in y:
        counts[label] = counts.get(label, 0) + 1
    
    return max(counts, key=counts.get)