In [71]:
%matplotlib

import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
import matplotlib.pyplot as plt

Using matplotlib backend: Qt5Agg


In [72]:
# 导入数据集
data_file = "data.csv"
data = pd.read_csv(data_file, index_col=0)

In [73]:
data.head()

Unnamed: 0_level_0,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Th...",female,38.0,1,0,PC 17599,71.2833,C85,C
3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [74]:
# object 为非数值型变量
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 891 entries, 1 to 891
Data columns (total 11 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Survived  891 non-null    int64  
 1   Pclass    891 non-null    int64  
 2   Name      891 non-null    object 
 3   Sex       891 non-null    object 
 4   Age       714 non-null    float64
 5   SibSp     891 non-null    int64  
 6   Parch     891 non-null    int64  
 7   Ticket    891 non-null    object 
 8   Fare      891 non-null    float64
 9   Cabin     204 non-null    object 
 10  Embarked  889 non-null    object 
dtypes: float64(2), int64(4), object(5)
memory usage: 83.5+ KB


In [75]:
# 对数据集进行预处理

# 缺失值:
# 删除缺失值过多的列，以及和预测的 y 没有关系的列
data.drop(["Cabin", "Name", "Ticket"], inplace=True, axis=1)

In [76]:
# 缺失值较多的列填补，有些特征只缺失一两个值，可以直接删除记录
data["Age"] = data["Age"].fillna(data["Age"].mean())
data = data.dropna()
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 1 to 891
Data columns (total 8 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Survived  889 non-null    int64  
 1   Pclass    889 non-null    int64  
 2   Sex       889 non-null    object 
 3   Age       889 non-null    float64
 4   SibSp     889 non-null    int64  
 5   Parch     889 non-null    int64  
 6   Fare      889 non-null    float64
 7   Embarked  889 non-null    object 
dtypes: float64(2), int64(4), object(2)
memory usage: 62.5+ KB


In [77]:
# 分类型变量转换为数值型变量

# 将二分类变量转换为数值型变量
data["Sex"] = (data["Sex"] == "male").astype("int")

In [78]:
# 将三分类变量转换为数值型变量
labels = data["Embarked"].unique().tolist()
data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
data.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 889 entries, 1 to 891
Data columns (total 8 columns):
 #   Column    Non-Null Count  Dtype  
---  ------    --------------  -----  
 0   Survived  889 non-null    int64  
 1   Pclass    889 non-null    int64  
 2   Sex       889 non-null    int64  
 3   Age       889 non-null    float64
 4   SibSp     889 non-null    int64  
 5   Parch     889 non-null    int64  
 6   Fare      889 non-null    float64
 7   Embarked  889 non-null    int64  
dtypes: float64(2), int64(6)
memory usage: 62.5 KB


In [79]:
# 提取标签和特征矩阵，分测试集和训练集
X = data.iloc[:, data.columns != "Survived"]
y = data.iloc[:, data.columns == "Survived"]

XTrain, XTest, YTrain, YTest = train_test_split(X, y, test_size=0.3)
XTrain.head()

Unnamed: 0_level_0,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
PassengerId,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
855,2,0,44.0,1,0,26.0,0
838,3,1,29.699118,0,0,8.05,0
478,3,1,29.0,1,0,7.0458,0
248,2,0,24.0,0,2,14.5,0
886,3,0,39.0,0,5,29.125,2


In [80]:
# 修正测试集和训练集的索引
for i in [XTrain, XTest, YTrain, YTest]:
    i.index = range(i.shape[0])

In [81]:
XTrain.shape

(622, 7)

In [82]:
XTrain.head()

Unnamed: 0,Pclass,Sex,Age,SibSp,Parch,Fare,Embarked
0,2,0,44.0,1,0,26.0,0
1,3,1,29.699118,0,0,8.05,0
2,3,1,29.0,1,0,7.0458,0
3,2,0,24.0,0,2,14.5,0
4,3,0,39.0,0,5,29.125,2


In [83]:
# 实例化，粗略跑一下查看结果
clf = DecisionTreeClassifier(random_state=25)
clf = clf.fit(XTrain, YTrain)
score_ = clf.score(XTest, YTest)
score_

0.7752808988764045

In [84]:
# 交叉验证粗略看一下
score = cross_val_score(clf, X, y, cv=10).mean()
score

0.7739274770173645

In [85]:
# 在不同 max_depth 下观察模型的拟合情况
tr = []
te = []
for i in range(10):
    clf = DecisionTreeClassifier(random_state=25
                                 ,max_depth=i+1
                                 ,criterion="entropy"
                                )
    clf = clf.fit(XTrain, YTrain)
    score_tr = clf.score(XTrain, YTrain)
    score_te = cross_val_score(clf, X, y, cv=10).mean()
    tr.append(score_tr)
    te.append(score_te)
print(max(te))

0.8177860061287026


In [86]:
plt.figure()
plt.plot(range(1, 11), tr, color="red", label="train")
plt.plot(range(1, 11), te, color="blue", label="test")
plt.xticks(range(1, 11))
plt.legend()
plt.show()

In [88]:
# 网格搜索
import numpy as np

# 字典传参数，*list 会把 list 拆成独立元素，如果不拆，(*arg, **kwargs) 会把整个 list 作为参数
parameters = {"splitter": ("best", "random")
              ,"criterion": ("gini", "entropy")
              ,"max_depth": [*range(1, 10)]
              ,"min_samples_leaf": [*range(1, 50, 5)]
              ,"min_impurity_decrease": [*np.linspace(0, 0.5, 20)]
             }

clf = DecisionTreeClassifier(random_state=25)
GS = GridSearchCV(clf, parameters, cv=10)
GS.fit(XTrain, YTrain)

GridSearchCV(cv=10, estimator=DecisionTreeClassifier(random_state=25),
             param_grid={'criterion': ('gini', 'entropy'),
                         'max_depth': [1, 2, 3, 4, 5, 6, 7, 8, 9],
                         'min_impurity_decrease': [0.0, 0.02631578947368421,
                                                   0.05263157894736842,
                                                   0.07894736842105263,
                                                   0.10526315789473684,
                                                   0.13157894736842105,
                                                   0.15789473684210525,
                                                   0.18421052631578946,
                                                   0.21052631578947367,
                                                   0.23684210526315788,
                                                   0.2631578947368421,
                                                   0.2894736842105263,
        

In [89]:
GS.best_params_

{'criterion': 'entropy',
 'max_depth': 5,
 'min_impurity_decrease': 0.0,
 'min_samples_leaf': 1,
 'splitter': 'random'}

In [90]:
GS.best_score_

0.8312596006144393