In [13]:
import pandas as pd
import numpy as np
import scipy.optimize as opt

In [14]:
class Leaf:
    def __init__(self,value):
        self.value = value

In [15]:
class Node:
    def __init__(self,branches,attribute,threshold):
        self.branches = branches
        self.threshold = threshold
        self.attribute = attribute
        
    def get(self,df):
        return self.branches[0] if df[self.attribute] < self.threshold else self.branches[1]
        

In [16]:
class Tree:
    def __init__(self,root):
        self.root = root
        
    def predict(self,x):
        item = self.root
        while isinstance(item,Node):
            item = item.get(x)
        return item

In [17]:
r=Node([Leaf('young'),Leaf('old')],"age",18)
t=Tree(r)
print(t.predict({"age":2}).value)

young


In [18]:
df=pd.read_csv("iris.csv")

In [19]:
print(t.predict({"age":20}).value)

old


In [45]:
class CART:
    def __init__(self,df,y_name,X_names):
        self.df = df
        self.y_name = y_name
        self.X_names = X_names
        self.tree = None
        
    def create_tree(self):
        root = self._node_or_leaf(self.df)
        self.tree = Tree(root)
        return self.tree
    
    def _gini_impurity(self, df):
        unique, counts = np.unique(df[self.y_name].values, return_counts=True)
        N = df[self.y_name].values.ravel().size
        p = counts/N
        #print(unique)
        #print(p)
        return 1. - np.sum(p**2)
    
    def _opt_fun(self,df,split_name):
        def fun(x):
            split_df = [df[df[split_name]<x],
                        df[df[split_name]>=x]]
            return self._loss(split_df[0]) + self._loss(split_df[1])
        return fun
        
    def _node_or_leaf(self,df,loss_parent=0.99):
        loss_best, split_df, split_threshold, split_name = self._loss_best(df)
        print(f"Computed split:\nloss: {loss_best:.2f} (parent: {loss_parent:.2f})\nattribute: {split_name}\nthreshold: {split_threshold}\ncount: {[len(df_.index) for df_ in split_df]}")
        if loss_best < loss_parent and loss_best > 0.:
            branches = []
            for i in range(2):
                branches.append(self._node_or_leaf(split_df[i],loss_parent=loss_best))
            item = Node(branches,split_name,split_threshold)
            print(f"\n * creating Node({split_name}, {split_threshold})")
        else:
            unique, counts = np.unique(df[self.y_name].values,return_counts=True)
            print(unique, counts)
            sort_ind = np.argsort(-counts)
            value = unique[sort_ind[0]]
            item = Leaf(value)
            print(f"\n * creating Leaf({value}, N={len(df.index)})")
        return item
    
    def _loss_best(self,df):
        loss0 = 10
        for name in self.X_names:
            #split_threshold_ = np.median(df[name].v
            res = opt.minimize_scalar(self._opt_fun(df,name),bounds=(df[name].min(),df[name].max()),method="bounded")
            split_threshold_ = res.x
            split_df_ = [df[df[name]<split_threshold_],
                        df[df[name]>=split_threshold_]]
            #loss = self._loss(split_df_[0]) + self._loss(split_df_[1])
            loss = res.fun
            if loss < loss0:
                loss0 = loss
                split_threshold = split_threshold_
                split_df = split_df_
                split_name = name
                
        #print(loss0)
                
        return loss0, split_df, split_threshold, split_name
    
    def _loss(self,df):
        return self._gini_impurity(df)
            
        
        

In [46]:
df.columns
X_names=["sepal_length","sepal_width","petal_length","petal_width"]
df[X_names]

Unnamed: 0,sepal_length,sepal_width,petal_length,petal_width
0,5.1,3.5,1.4,0.2
1,4.9,3.0,1.4,0.2
2,4.7,3.2,1.3,0.2
3,4.6,3.1,1.5,0.2
4,5.0,3.6,1.4,0.2
...,...,...,...,...
145,6.7,3.0,5.2,2.3
146,6.3,2.5,5.0,1.9
147,6.5,3.0,5.2,2.0
148,6.2,3.4,5.4,2.3


In [47]:
df.iloc[0]

sepal_length       5.1
sepal_width        3.5
petal_length       1.4
petal_width        0.2
species         setosa
Name: 0, dtype: object

In [48]:
c = CART(df,"species",X_names)
c.create_tree()

Computed split:
loss: 0.50 (parent: 0.99)
attribute: petal_length
threshold: 2.3928050106119554
count: [50, 100]
Computed split:
loss: 0.00 (parent: 0.50)
attribute: sepal_length
threshold: 5.799994478069624
count: [49, 1]
['setosa'] [50]

 * creating Leaf(setosa, N=50)
Computed split:
loss: 0.21 (parent: 0.50)
attribute: petal_width
threshold: 1.7120403835448885
count: [54, 46]
Computed split:
loss: 0.18 (parent: 0.21)
attribute: petal_length
threshold: 3.3094496107966327
count: [3, 51]
Computed split:
loss: 0.00 (parent: 0.18)
attribute: sepal_length
threshold: 5.099995237708634
count: [2, 1]
['versicolor'] [3]

 * creating Leaf(versicolor, N=3)
Computed split:
loss: 0.19 (parent: 0.18)
attribute: sepal_width
threshold: 3.1605583854679087
count: [47, 4]
['versicolor' 'virginica'] [46  5]

 * creating Leaf(versicolor, N=51)

 * creating Node(petal_length, 3.3094496107966327)
Computed split:
loss: 0.05 (parent: 0.21)
attribute: sepal_width
threshold: 3.454283080105119
count: [43, 3]
Co

<__main__.Tree at 0x4990040>

In [49]:
c.tree.predict(df.iloc[0]).value

'setosa'