In [17]:
import numpy as np
import sklearn.datasets

def entropy(y: np.ndarray):
    p = np.unique(y, return_counts=True)[1]/len(y)
    return -np.sum(p * np.log(p))

def gini(y: np.ndarray):
    p = np.unique(y, return_counts=True)[1]/len(y)
    return np.sum(p * (1 - p))

def var(y: np.ndarray):
    return np.var(y)

def split(X: np.ndarray, y: np.ndarray, threshold: float) -> tuple[np.ndarray, np.ndarray]:
    m = X <= threshold
    return y[np.invert(m)], y[m]
    

def tree_split(X, y, criterion):
    criter_dict = {'entropy':entropy, 'gini':gini, 'var':var }
    H = criter_dict[criterion]
    
    search_idx = None
    
    min_ent = np.inf
    for i in range(X.shape[1]):
        for j in range(X.shape[0]):
            
            threshold = X[j, i]
            left, right = split(X[:, i], y, threshold)
            
            ent = H(left) * left.shape[0] / X.shape[0] + \
                    H(right) * right.shape[0] / X.shape[0]
            
            if ent < min_ent:
                min_ent = ent
                search_idx = [i, j]
    return search_idx
                

In [21]:
X, y = sklearn.datasets.make_classification(n_samples=100, n_features=20, n_informative=20, n_redundant=0, n_clusters_per_class=2)
print(X,y)

[[-3.98582635  0.75764539  3.25288892 ...  2.91146758  4.57137698
   2.15453668]
 [-1.30624563  2.41190754  1.65850954 ... -4.69231805 -0.02520994
   0.2523196 ]
 [-1.30663851 -3.58519116  2.0271222  ...  1.02812272 -1.63835972
   0.21451276]
 ...
 [ 0.2770626   0.00925033 -0.3426235  ...  1.79838569 -1.80086054
   2.54411797]
 [ 3.14466674  2.51669085  3.18947342 ... -2.64706015 -2.18573452
   3.43117551]
 [-1.6955388  -0.29775785 -2.76471767 ...  1.18399405 -3.31532382
   0.53611124]] [1 1 1 0 1 1 0 0 0 0 1 1 1 1 0 0 0 1 0 0 0 1 0 1 1 0 1 1 0 0 0 1 0 1 1 0 1
 1 0 1 1 0 0 0 0 1 1 0 1 0 1 0 1 0 1 1 0 0 0 1 1 0 1 1 0 1 0 1 0 0 1 0 0 1
 1 0 1 1 0 0 1 0 1 0 0 1 0 0 0 0 1 1 0 0 1 1 1 1 1 0]


In [22]:
tree_split(X, y, 'entropy')

[79, 5]