In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import tree,ensemble,metrics

from rule import Rule
from rule_extraction import rule_extract

## 数据准备

In [2]:
# fetch dataset
data = pd.read_csv('./dataset/titanic.csv', usecols = ['Age', 'Fare', 'Survived'])

# drop NA records since most Tree algorithm cannot handle
data.dropna(inplace=True)

# split training/test sets
X_train, X_test, y_train, y_test = train_test_split(data[['Age', 'Fare']], data.Survived, test_size=0.2, random_state=0)

# dataset shape
print(X_train.shape, X_test.shape)


(571, 2) (143, 2)


## 单颗决策树

In [3]:
# 单颗决策树模型
# 模型API参考 http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

model_tree_clf = tree.DecisionTreeClassifier(criterion='gini',max_depth=3)
model_tree_clf.fit(X_train,y_train)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')

In [4]:
# model performance on test set
y_pred_test = model_tree_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[67 12]
 [33 31]]


In [5]:
# rule_extract(model_tree_clf,X_test,y_test)

[('Age > 8.5 and Fare <= 51.931251525878906 and Fare > 10.335399627685547',
  ('recall for class 1 is: 0.40625',
   'prec for class 1 is: 0.4727272727272727',
   'recall for class 0 is: 0.6329113924050633',
   'prec for class 0 is: 0.42735042735042733')),
 ('Age > 29.5 and Fare > 51.931251525878906',
  ('recall for class 1 is: 0.25',
   'prec for class 1 is: 0.6956521739130435',
   'recall for class 0 is: 0.9113924050632911',
   'prec for class 0 is: 0.5669291338582677')),
 ('Age <= 29.5 and Age > 17.5 and Fare > 51.931251525878906',
  ('recall for class 1 is: 0.109375',
   'prec for class 1 is: 0.6363636363636364',
   'recall for class 0 is: 0.9493670886075949',
   'prec for class 0 is: 0.5514705882352942')),
 ('Age <= 8.5 and Fare <= 27.825000762939453',
  ('recall for class 1 is: 0.09375',
   'prec for class 1 is: 1.0',
   'recall for class 0 is: 1.0',
   'prec for class 0 is: 0.5766423357664233')),
 ('Age > 8.5 and Fare <= 10.335399627685547',
  ('recall for class 1 is: 0.078125',


In [5]:
rule_extract(model_tree_clf,X_test,y_test,0,-1,-1)

[305, 504, 297]
[72, 724, 102, 655, 97, 681, 151, 291, 310, 377, 615]
[824, 261, 787, 850, 63, 618]
[691, 644, 469, 777, 803, 448]
[423, 177, 292, 889, 595, 637, 456, 811, 861, 15, 239, 394, 795, 346, 9, 512, 830, 403, 345, 666, 772, 446, 570, 312, 251, 812, 417, 658, 339, 506, 11, 616, 885, 626, 702, 657, 142, 430, 706, 583, 66, 266, 620, 450, 21, 860, 98, 84, 874, 518, 327, 701, 607, 308, 123]
[54, 110, 325, 591, 1, 621, 218, 35, 332, 609, 390, 558, 195, 599, 318, 829, 412, 544, 779, 438, 262, 366, 215]
[287, 315, 131, 477, 606, 40, 382, 682, 614, 197, 640, 225, 391, 756, 590, 822, 729, 442, 631, 441, 212, 677, 67, 422, 89, 296, 649, 810, 51, 769, 715, 69, 355, 771, 818, 338, 421, 379, 471]


[('Age > 8.5 and Fare <= 51.931251525878906 and Fare > 10.335399627685547',
  (0.40625, 0.4727272727272727, 0.6329113924050633, 0.42735042735042733, 1)),
 ('Age > 29.5 and Fare > 51.931251525878906',
  (0.25, 0.6956521739130435, 0.9113924050632911, 0.5669291338582677, 1)),
 ('Age <= 29.5 and Age > 17.5 and Fare > 51.931251525878906',
  (0.109375, 0.6363636363636364, 0.9493670886075949, 0.5514705882352942, 1)),
 ('Age <= 8.5 and Fare <= 27.825000762939453',
  (0.09375, 1.0, 1.0, 0.5766423357664233, 1)),
 ('Age > 8.5 and Fare <= 10.335399627685547',
  (0.078125, 0.1282051282051282, 0.5696202531645569, 0.32608695652173914, 1)),
 ('Age <= 17.5 and Fare > 51.931251525878906',
  (0.03125, 0.6666666666666666, 0.9873417721518988, 0.5531914893617021, 1)),
 ('Age <= 8.5 and Fare <= 51.931251525878906 and Fare > 27.825000762939453',
  (0.03125, 0.3333333333333333, 0.9493670886075949, 0.5319148936170213, 1))]

In [10]:
ru = 'Age > 8888.5'
_eval_rule_perf(ru,X_test,y_test)

[]




(0.0, 0.0, 0.0, 0.0)

In [9]:
from warnings import warn

def _eval_rule_perf(rule, X, y):
        """
        衡量每一条单独规则的评价指标，目前支持 0/1 两类样本的precision/recall
       
        Parameters
        ----------
    
        rule : str
            从决策树中提取出的单条规则
    
        X : pandas.DataFrame.
            用来测试的样本的特征集
    
        y : pandas.DataFrame.
            用来测试的样本的y标签
            
        """
        detected_index = list(X.query(rule).index)
        print(detected_index)
        if len(detected_index) <= 0:
            warn("rule %s reach no samples" % str(rule))
            return (0.,0.,0.,0.)
        
        y_detected = y[detected_index]
        true_pos = y_detected[y_detected > 0].count()
        false_pos = y_detected[y_detected == 0].count()

        pos = y[y > 0].count()
        neg = y[y == 0].count()
#        recall_0 = str('recall for class 0 is: '+ str(1- (float(false_pos) /neg)))
#        prec_0 = str('prec for class 0 is: ' + str((neg-false_pos) / (len(y)-y_detected.sum())))
#        recall_1 = str('recall for class 1 is: '+ str(float(true_pos) / pos))
#        prec_1 = str('prec for class 1 is: ' + str(y_detected.mean()))
        recall_0 = (1- (float(false_pos) /neg))
        prec_0 = ((neg-false_pos) / (len(y)-y_detected.sum()))
        recall_1 = (float(true_pos) / pos)
        prec_1 = (y_detected.mean())

        return recall_1, prec_1, recall_0, prec_0

In [18]:
a= rule_extract(model_tree_clf,X_test,y_test,0,0.2,0.1)

## 随机森林

In [6]:
model_RF_clf = ensemble.RandomForestClassifier(max_depth=3,n_estimators=2)
model_RF_clf.fit(X_train,y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=3, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=2, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False)

In [7]:
# model performance on test set
y_pred_test = model_RF_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[68 11]
 [40 24]]


In [8]:
rule_extract(model_RF_clf,X_test,y_test)

[('Fare <= 56.197898864746094 and Fare > 9.416650772094727',
  ('recall for class 1 is: 0.5625',
   'prec for class 1 is: 0.4864864864864865',
   'recall for class 0 is: 0.5189873417721519',
   'prec for class 0 is: 0.38317757009345793')),
 ('Fare > 56.197898864746094',
  ('recall for class 1 is: 0.359375',
   'prec for class 1 is: 0.696969696969697',
   'recall for class 0 is: 0.8734177215189873',
   'prec for class 0 is: 0.575')),
 ('Fare <= 15.625 and Fare > 7.227099895477295',
  ('recall for class 1 is: 0.296875',
   'prec for class 1 is: 0.30158730158730157',
   'recall for class 0 is: 0.44303797468354433',
   'prec for class 0 is: 0.28225806451612906')),
 ('Fare > 75.1146011352539',
  ('recall for class 1 is: 0.28125',
   'prec for class 1 is: 0.72',
   'recall for class 0 is: 0.9113924050632911',
   'prec for class 0 is: 0.576')),
 ('Fare <= 48.20000076293945 and Fare > 16.0',
  ('recall for class 1 is: 0.28125',
   'prec for class 1 is: 0.5294117647058824',
   'recall for class

## BaggingClassifier

In [9]:
model_bagging_clf = ensemble.BaggingClassifier(
                base_estimator=tree.DecisionTreeClassifier(max_depth=3),
                n_estimators=2,
                n_jobs=-1,
                random_state=0)
model_bagging_clf.fit(X_train,y_train)

BaggingClassifier(base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best'),
         bootstrap=True, bootstrap_features=False, max_features=1.0,
         max_samples=1.0, n_estimators=2, n_jobs=-1, oob_score=False,
         random_state=0, verbose=0, warm_start=False)

In [10]:
# model performance on test set
y_pred_test = model_bagging_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[73  6]
 [41 23]]


In [11]:
rule_extract(model_bagging_clf,X_test,y_test)

[('Fare <= 44.239601135253906 and Fare > 10.824999809265137',
  ('recall for class 1 is: 0.46875',
   'prec for class 1 is: 0.5',
   'recall for class 0 is: 0.620253164556962',
   'prec for class 0 is: 0.4336283185840708')),
 ('Fare <= 82.66455078125 and Fare > 19.856250762939453',
  ('recall for class 1 is: 0.40625',
   'prec for class 1 is: 0.5306122448979592',
   'recall for class 0 is: 0.7088607594936709',
   'prec for class 0 is: 0.47863247863247865')),
 ('Age <= 44.5 and Age > 1.5 and Fare <= 18.375',
  ('recall for class 1 is: 0.265625',
   'prec for class 1 is: 0.265625',
   'recall for class 0 is: 0.40506329113924056',
   'prec for class 0 is: 0.25396825396825395')),
 ('Age <= 45.0 and Age > 25.5 and Fare > 51.931251525878906',
  ('recall for class 1 is: 0.203125',
   'prec for class 1 is: 0.8125',
   'recall for class 0 is: 0.9620253164556962',
   'prec for class 0 is: 0.5846153846153846')),
 ('Age > 16.5 and Fare <= 10.824999809265137',
  ('recall for class 1 is: 0.125',
   

## 极端随机树

In [12]:
model_extratree_clf = ensemble.ExtraTreesClassifier(max_depth=3,n_estimators=2)
model_extratree_clf.fit(X_train,y_train)

ExtraTreesClassifier(bootstrap=False, class_weight=None, criterion='gini',
           max_depth=3, max_features='auto', max_leaf_nodes=None,
           min_impurity_split=1e-07, min_samples_leaf=1,
           min_samples_split=2, min_weight_fraction_leaf=0.0,
           n_estimators=2, n_jobs=1, oob_score=False, random_state=None,
           verbose=0, warm_start=False)

In [13]:
# model performance on test set
y_pred_test = model_extratree_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[74  5]
 [56  8]]


In [14]:
rule_extract(model_extratree_clf,X_test,y_test)

[('Age <= 77.38182124642407 and Fare <= 105.26056261931572',
  ('recall for class 1 is: 0.875',
   'prec for class 1 is: 0.4307692307692308',
   'recall for class 0 is: 0.06329113924050633',
   'prec for class 0 is: 0.05747126436781609')),
 ('Age > 24.058929195029002 and Fare <= 185.55054721182938 and Fare > 17.874592264134076',
  ('recall for class 1 is: 0.421875',
   'prec for class 1 is: 0.6136363636363636',
   'recall for class 0 is: 0.7848101265822784',
   'prec for class 0 is: 0.5344827586206896')),
 ('Age <= 24.058929195029002 and Fare <= 185.55054721182938 and Fare > 17.874592264134076',
  ('recall for class 1 is: 0.21875',
   'prec for class 1 is: 0.6086956521739131',
   'recall for class 0 is: 0.8860759493670887',
   'prec for class 0 is: 0.5426356589147286')),
 ('Age > 23.579183918583364 and Fare <= 17.874592264134076 and Fare > 3.804449314728431',
  ('recall for class 1 is: 0.203125',
   'prec for class 1 is: 0.2765957446808511',
   'recall for class 0 is: 0.569620253164556