# 基于决策树的客户流失预测

数据集：https://www.kaggle.com/datasets/blastchar/telco-customer-churn
Github地址：https://github.com/Liujunyi1999/DecisionTree

### 一、数据预处理

In [124]:
import pandas as pd
df = pd.read_csv('Telco-Customer-Churn.csv')
print(df.head())

   customerID  gender  SeniorCitizen Partner Dependents  tenure PhoneService  \
0  7590-VHVEG  Female              0     Yes         No       1           No   
1  5575-GNVDE    Male              0      No         No      34          Yes   
2  3668-QPYBK    Male              0      No         No       2          Yes   
3  7795-CFOCW    Male              0      No         No      45           No   
4  9237-HQITU  Female              0      No         No       2          Yes   

      MultipleLines InternetService OnlineSecurity  ... DeviceProtection  \
0  No phone service             DSL             No  ...               No   
1                No             DSL            Yes  ...              Yes   
2                No             DSL            Yes  ...               No   
3  No phone service             DSL            Yes  ...              Yes   
4                No     Fiber optic             No  ...               No   

  TechSupport StreamingTV StreamingMovies        Contract Pape

In [125]:
print(df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7043 entries, 0 to 7042
Data columns (total 21 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   customerID        7043 non-null   object 
 1   gender            7043 non-null   object 
 2   SeniorCitizen     7043 non-null   int64  
 3   Partner           7043 non-null   object 
 4   Dependents        7043 non-null   object 
 5   tenure            7043 non-null   int64  
 6   PhoneService      7043 non-null   object 
 7   MultipleLines     7043 non-null   object 
 8   InternetService   7043 non-null   object 
 9   OnlineSecurity    7043 non-null   object 
 10  OnlineBackup      7043 non-null   object 
 11  DeviceProtection  7043 non-null   object 
 12  TechSupport       7043 non-null   object 
 13  StreamingTV       7043 non-null   object 
 14  StreamingMovies   7043 non-null   object 
 15  Contract          7043 non-null   object 
 16  PaperlessBilling  7043 non-null   object 


这里df中大部分均为object类型，但是决策树中要求数据为数值型，所以下面将对数据类型进行处理。

In [126]:
df = df.dropna()
df['gender'] = df['gender'].map({'Female': 0, 'Male': 1})
df['Partner']=df['Partner'].map({'No':0, 'Yes':1})
df['Dependents']=df['Dependents'].map({'No':0, 'Yes':1})
df['PhoneService']=df['PhoneService'].map({'No':0, 'Yes':1})
df['MultipleLines']=df['MultipleLines'].map({'No':0, 'Yes':1, 'No phone service':2})
df['InternetService']=df['InternetService'].map({'DSL':0, 'Fiber optic':1, 'No':2})
df['OnlineSecurity'] = df['OnlineSecurity'].map({'No': 0, 'Yes': 1, 'No internet service': 2})
df['OnlineBackup'] = df['OnlineBackup'].map({'No': 0, 'Yes': 1, 'No internet service': 2})
df['DeviceProtection'] = df['DeviceProtection'].map({'No': 0, 'Yes': 1, 'No internet service': 2})
df['TechSupport'] = df['TechSupport'].map({'No': 0, 'Yes': 1, 'No internet service': 2})
df['StreamingTV'] = df['StreamingTV'].map({'No': 0, 'Yes': 1, 'No internet service': 2})
df['StreamingMovies'] = df['StreamingMovies'].map({'No': 0, 'Yes': 1, 'No internet service': 2})
df['Contract'] = df['Contract'].map({'Month-to-month': 0, 'One year': 1, 'Two year': 2})
df['PaperlessBilling'] = df['PaperlessBilling'].map({'No': 0, 'Yes': 1})
df['PaymentMethod'] = df['PaymentMethod'].map({'Electronic check': 0, 'Mailed check': 1, 'Bank transfer (automatic)': 2, 'Credit card (automatic)': 3})
df['Churn'] = df['Churn'].map({'No': 0, 'Yes': 1})

TotalCharges中的数据都是数值，但是类型是字符串类型，所以需要进行类型转换：

In [154]:
df['TotalCharges'] = pd.to_numeric(df['TotalCharges'], errors='coerce')

删除不需要或和题目要求不相关的列，例如用户ID，性别等和客户流失率无关的数据类别。

In [155]:
df.drop(['customerID', 'gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 
         'PaperlessBilling', 'PaymentMethod'], axis=1, inplace=True)

KeyError: "['customerID', 'gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'PaperlessBilling', 'PaymentMethod'] not found in axis"

用均值填充TotalCharges中的缺失值空缺

In [160]:
df.fillna(df['TotalCharges'].mean(), inplace=True)

检查数据集中是否存在NaN类型数据

In [161]:
print(df.isnull().sum())

SeniorCitizen       0
tenure              0
InternetService     0
OnlineSecurity      0
OnlineBackup        0
DeviceProtection    0
TechSupport         0
StreamingTV         0
StreamingMovies     0
Contract            0
MonthlyCharges      0
TotalCharges        0
Churn               0
dtype: int64


检查数据类型以及缺失值

In [162]:
print(df.info())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 7043 entries, 0 to 7042
Data columns (total 13 columns):
 #   Column            Non-Null Count  Dtype  
---  ------            --------------  -----  
 0   SeniorCitizen     7043 non-null   int64  
 1   tenure            7043 non-null   int64  
 2   InternetService   7043 non-null   int64  
 3   OnlineSecurity    7043 non-null   int64  
 4   OnlineBackup      7043 non-null   int64  
 5   DeviceProtection  7043 non-null   int64  
 6   TechSupport       7043 non-null   int64  
 7   StreamingTV       7043 non-null   int64  
 8   StreamingMovies   7043 non-null   int64  
 9   Contract          7043 non-null   int64  
 10  MonthlyCharges    7043 non-null   float64
 11  TotalCharges      7043 non-null   float64
 12  Churn             7043 non-null   int64  
dtypes: float64(2), int64(11)
memory usage: 715.4 KB
None


In [163]:
print(df['Churn'])

0       0
1       0
2       1
3       0
4       1
       ..
7038    0
7039    0
7040    0
7041    1
7042    0
Name: Churn, Length: 7043, dtype: int64


### 二、使用相关性分析选择特征

计算客户流失和其他因素的相关性：

In [165]:
corr_matrix = df.corr()
corr_matrix['Churn'].sort_values(ascending=False)

Churn               1.000000
MonthlyCharges      0.193356
SeniorCitizen       0.150889
InternetService    -0.047291
TotalCharges       -0.199428
StreamingTV        -0.205742
StreamingMovies    -0.207256
DeviceProtection   -0.281465
OnlineBackup       -0.291449
TechSupport        -0.329852
OnlineSecurity     -0.332819
tenure             -0.352229
Contract           -0.396713
Name: Churn, dtype: float64

使用递归特性消除（RFE方法）选择特征：

In [133]:
from sklearn.feature_selection import RFE
from sklearn.tree import DecisionTreeClassifier

X = df.drop(['Churn'], axis=1)
y = df['Churn']

在这段代码中，使用了 DecisionTreeClassifier() 作为 RFE 算法的模型，5 作为选择的特征数量。这里，RFE 算法会首先使用所有的特征进行训练和评估，然后从所有的特征中选出对模型性能影响最小的一个特征进行删除，再次训练和评估模型。重复上述过程，直到选择出指定数量的特征为止。

In [139]:
model = DecisionTreeClassifier()
rfe = RFE(model, n_features_to_select=5)
rfe = rfe.fit(X, y)

selected_features = X.columns[rfe.support_]
print(selected_features)

Index(['tenure', 'OnlineSecurity', 'Contract', 'MonthlyCharges',
       'TotalCharges'],
      dtype='object')


### 三、构建训练集和测试集

In [None]:
按照上面得到的selected_features按照7：3的比例将数据集分成训练集X和测试集y

In [136]:
from sklearn.model_selection import train_test_split
X = df[selected_features]
y = df['Churn']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

### 四、构建决策树模型

In [137]:
model = DecisionTreeClassifier()

在训练集上训练模型

In [146]:
model.fit(X_train, y_train)

在测试集上进行预测

In [147]:
y_pred = model.predict(X_test)

### 五、模型评估

计算模型的准确率、召回率、F1值等指标

In [148]:
from sklearn.metrics import accuracy_score,precision_score,recall_score,f1_score

In [150]:
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)

In [152]:
print('Accuracy:', accuracy)
print('Precision:', precision)
print('Recall:', recall)
print('F1:', f1)

Accuracy: 0.7288215806909607
Precision: 0.5009276437847866
Recall: 0.47038327526132406
F1: 0.48517520215633425
