In [1]:
from sklearn.datasets import load_breast_cancer
from collections import Counter
import numpy as np
data = load_breast_cancer()
X,y = data["data"],data["target"]

In [2]:
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(X,y)

In [3]:
#这边为了方便只考虑连续变量了
class CartDT:
    @staticmethod
    def gini(y):
        if type(y) == list or type(y) == np.ndarray:
            y = dict(Counter(y))
        precs = np.array(list(y.values())) / sum(y.values())
        return 1-np.sum(precs**2)

In [4]:
t = [2,3,4]
t = dict(Counter(t))
CartDT.gini(t)

0.6666666666666667

In [5]:
#选择最佳的划分点,即求出fi和fv
def best_split(X,y):
    best_gini = 1e10
    best_fi = None
    best_fv = None
    for fi in range(X.shape[1]):
        #该特征仅有一个取值无法再分
        if len(set(X[:,fi])) == 1:
            continue
        for fv in sorted(set(X[:,fi]))[:-1]:
            y_left = y[X[:,fi] <= fv]
            gini_left = CartDT.gini(y_left)
            y_right = y[X[:,fi] > fv]
            gini_right = CartDT.gini(y_right)
            gini = len(y_left)/len(y)*gini_left + len(y_right)/len(y)*gini_right
#             print(f"fi={fi:.2f} fv={fv:.2f} gini={gini:.2f}")
            if gini < best_gini:
                best_gini = gini
                best_fi = fi
                best_fv = fv
    return best_gini,best_fi,best_fv

In [6]:
best_split(X_test,y_test)

(0.11571227960680289, 7, 0.05102)

In [7]:
def build_tree(X,y):
    #叶子节点的条件1,仅有一个类别
    counts = dict(Counter(y))
    result = max(counts,key=counts.get)
    if len(counts) == 1:
        return {"counts":counts,"result":result}
    
    #叶子节点的条件2，所有特征仅有一个取值
    fcs = [len(Counter(X[:,fi])) for fi in range(X.shape[-1])]
    if sum(fcs) == X.shape[-1]:
        return {"counts":counts,"result":result}
    
    gini,fi,fv = best_split(X,y)
    index_left,index_right = X[:,fi]<=fv,X[:,fi]>fv
    left = build_tree(X[index_left],y[index_left])
    right = build_tree(X[index_right],y[index_right])
    return {"counts":counts,"result":None,"left":left,"right":right,"fi":fi,"fv":fv}

In [8]:
tree = build_tree(X_train,y_train)

In [9]:
tree

{'counts': {0: 157, 1: 269},
 'result': None,
 'left': {'counts': {1: 249, 0: 13},
  'result': None,
  'left': {'counts': {1: 240, 0: 4},
   'result': None,
   'left': {'counts': {1: 240, 0: 2},
    'result': None,
    'left': {'counts': {1: 4, 0: 1},
     'result': None,
     'left': {'counts': {1: 4}, 'result': 1},
     'right': {'counts': {0: 1}, 'result': 0},
     'fi': 1,
     'fv': 18.17},
    'right': {'counts': {1: 236, 0: 1},
     'result': None,
     'left': {'counts': {1: 221}, 'result': 1},
     'right': {'counts': {1: 15, 0: 1},
      'result': None,
      'left': {'counts': {0: 1}, 'result': 0},
      'right': {'counts': {1: 15}, 'result': 1},
      'fi': 21,
      'fv': 33.37},
     'fi': 21,
     'fv': 33.33},
    'fi': 14,
    'fv': 0.00328},
   'right': {'counts': {0: 2}, 'result': 0},
   'fi': 10,
   'fv': 0.6412},
  'right': {'counts': {0: 9, 1: 9},
   'result': None,
   'left': {'counts': {0: 1, 1: 7},
    'result': None,
    'left': {'counts': {1: 7}, 'result': 1}

In [10]:
def predict(X):
    y_pred = []
    for x in X:
        cur = tree
        while cur["result"] == None:
            fi,fv = cur["fi"],cur["fv"]
            cur = cur["left"] if x[fi] <= fv else cur["right"]
        y_pred.append(cur["result"])
    return np.array(y_pred)

In [11]:
np.sum(predict(X_test) == y_test) / len(y_test)

0.9090909090909091

In [12]:
def C(tree):
    leafs = []
    count = 0
    def dfs(tree):
        nonlocal leafs,count
        count += 1
        if tree["result"] != None:
            leafs.append(tree["counts"])
            return
        dfs(tree["left"])
        dfs(tree["right"])
        return
    dfs(tree)
    percs = np.array([sum(leaf.values()) for leaf in leafs])
    percs = percs / percs.sum()
    ginis = np.array([CartDT.gini(leaf) for leaf in leafs])
    c = np.sum(percs * ginis)
    return c,count

In [13]:
c,count = C(tree)

In [14]:
alphas = []
def add_alpha(tree):
    global alphas
    if tree["result"] != None:
        return tree
    gini_one = CartDT.gini(tree["counts"])
    gini_whole,counts = C(tree)
    alpha = (gini_one - gini_whole)/(counts-1)
    alphas.append(alpha)
    tree["alpha"] = alpha
    tree["left"] = add_alpha(tree["left"])
    tree["right"] = add_alpha(tree["right"])
    return tree

In [15]:
tree_alpha = add_alpha(tree)
print(tree_alpha["alpha"])
print(tree_alpha["left"]["alpha"])
print(tree_alpha["left"]["left"]["alpha"])
print(tree_alpha["left"]["left"]["left"]["alpha"])
print(tree_alpha["left"]["left"]["left"]["left"]["alpha"])
print(tree_alpha["left"]["left"]["left"]["left"]["left"]["alpha"])

0.012248393581984452
0.004715634287046211
0.003224939532383764
0.0020490403660952117
0.15999999999999992


KeyError: 'alpha'

In [16]:
subtrees = [tree.copy() for _ in range(len(set(alphas)))]

In [17]:
def inactivity(tree,alpha):
    if tree["result"] != None:
        return tree
    if tree["alpha"] <= alpha:
        tree["result"] = max(tree["counts"],key=tree["counts"].get)
    tree["left"] = inactivity(tree["left"],alpha)
    tree["right"] = inactivity(tree["right"],alpha)
    return tree

In [18]:
for i,alpha in enumerate(sorted(set(alphas))):
    subtrees[i] = inactivity(subtrees[i],alpha)

In [19]:
# 整合所有的东西
class CartDT:
    @staticmethod
    def gini(y):
        if type(y) == list or type(y) == np.ndarray:
            y = dict(Counter(y))
        precs = np.array(list(y.values())) / sum(y.values())
        return 1-np.sum(precs**2)
    
    def best_split(self,X,y):
        best_gini = 1e10
        best_fi = None
        best_fv = None
        for fi in range(X.shape[1]):
            #该特征仅有一个取值无法再分
            if len(set(X[:,fi])) == 1:
                continue
            for fv in sorted(set(X[:,fi]))[:-1]:
                y_left = y[X[:,fi] <= fv]
                gini_left = CartDT.gini(y_left)
                y_right = y[X[:,fi] > fv]
                gini_right = CartDT.gini(y_right)
                gini = len(y_left)/len(y)*gini_left + len(y_right)/len(y)*gini_right
    #             print(f"fi={fi:.2f} fv={fv:.2f} gini={gini:.2f}")
                if gini < best_gini:
                    best_gini = gini
                    best_fi = fi
                    best_fv = fv
        return best_gini,best_fi,best_fv

    def build_tree(self,X,y):
        #叶子节点的条件1,仅有一个类别
        counts = dict(Counter(y))
        result = max(counts,key=counts.get)
        if len(counts) == 1:
            return {"counts":counts,"result":result}

        #叶子节点的条件2，所有特征仅有一个取值
        fcs = [len(Counter(X[:,fi])) for fi in range(X.shape[-1])]
        if sum(fcs) == X.shape[-1]:
            return {"counts":counts,"result":result}

        gini,fi,fv = self.best_split(X,y)
        index_left,index_right = X[:,fi]<=fv,X[:,fi]>fv
        left = self.build_tree(X[index_left],y[index_left])
        right = self.build_tree(X[index_right],y[index_right])
        return {"counts":counts,"result":None,"left":left,"right":right,"fi":fi,"fv":fv}
    
    def fit(self,X,y):
        self.tree = self.build_tree(X,y)
    
    def _C(self,tree):
        leafs = []
        count = 0
        def dfs(tree):
            nonlocal leafs,count
            count += 1
            if tree["result"] != None:
                leafs.append(tree["counts"])
                return
            dfs(tree["left"])
            dfs(tree["right"])
            return
        dfs(tree)
        percs = np.array([sum(leaf.values()) for leaf in leafs])
        percs = percs / percs.sum()
        ginis = np.array([CartDT.gini(leaf) for leaf in leafs])
        c = np.sum(percs * ginis)
        return c,count
    
    def _add_alpha(self,tree):
        if tree["result"] != None:
            return tree
        gini_one = CartDT.gini(tree["counts"])
        gini_whole,counts = self._C(tree)
        alpha = (gini_one - gini_whole)/(counts-1)
        self.alphas.append(alpha)
        tree["alpha"] = alpha
        tree["left"] = self._add_alpha(tree["left"])
        tree["right"] = self._add_alpha(tree["right"])
        return tree
    
    def _inactivity(self,tree,alpha):
        if tree["result"] != None:
            return tree
        if tree["alpha"] <= alpha:
            tree["result"] = max(tree["counts"],key=tree["counts"].get)
        tree["left"] = self._inactivity(tree["left"],alpha)
        tree["right"] = self._inactivity(tree["right"],alpha)
        return tree
    
    def post_pruning(self):
        self.alphas = []
        self.tree = self._add_alpha(self.tree)
        self.subtrees = [self.tree.copy() for _ in range(len(set(self.alphas)))]
        for i,alpha in enumerate(sorted(set(self.alphas))):
            self.subtrees[i] = self._inactivity(self.subtrees[i],alpha)
        
    def _predict(self,X,tree):
        y_pred = []
        for x in X:
            cur = tree
            while cur["result"] == None:
                fi,fv = cur["fi"],cur["fv"]
                cur = cur["left"] if x[fi] <= fv else cur["right"]
            y_pred.append(cur["result"])
        return np.array(y_pred)
    
    def _score(self,X,y,tree):
        return np.sum(self._predict(X,tree)==y) / len(y)
    
    def predict(self,X):
        return self._predict(X,self.tree)
    
    def score(self,X,y):
        return np.sum(self.predict(X)==y) / len(y)

In [20]:
cart_clf = CartDT()
cart_clf.fit(X_train,y_train)

In [21]:
cart_clf.score(X_test,y_test)

0.9090909090909091

In [22]:
cart_clf.post_pruning()

In [23]:
print(cart_clf.alphas)

[0.012248393581984452, 0.004715634287046211, 0.003224939532383764, 0.0020490403660952117, 0.15999999999999992, 0.0021008029340027212, 0.05859375, 0.0625, 0.109375, 0.07999999999999996, 0.2222222222222222, 0.013384889946460435, 0.0621301775147929, 0.047091412742382266, 0.2222222222222222, 0.04750000000000004, 0.005247999999999993, 0.003999739854318396, 0.2222222222222222]


In [24]:
for subtree in cart_clf.subtrees:
    print(cart_clf._score(X_test,y_test,subtree))

0.9090909090909091
0.9090909090909091
0.9090909090909091
0.9090909090909091
0.9090909090909091
0.9090909090909091
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
0.6153846153846154
