## Decision Tree

$$ {Gini = 1- \sum_{i=1}^n（p^2)(c_i)}$$
$$ {Entropy = \sum_{i=1}^n - p(c_i)log_2(p(c_i))}$$

Gini系数和Entropy都是用于判断数据的杂乱程度


In [20]:
import numpy as np

In [21]:
from collections import Counter

In [52]:
def pr(e,es):
    return Counter(es)[e]/len(es)

In [53]:
def gini(elements):
    return 1 - np.sum([pr(e,elements)**2 for e in set(elements)])

In [54]:
def entropy(elements):
    return -1 * np.sum([pr(e,elements) * np.log2(pr(e,elements)) for e in set(elements)]) 

In [70]:
features_1 = ['R','R','Y','R']
features_2 = ['R','R','R','Y']
features_3 = ['Y','Y','R','Y']
features_4 = ['Y','Y','Y','Y']  #当不杂乱的时候，其gini系数及entropy就是0

In [71]:
gini(features_3)

0.375

In [72]:
entropy(features_3)

0.8112781244591328

In [79]:
item_sales = {
    'gender':['F','F','F','F','M','M','M'],
    'income':['20','10','20','20','20','20','10',],
    'family_number':[1,1,2,1,1,1,2],
    'bought':[1,1,1,0,0,0,1],
}

In [80]:
import pandas as pd

In [81]:
dataset = pd.DataFrame.from_dict(item_sales)

In [82]:
dataset

Unnamed: 0,gender,income,family_number,bought
0,F,20,1,1
1,F,10,1,1
2,F,20,2,1
3,F,20,1,0
4,M,20,1,0
5,M,20,1,0
6,M,10,2,1


In [83]:
dataset.gender

0    F
1    F
2    F
3    F
4    M
5    M
6    M
Name: gender, dtype: object

In [85]:
dataset.columns

Index(['gender', 'income', 'family_number', 'bought'], dtype='object')

In [87]:
dataset['gender'] == 'F'

0     True
1     True
2     True
3     True
4    False
5    False
6    False
Name: gender, dtype: bool

In [88]:
dataset[dataset['gender'] == 'F']

Unnamed: 0,gender,income,family_number,bought
0,F,20,1,1
1,F,10,1,1
2,F,20,2,1
3,F,20,1,0


In [89]:
target = 'bought'

In [90]:
dataset[dataset['gender'] == 'F'][target]

0    1
1    1
2    1
3    0
Name: bought, dtype: int64

In [92]:
dataset[dataset['gender'] == 'F'][target].tolist()

[1, 1, 1, 0]

In [94]:
print(dataset[dataset['gender'] == 'F'][target].tolist(),
      dataset[dataset['gender'] == 'M'][target].tolist())
gini(dataset[dataset['gender'] == 'F'][target].tolist())+\
gini(dataset[dataset['gender'] == 'M'][target].tolist())

[1, 1, 1, 0] [0, 0, 1]


0.8194444444444444

In [95]:
dataset[dataset['income'] == '20'][target].tolist()

[1, 1, 0, 0, 0]

In [96]:
print(dataset[dataset['income'] == '20'][target].tolist(),
      dataset[dataset['income'] == '10'][target].tolist())
gini(dataset[dataset['income'] == '20'][target].tolist())+\
gini(dataset[dataset['income'] == '10'][target].tolist())

[1, 1, 0, 0, 0] [1, 1]


0.48

## 决策树模型，让计算机自动构建逐层的if-else模型
## CART Algorithm
### Classification And Regression Tree Algorithm

loss = (m_left)/m * G_left + (m_right)/m * G_right

### 决策树，可以用来做回归

In [114]:
from sklearn.datasets import load_boston

In [115]:
boston = load_boston()

In [116]:
X = boston.data

In [117]:
y = boston.target

In [125]:
from sklearn.tree import DecisionTreeRegressor

In [126]:
tree_clf = DecisionTreeRegressor()

In [127]:
tree_clf.fit(X,y)

DecisionTreeRegressor()

In [128]:
from sklearn.tree import export_graphviz

In [130]:
export_graphviz(tree_clf,out_file='boston.dot',
               feature_names = boston.feature_names,
               rounded = True,
               filled = True)

In [131]:
for line in open('boston.dot'):
    print(line)

digraph Tree {

node [shape=box, style="filled, rounded", color="black", fontname=helvetica] ;

edge [fontname=helvetica] ;

0 [label="RM <= 6.941\nmse = 84.42\nsamples = 506\nvalue = 22.533", fillcolor="#f5ceb2"] ;

1 [label="LSTAT <= 14.4\nmse = 40.273\nsamples = 430\nvalue = 19.934", fillcolor="#f6d5bd"] ;

0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;

2 [label="DIS <= 1.385\nmse = 26.009\nsamples = 255\nvalue = 23.35", fillcolor="#f4ccae"] ;

1 -> 2 ;

3 [label="B <= 339.985\nmse = 78.146\nsamples = 5\nvalue = 45.58", fillcolor="#e88d4c"] ;

2 -> 3 ;

4 [label="mse = 0.0\nsamples = 1\nvalue = 27.9", fillcolor="#f2bf9a"] ;

3 -> 4 ;

5 [label="mse = 0.0\nsamples = 4\nvalue = 50.0", fillcolor="#e58139"] ;

3 -> 5 ;

6 [label="RM <= 6.543\nmse = 14.885\nsamples = 250\nvalue = 22.905", fillcolor="#f5cdb0"] ;

2 -> 6 ;

7 [label="LSTAT <= 7.57\nmse = 8.39\nsamples = 195\nvalue = 21.63", fillcolor="#f5d0b6"] ;

6 -> 7 ;

8 [label="TAX <= 222.5\nmse = 3.015\nsamples = 43\

793 [label="B <= 107.815\nmse = 1.268\nsamples = 11\nvalue = 8.755", fillcolor="#fdf4ee"] ;

781 -> 793 ;

794 [label="B <= 21.85\nmse = 0.499\nsamples = 6\nvalue = 8.067", fillcolor="#fdf6f2"] ;

793 -> 794 ;

795 [label="RM <= 5.926\nmse = 0.047\nsamples = 3\nvalue = 8.6", fillcolor="#fdf5ef"] ;

794 -> 795 ;

796 [label="mse = 0.0\nsamples = 1\nvalue = 8.3", fillcolor="#fdf6f0"] ;

795 -> 796 ;

797 [label="DIS <= 1.858\nmse = 0.002\nsamples = 2\nvalue = 8.75", fillcolor="#fdf4ee"] ;

795 -> 797 ;

798 [label="mse = 0.0\nsamples = 1\nvalue = 8.8", fillcolor="#fdf4ee"] ;

797 -> 798 ;

799 [label="mse = 0.0\nsamples = 1\nvalue = 8.7", fillcolor="#fdf5ef"] ;

797 -> 799 ;

800 [label="NOX <= 0.717\nmse = 0.382\nsamples = 3\nvalue = 7.533", fillcolor="#fef8f4"] ;

794 -> 800 ;

801 [label="B <= 57.76\nmse = 0.01\nsamples = 2\nvalue = 7.1", fillcolor="#fef9f6"] ;

800 -> 801 ;

802 [label="mse = 0.0\nsamples = 1\nvalue = 7.2", fillcolor="#fef9f5"] ;

801 -> 802 ;

803 [label="mse = 0.0\

In [132]:
tree_clf.feature_importances_

array([0.06092278, 0.00063904, 0.00418948, 0.00079233, 0.0407796 ,
       0.57561402, 0.0111776 , 0.07331292, 0.00104338, 0.01265787,
       0.00723475, 0.01428157, 0.19735466])

In [136]:
{n:w for n,w in zip(boston.feature_names,tree_clf.feature_importances_)}

{'CRIM': 0.060922776462794224,
 'ZN': 0.0006390403046970651,
 'INDUS': 0.004189477708651446,
 'CHAS': 0.0007923309855316643,
 'NOX': 0.04077960095223498,
 'RM': 0.5756140202432457,
 'AGE': 0.011177603450119278,
 'DIS': 0.07331291708499542,
 'RAD': 0.0010433820684507197,
 'TAX': 0.012657867822702343,
 'PTRATIO': 0.007234753277597323,
 'B': 0.014281565725812032,
 'LSTAT': 0.19735466391316764}

In [139]:
sorted({n:w for n,w in zip(boston.feature_names,tree_clf.feature_importances_)}.items(),
key = lambda x:x[0])

[('AGE', 0.011177603450119278),
 ('B', 0.014281565725812032),
 ('CHAS', 0.0007923309855316643),
 ('CRIM', 0.060922776462794224),
 ('DIS', 0.07331291708499542),
 ('INDUS', 0.004189477708651446),
 ('LSTAT', 0.19735466391316764),
 ('NOX', 0.04077960095223498),
 ('PTRATIO', 0.007234753277597323),
 ('RAD', 0.0010433820684507197),
 ('RM', 0.5756140202432457),
 ('TAX', 0.012657867822702343),
 ('ZN', 0.0006390403046970651)]

In [142]:
sorted({n:w for n,w in zip(boston.feature_names,tree_clf.feature_importances_)}.items(),
key = lambda x:x[1],reverse = True)

[('RM', 0.5756140202432457),
 ('LSTAT', 0.19735466391316764),
 ('DIS', 0.07331291708499542),
 ('CRIM', 0.060922776462794224),
 ('NOX', 0.04077960095223498),
 ('B', 0.014281565725812032),
 ('TAX', 0.012657867822702343),
 ('AGE', 0.011177603450119278),
 ('PTRATIO', 0.007234753277597323),
 ('INDUS', 0.004189477708651446),
 ('RAD', 0.0010433820684507197),
 ('CHAS', 0.0007923309855316643),
 ('ZN', 0.0006390403046970651)]