In [20]:
import numpy as np
import math

In [21]:
def gini(counts):
    """ Calculates Gini Impurity 
    
        Args:
            counts: The number of samples in each class.
    """
    # counts: array of shape (K,)
    total = counts.sum()
    if total == 0:
        return 0.0
    p = counts / total
    return 1.0 - np.sum(p * p)

def entropy(counts):
    """ Calculate Entropy
    
        Args:
            counts: The number of samples in each class.
    """
    p = counts/counts.sum()
    p = p[p > 0]   # avoid log(0)
    return -np.sum(p * np.log2(p))


In [22]:
def best_split_one_feature(x, y):   # Assuming x is already sorted
    x = np.asarray(x)
    y = np.asarray(y)

    # Encode labels to 0...K-1
    classes, y_enc = np.unique(y, return_inverse=True)
    K = len(classes)
    N = len(x)
    if N <= 1:
        return None, np.inf  # no split

    parent_counts = np.bincount(y_enc, minlength=K)
    parent_imp = gini(parent_counts)
    
    best_t = None
    best_after_imp = np.inf  # we minimize weighted impurity after
    for i in range(1, N):
        # Skip if they are the same
        if x[i-1]==x[i]:
            continue
        t = (x[i-1] + x[i]) / 2   # The threshold is the middle point
        
        # Compute left/right counts
        left_counts = np.bincount(y_enc[x <= t], minlength=K)
        right_counts = parent_counts - left_counts

        # Total number of samples in each node
        n_left = left_counts.sum()
        n_right = right_counts.sum()

        if n_left==0 or n_right==0:
            continue

        left_imp = gini(left_counts)
        right_imp = gini(right_counts)
        weighted_after = (n_left/N) * left_imp + (n_right/N) * right_imp

        if weighted_after < best_after_imp:
            best_after_imp = weighted_after
            best_t = t
    
    return best_t, best_after_imp



x=[1, 3, 5, 6, 8]
y=[0, 1, 0, 0, 2]

best_split_one_feature(x, y)

(np.float64(7.0), np.float64(0.30000000000000004))