In [11]:
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.metrics import classification_report
from sklearn.pipeline import Pipeline

# 指定第3列的数据类型为字符串类型，其余列根据需要设置
dtype_spec = {3: 'str'}

# 读取数据并指定数据类型
df = pd.read_csv('./ad.data', header=None, dtype=dtype_spec, low_memory=False)

# 特征和标签列的设置
explanatory_variable_columns = set(df.columns.values)
explanatory_variable_columns.remove(len(df.columns.values)-1)
response_variable_column = df[len(df.columns.values)-1]  # 最后一列描述类

# 将标签转换为二进制
y = [1 if e == 'ad.' else 0 for e in response_variable_column]
X = df[list(explanatory_variable_columns)].copy()

# 将空值替换为 -1
X.replace(to_replace=' *?', value=-1, regex=True, inplace=True)

# 拆分数据集
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)

# 创建Pipeline
pipeline = Pipeline([
    ('clf', DecisionTreeClassifier(criterion='entropy'))
])

# 定义参数网格
parameters = {
    'clf__max_depth': (150, 155, 160),
    'clf__min_samples_split': (2, 3),
    'clf__min_samples_leaf': (1, 2, 3)
}

# 网格搜索
grid_search = GridSearchCV(pipeline, parameters, n_jobs=-1, verbose=1, scoring='f1')
grid_search.fit(X_train, y_train)

# 打印最佳参数和得分
best_parameters = grid_search.best_estimator_.get_params()
print('Best score: %0.3f' % grid_search.best_score_)
print('Best parameters set:')
for param_name in sorted(parameters.keys()):
    print('\t%s: %r' % (param_name, best_parameters[param_name]))

# 在测试集上进行预测并打印分类报告
predictions = grid_search.predict(X_test)
print(classification_report(y_test, predictions))


Fitting 5 folds for each of 18 candidates, totalling 90 fits
Best score: 0.869
Best parameters set:
	clf__max_depth: 155
	clf__min_samples_leaf: 1
	clf__min_samples_split: 3
              precision    recall  f1-score   support

           0       0.97      0.99      0.98       682
           1       0.93      0.85      0.89       138

    accuracy                           0.96       820
   macro avg       0.95      0.92      0.93       820
weighted avg       0.96      0.96      0.96       820

