In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn import tree,ensemble,metrics

from rule import Rule
from rule_extraction import rule_extract

## 数据准备

In [2]:
# fetch dataset
data = pd.read_csv('./dataset/titanic.csv', usecols = ['Age', 'Fare', 'Survived'])

# drop NA records since most Tree algorithm cannot handle
data.dropna(inplace=True)

# split training/test sets
X_train, X_test, y_train, y_test = train_test_split(data[['Age', 'Fare']], data.Survived, test_size=0.2, random_state=0)

# dataset shape
print(X_train.shape, X_test.shape)


(571, 2) (143, 2)


## 单颗决策树

In [3]:
# 单颗决策树模型
# 模型API参考 http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html

model_tree_clf = tree.DecisionTreeClassifier(criterion='gini',max_depth=3)
model_tree_clf.fit(X_train,y_train)

DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')

In [4]:
# model performance on test set
y_pred_test = model_tree_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[67 12]
 [33 31]]


In [5]:
# 输出所有规则，不加筛选条件
# 输出：list of tuples (rule, recall on 0-class, prec on 0-class, recall on 0-class, prec on 0-class, nb)
    
rule_extract(model_tree_clf,X_test,y_test)

[('Age > 8.5 and Fare <= 51.931251525878906 and Fare > 10.335399627685547',
  (0.40625, 0.4727272727272727, 0.6329113924050633, 0.42735042735042733, 1)),
 ('Age > 29.5 and Fare > 51.931251525878906',
  (0.25, 0.6956521739130435, 0.9113924050632911, 0.5669291338582677, 1)),
 ('Age <= 29.5 and Age > 17.5 and Fare > 51.931251525878906',
  (0.109375, 0.6363636363636364, 0.9493670886075949, 0.5514705882352942, 1)),
 ('Age <= 8.5 and Fare <= 27.825000762939453',
  (0.09375, 1.0, 1.0, 0.5766423357664233, 1)),
 ('Age > 8.5 and Fare <= 10.335399627685547',
  (0.078125, 0.1282051282051282, 0.5696202531645569, 0.32608695652173914, 1)),
 ('Age <= 8.5 and Fare <= 51.931251525878906 and Fare > 27.825000762939453',
  (0.03125, 0.3333333333333333, 0.9493670886075949, 0.5319148936170213, 1)),
 ('Age <= 17.5 and Fare > 51.931251525878906',
  (0.03125, 0.6666666666666666, 0.9873417721518988, 0.5531914893617021, 1))]

In [6]:
# 筛选规则，要求对1类样本的召回率在0.1以上，准确率在0.5以上
rule_extract(model_tree_clf,X_test,y_test,0,0.1,0.5,0,0)

[('Age > 29.5 and Fare > 51.931251525878906',
  (0.25, 0.6956521739130435, 0.9113924050632911, 0.5669291338582677, 1)),
 ('Age <= 29.5 and Age > 17.5 and Fare > 51.931251525878906',
  (0.109375, 0.6363636363636364, 0.9493670886075949, 0.5514705882352942, 1))]

## 随机森林

In [7]:
model_RF_clf = ensemble.RandomForestClassifier(max_depth=3,n_estimators=2)
model_RF_clf.fit(X_train,y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=3, max_features='auto', max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            n_estimators=2, n_jobs=1, oob_score=False, random_state=None,
            verbose=0, warm_start=False)

In [8]:
# model performance on test set
y_pred_test = model_RF_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[67 12]
 [38 26]]


In [9]:
rule_extract(model_RF_clf,X_test,y_test)

[('Age > 6.5 and Fare <= 116.63749694824219 and Fare > 10.816650390625',
  (0.640625, 0.5394736842105263, 0.5569620253164558, 0.43137254901960786, 1)),
 ('Age > 6.5 and Fare <= 31.137500762939453',
  (0.484375,
   0.33695652173913043,
   0.22784810126582278,
   0.16071428571428573,
   1)),
 ('Age > 29.0 and Fare > 60.38960266113281',
  (0.21875, 0.7368421052631579, 0.9367088607594937, 0.5736434108527132, 1)),
 ('Age <= 29.0 and Fare > 60.38960266113281',
  (0.125, 0.6153846153846154, 0.9367088607594937, 0.5481481481481482, 1)),
 ('Age <= 6.5 and Fare <= 116.63749694824219 and Fare > 10.816650390625',
  (0.109375, 0.7, 0.9620253164556962, 0.5588235294117647, 1)),
 ('Age <= 30.75 and Fare <= 10.816650390625',
  (0.109375, 0.22580645161290322, 0.6962025316455696, 0.40441176470588236, 1)),
 ('Age <= 6.5 and Fare <= 26.950000762939453',
  (0.09375, 1.0, 1.0, 0.5766423357664233, 1)),
 ('Fare > 143.59164428710938',
  (0.078125, 0.5555555555555556, 0.9493670886075949, 0.5434782608695652, 1)),


## BaggingClassifier

In [10]:
model_bagging_clf = ensemble.BaggingClassifier(
                base_estimator=tree.DecisionTreeClassifier(max_depth=3),
                n_estimators=2,
                n_jobs=-1,
                random_state=0)
model_bagging_clf.fit(X_train,y_train)

BaggingClassifier(base_estimator=DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,
            max_features=None, max_leaf_nodes=None,
            min_impurity_split=1e-07, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best'),
         bootstrap=True, bootstrap_features=False, max_features=1.0,
         max_samples=1.0, n_estimators=2, n_jobs=-1, oob_score=False,
         random_state=0, verbose=0, warm_start=False)

In [11]:
# model performance on test set
y_pred_test = model_bagging_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[73  6]
 [41 23]]


In [12]:
rule_extract(model_bagging_clf,X_test,y_test)

  warn("rule %s reach no samples" % str(rule))


[('Fare <= 44.239601135253906 and Fare > 10.824999809265137',
  (0.46875, 0.5, 0.620253164556962, 0.4336283185840708, 1)),
 ('Fare <= 82.66455078125 and Fare > 19.856250762939453',
  (0.40625, 0.5306122448979592, 0.7088607594936709, 0.47863247863247865, 1)),
 ('Age <= 44.5 and Age > 1.5 and Fare <= 18.375',
  (0.265625, 0.265625, 0.40506329113924056, 0.25396825396825395, 1)),
 ('Age <= 45.0 and Age > 25.5 and Fare > 51.931251525878906',
  (0.203125, 0.8125, 0.9620253164556962, 0.5846153846153846, 1)),
 ('Age > 16.5 and Fare <= 10.824999809265137',
  (0.125, 0.17777777777777778, 0.5316455696202531, 0.3111111111111111, 1)),
 ('Age > 27.0 and Fare > 82.66455078125',
  (0.125, 0.7272727272727273, 0.9620253164556962, 0.562962962962963, 1)),
 ('Age > 45.0 and Fare > 51.931251525878906',
  (0.078125, 0.5, 0.9367088607594937, 0.5362318840579711, 1)),
 ('Age <= 25.5 and Age > 17.5 and Fare > 51.931251525878906',
  (0.078125, 0.625, 0.9620253164556962, 0.5507246376811594, 1)),
 ('Age <= 27.0 and

## 极端随机树

In [13]:
model_extratree_clf = ensemble.ExtraTreesClassifier(max_depth=3,n_estimators=2)
model_extratree_clf.fit(X_train,y_train)

ExtraTreesClassifier(bootstrap=False, class_weight=None, criterion='gini',
           max_depth=3, max_features='auto', max_leaf_nodes=None,
           min_impurity_split=1e-07, min_samples_leaf=1,
           min_samples_split=2, min_weight_fraction_leaf=0.0,
           n_estimators=2, n_jobs=1, oob_score=False, random_state=None,
           verbose=0, warm_start=False)

In [14]:
# model performance on test set
y_pred_test = model_extratree_clf.predict(X_test)
print(metrics.confusion_matrix(y_test,y_pred_test))

[[75  4]
 [58  6]]


In [15]:
rule_extract(model_extratree_clf,X_test,y_test)

  warn("rule %s reach no samples" % str(rule))
  warn("rule %s reach no samples" % str(rule))
  warn("rule %s reach no samples" % str(rule))
  warn("rule %s reach no samples" % str(rule))
  warn("rule %s reach no samples" % str(rule))


[('Age <= 40.01392810243131 and Fare <= 448.03098606177304',
  (0.8125, 0.45217391304347826, 0.20253164556962022, 0.17582417582417584, 1)),
 ('Age <= 72.25456073751852 and Age > 11.725676720469133 and Fare <= 128.86898411733532',
  (0.78125, 0.4132231404958678, 0.10126582278481011, 0.08602150537634409, 1)),
 ('Age <= 11.725676720469133 and Fare <= 128.86898411733532',
  (0.125, 0.6666666666666666, 0.9493670886075949, 0.5555555555555556, 1)),
 ('Age <= 56.77933858288455 and Age > 40.01392810243131 and Fare <= 448.03098606177304',
  (0.109375, 0.3888888888888889, 0.8607594936708861, 0.5, 1)),
 ('Fare <= 212.73718875609006 and Fare > 128.86898411733532',
  (0.09375, 0.6666666666666666, 0.9620253164556962, 0.5547445255474452, 1)),
 ('Age > 56.77933858288455 and Fare <= 448.03098606177304 and Fare > 10.114682476716649',
  (0.078125, 0.5, 0.9367088607594937, 0.5362318840579711, 1)),
 ('Age <= 74.47146435662707 and Age > 72.25456073751852 and Fare <= 128.86898411733532',
  (0.0, 0.0, 0.0, 0.0