In [41]:
from sklearn.datasets import load_iris
import pandas as pd
import numpy as np

In [None]:
def entropy(probability):
    result = 0
    for p in probability:
        result -= p * np.log2(p)
    return result

def gini(probability):
    result = 1
    for p in probability:
        result -= p**2
    return result

class Node:
    def __init__(self, df, target):
        self.df = df # df belonging to the node.
        self.target = target # Target attribute name.
        self.columns = df.columns.to_list() # Attribues/Columns of the node/df.
        self.columns.remove(target)
        self.left = None # Left child node.
        self.right = None # Right child node.

    def findSplitPoint(self, attr_name):
        df_sort = self.df.sort_values(by=attr_name) # Sort df based on the selected attribute.

        # Initialization.
        split_point = []
        label_pre = df_sort[self.target][0]
        attr_pre = df_sort[attr_name][0]

        # Loop through target labels sorted based on the selected attribute.
        for i, label in enumerate(df_sort[self.target]):
            if i > 0 and df_sort[attr_name].iloc[i] != df_sort[attr_name].iloc[i-1]: # If attribute value has changed.
                attr_pre = df_sort[attr_name].iloc[i-1] # Update previous attribute value.
            if label == label_pre: # Continue if the label didn't change.
                continue

            dif = df_sort[attr_name].iloc[i] - attr_pre # Difference of previous and current attribute value.
            split_point_current = df_sort[attr_name].iloc[i] - dif/2 # The average betwee previous and current attribute value.

            if split_point == [] or split_point_current != split_point[-1]: # If the split point has not already occurred.
                split_point.append(split_point_current) # Store the split point.

            label_pre = label # Update previous label.
            
        return df_sort, split_point


In [182]:
iris = load_iris()
df = pd.DataFrame(data=iris.data, columns=iris.feature_names)
df['target'] = iris.target

node = Node(df, 'target')


data, split_point = node.findSplitPoint(node.columns[0])
print(split_point)

[np.float64(4.85), np.float64(4.95), np.float64(5.05), np.float64(5.15), np.float64(5.35), np.float64(5.45), np.float64(5.55), np.float64(5.65), np.float64(5.75), np.float64(5.85), np.float64(5.95), np.float64(6.05), np.float64(6.15), np.float64(6.25), np.float64(6.35), np.float64(6.45), np.float64(6.65), np.float64(6.75), np.float64(6.85), np.float64(6.95), np.float64(7.05)]


In [28]:
print(1- (3/10)**2 - (7/10)**2)

print(1-(3/9)**2 - (6/9)**2)
print(0.42 - (9/10)*0.4444)

0.4200000000000001
0.4444444444444444
0.020039999999999947


In [34]:
print(gini([3/10,7/10]))
print(gini([3/9,6/9]))
print(gini([0.3,0.7]) - (9/10)*gini([3/9,6/9]))

0.4200000000000001
0.4444444444444444
0.02000000000000013
