In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score
from sklearn import tree
from sklearn.model_selection import cross_validate
from sklearn.model_selection import GridSearchCV

In [2]:
# 读入文件
train_data_df = pd.read_csv("../data/traindata.csv")
train_label_df = pd.read_csv("../data/trainlabel.txt")
features = train_data_df.columns

# 数值填充
train_data_df.replace('Preschool', 1, inplace=True)
train_data_df.replace('1st-4th', 2, inplace=True)
train_data_df.replace('5th-6th', 3, inplace=True)
train_data_df.replace('7th-8th', 4, inplace=True)
train_data_df.replace('9th', 5, inplace=True)
train_data_df.replace('10th', 6, inplace=True)
train_data_df.replace('11th', 7, inplace=True)
train_data_df.replace('12th', 8, inplace=True)
train_data_df.replace('HS-grad', 9, inplace=True)
train_data_df.replace('Some-college', 10, inplace=True)
train_data_df.replace('Assoc-voc', 11, inplace=True)
train_data_df.replace('Assoc-acdm', 12, inplace=True)
train_data_df.replace('Bachelors', 13, inplace=True)
train_data_df.replace('Prof-school', 14, inplace=True)
train_data_df.replace('Masters', 15, inplace=True)
train_data_df.replace('Doctorate', 16, inplace=True)

# 对非数值信息独热编码
encoder = OneHotEncoder()
encoded_features = encoder.fit_transform(train_data_df[['workclass', 'marital.status','occupation', 'relationship', 'race', 'sex', 'native.country']]).toarray()
encoded_features = pd.concat([pd.DataFrame(encoded_features), train_data_df[['age', 'fnlwgt', 'education', 'education.num', 'capital.gain', 'capital.loss', 'hours.per.week']]], axis=1)
encoded_features

# 数据归一化
encoded_features.columns = encoded_features.columns.astype(str)

# 使用StandardScaler进行Z-Score归一化
zscore_scaler = StandardScaler()
df_zscore_scaled = zscore_scaler.fit_transform(encoded_features)

encoded_features = pd.DataFrame(df_zscore_scaled)

# 相关性分析
correlation_matrix = encoded_features.corr()

# 获取独热编码后的特征名称
encoded_feature_names = encoder.get_feature_names_out(['workclass', 'marital.status','occupation', 'relationship', 'race', 'sex', 'native.country'])
arr = np.array(['age','fnlwgt', 'education', 'education.num', 'capital.gain', 'capital.loss', 'hours.per.week'])
encoded_feature_names = np.append(encoded_feature_names, arr)
for i in range(len(encoded_feature_names)):
    print(i, encoded_feature_names[i])
for i in range(len(correlation_matrix)):
    for j in range(len(correlation_matrix)):
        if i >= j or i == 44: continue
        e = correlation_matrix[i][j]
        if e > 0.6 or e < -0.6: 
            print(i, encoded_feature_names[i], j, encoded_feature_names[j], e)
            if i in encoded_features.columns: encoded_features = encoded_features.drop(i, axis=1)
                
encoded_features

0 workclass_?
1 workclass_Federal-gov
2 workclass_Local-gov
3 workclass_Never-worked
4 workclass_Private
5 workclass_Self-emp-inc
6 workclass_Self-emp-not-inc
7 workclass_State-gov
8 workclass_Without-pay
9 marital.status_Divorced
10 marital.status_Married-AF-spouse
11 marital.status_Married-civ-spouse
12 marital.status_Married-spouse-absent
13 marital.status_Never-married
14 marital.status_Separated
15 marital.status_Widowed
16 occupation_?
17 occupation_Adm-clerical
18 occupation_Armed-Forces
19 occupation_Craft-repair
20 occupation_Exec-managerial
21 occupation_Farming-fishing
22 occupation_Handlers-cleaners
23 occupation_Machine-op-inspct
24 occupation_Other-service
25 occupation_Priv-house-serv
26 occupation_Prof-specialty
27 occupation_Protective-serv
28 occupation_Sales
29 occupation_Tech-support
30 occupation_Transport-moving
31 relationship_Husband
32 relationship_Not-in-family
33 relationship_Other-relative
34 relationship_Own-child
35 relationship_Unmarried
36 relationship_W

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,81,82,83,84,85,86,88,89,90,91
0,-0.175098,3.782449,-0.014813,-1.515578,-0.189737,-0.286956,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,2.821506,-0.114762,1.131681,0.38330,-0.214689,-2.148673
1,-0.175098,-0.264379,-0.014813,-1.515578,5.270463,-0.286956,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,0.108878,2.706469,1.131681,-0.14506,-0.214689,0.775212
2,-0.175098,-0.264379,-0.014813,-1.515578,-0.189737,3.484849,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,-2.932691,-0.045458,-0.022952,-0.697579,1.440206,-0.416186,-0.14506,3.690793,0.775212
3,-0.175098,-0.264379,-0.014813,0.659814,-0.189737,-0.286956,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,0.182192,-0.050070,-0.416186,-0.14506,-0.214689,-0.036978
4,-0.175098,-0.264379,-0.014813,0.659814,-0.189737,-0.286956,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,-1.210779,1.111395,-0.029219,-0.14506,-0.214689,-1.255264
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
22787,-0.175098,-0.264379,-0.014813,0.659814,-0.189737,-0.286956,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,0.695392,-1.042350,-1.190120,-0.14506,-0.214689,0.775212
22788,-0.175098,-0.264379,-0.014813,-1.515578,-0.189737,3.484849,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,-0.917522,-0.674011,-0.416186,-0.14506,-0.214689,-0.036978
22789,-0.175098,-0.264379,-0.014813,0.659814,-0.189737,-0.286956,-0.206136,-0.018738,-0.396990,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,-0.184379,1.060678,-0.416186,-0.14506,-0.214689,-0.036978
22790,-0.175098,-0.264379,-0.014813,-1.515578,-0.189737,-0.286956,-0.206136,-0.018738,2.518958,-0.027321,...,-0.025662,0.340984,-0.045458,-0.022952,-0.404322,1.382914,-0.029219,-0.14506,-0.214689,-0.036978


In [3]:
from sklearn.model_selection import cross_val_score

def evaluate(model):

    # 执行交叉验证
    scores = cross_val_score(model, encoded_features, train_label_df, cv=5)  # cv=5 表示将数据集划分为5个折叠

    # 打印每次交叉验证的得分
    print("Cross-validation scores:", scores)

    # 打印平均得分
    print("Average score:", scores.mean())

In [4]:
# baseline
clf = DecisionTreeClassifier(random_state=6)
evaluate(clf)

Cross-validation scores: [0.81114279 0.81838122 0.81505046 0.80320316 0.81658622]
Average score: 0.8128727720922149


In [5]:
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(encoded_features, train_label_df, test_size=0.2, random_state=6)

In [6]:
# 设置参数范围
param_grid = {'criterion': ['gini', 'entropy', 'log_loss'],
              'splitter': ['best', 'random']}

# 创建 GridSearchCV 对象
grid_search = GridSearchCV(estimator=DecisionTreeClassifier(random_state=6), 
                           param_grid=param_grid, 
                           cv=5, 
                           scoring='accuracy')

# 执行参数搜索
grid_search.fit(X_train, y_train)

# 打印所有参数组合的训练结果
results_df = pd.DataFrame(grid_search.cv_results_)
print(results_df[['params', 'mean_test_score']])

# 输出最优参数和对应的评估结果
print("Best parameters:", grid_search.best_params_)
print("Best score:", grid_search.best_score_)

                                            params  mean_test_score
0        {'criterion': 'gini', 'splitter': 'best'}         0.813909
1      {'criterion': 'gini', 'splitter': 'random'}         0.802556
2     {'criterion': 'entropy', 'splitter': 'best'}         0.810234
3   {'criterion': 'entropy', 'splitter': 'random'}         0.807492
4    {'criterion': 'log_loss', 'splitter': 'best'}         0.810234
5  {'criterion': 'log_loss', 'splitter': 'random'}         0.807492
Best parameters: {'criterion': 'gini', 'splitter': 'best'}
Best score: 0.813908786082114


In [7]:
# 创建决策树模型
clf = DecisionTreeClassifier(random_state=6, criterion='gini')

# 训练决策树模型
clf.fit(X_train, y_train)

# 执行后剪枝
path = clf.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas, impurities = path.ccp_alphas, path.impurities

clfs = []
for ccp_alpha in ccp_alphas:
    clf = DecisionTreeClassifier(random_state=42, ccp_alpha=ccp_alpha)
    clf.fit(X_train, y_train)
    clfs.append(clf)

# 计算不同剪枝参数下的准确率
accuracies = [accuracy_score(y_test, clf.predict(X_test)) for clf in clfs]
best_alpha = ccp_alphas[np.argmax(accuracies)]
best_clf = clfs[np.argmax(accuracies)]

print("Best alpha:", best_alpha)
print("Best accuracy:", accuracies[np.argmax(accuracies)])

Best alpha: 0.0001755087974549532
Best accuracy: 0.8613731081377495


In [8]:
feature_importances = best_clf.feature_importances_

# 将特征名称与对应的重要性得分一一对应起来
importances_dict = dict(zip(encoded_feature_names, feature_importances))

sizes = []
labels = []
distances = []
# 打印每个特征和对应的重要性得分
for feature, importance in importances_dict.items():
    if importance > 0: 
        sizes.append(importance)
        labels.append(feature)
        distances.append(0.01/importance)
        print(f"{feature}: {importance}")

workclass_?: 0.002292878195014191
workclass_Federal-gov: 0.002973048782111185
workclass_Private: 0.0021091075553618093
workclass_Self-emp-inc: 0.009586453712542305
occupation_Armed-Forces: 0.013349805374056738
occupation_Craft-repair: 0.003435368620353291
occupation_Exec-managerial: 0.0015585627112692517
occupation_Farming-fishing: 0.0011035498779476671
occupation_Handlers-cleaners: 0.0041636784048882805
occupation_Other-service: 0.00397938835699604
occupation_Prof-specialty: 0.0015142323157715492
occupation_Protective-serv: 0.0007465861473625799
occupation_Sales: 0.003979801111144255
occupation_Tech-support: 0.32124405701719755
relationship_Own-child: 0.06752021293414374
native.country_South: 0.0013728372206275718
native.country_Trinadad&Tobago: 0.051128750110392245
native.country_United-States: 0.007873806420248163
native.country_Vietnam: 0.20671050720874562
native.country_Yugoslavia: 0.20434391157558723
age: 0.05237420719939163
fnlwgt: 0.036639249148847086


In [12]:
# 预测
# 读入文件
test_data_df = pd.read_csv("../data/testdata.csv")
test_data_df.at[9360, 'native.country'] = '?'
features = test_data_df.columns

# 数值填充
test_data_df.replace('Preschool', 1, inplace=True)
test_data_df.replace('1st-4th', 2, inplace=True)
test_data_df.replace('5th-6th', 3, inplace=True)
test_data_df.replace('7th-8th', 4, inplace=True)
test_data_df.replace('9th', 5, inplace=True)
test_data_df.replace('10th', 6, inplace=True)
test_data_df.replace('11th', 7, inplace=True)
test_data_df.replace('12th', 8, inplace=True)
test_data_df.replace('HS-grad', 9, inplace=True)
test_data_df.replace('Some-college', 10, inplace=True)
test_data_df.replace('Assoc-voc', 11, inplace=True)
test_data_df.replace('Assoc-acdm', 12, inplace=True)
test_data_df.replace('Bachelors', 13, inplace=True)
test_data_df.replace('Prof-school', 14, inplace=True)
test_data_df.replace('Masters', 15, inplace=True)
test_data_df.replace('Doctorate', 16, inplace=True)

# 对非数值信息独热编码
encoder = OneHotEncoder()
X_pred = encoder.fit_transform(test_data_df[['workclass', 'marital.status','occupation', 'relationship', 'race', 'sex', 'native.country']]).toarray()
X_pred = pd.concat([pd.DataFrame(X_pred), test_data_df[['age', 'fnlwgt', 'education', 'education.num', 'capital.gain', 'capital.loss', 'hours.per.week']]], axis=1)
X_pred

# 数据归一化
X_pred.columns = X_pred.columns.astype(str)

# 使用StandardScaler进行Z-Score归一化
zscore_scaler = StandardScaler()
df_zscore_scaled = zscore_scaler.fit_transform(X_pred)

X_pred = pd.DataFrame(df_zscore_scaled)

# 相关性分析
correlation_matrix = X_pred.corr()

# 获取独热编码后的特征名称
encoded_feature_names = encoder.get_feature_names_out(['workclass', 'marital.status','occupation', 'relationship', 'race', 'sex', 'native.country'])
arr = np.array(['age', 'fnlwgt', 'education', 'education.num', 'capital.gain', 'capital.loss', 'hours.per.week'])
encoded_feature_names = np.append(encoded_feature_names, arr)
for i in range(len(encoded_feature_names)):
    print(i, encoded_feature_names[i])
for i in range(len(correlation_matrix)):
    for j in range(len(correlation_matrix)):
        if i >= j or i == 44: continue
        e = correlation_matrix[i][j]
        if e > 0.6 or e < -0.6: 
            print(i, encoded_feature_names[i], j, encoded_feature_names[j], e)
            if i in X_pred.columns: X_pred = X_pred.drop(i, axis=1)
                
X_pred

0 workclass_?
1 workclass_Federal-gov
2 workclass_Local-gov
3 workclass_Never-worked
4 workclass_Private
5 workclass_Self-emp-inc
6 workclass_Self-emp-not-inc
7 workclass_State-gov
8 workclass_Without-pay
9 marital.status_Divorced
10 marital.status_Married-AF-spouse
11 marital.status_Married-civ-spouse
12 marital.status_Married-spouse-absent
13 marital.status_Never-married
14 marital.status_Separated
15 marital.status_Widowed
16 occupation_?
17 occupation_Adm-clerical
18 occupation_Armed-Forces
19 occupation_Craft-repair
20 occupation_Exec-managerial
21 occupation_Farming-fishing
22 occupation_Handlers-cleaners
23 occupation_Machine-op-inspct
24 occupation_Other-service
25 occupation_Priv-house-serv
26 occupation_Prof-specialty
27 occupation_Protective-serv
28 occupation_Sales
29 occupation_Tech-support
30 occupation_Transport-moving
31 relationship_Husband
32 relationship_Not-in-family
33 relationship_Other-relative
34 relationship_Own-child
35 relationship_Unmarried
36 relationship_W

Unnamed: 0,1,2,3,4,5,6,7,8,9,10,...,81,82,83,84,85,86,88,89,90,91
0,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,2.365352,-0.304488,-0.036433,0.234750,-0.221203,1.577408
1,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-0.347320,0.344807,1.927770,-0.148041,-0.221203,1.175093
2,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,-2.933548,-0.045293,-0.020239,-0.200689,1.148321,-0.429274,-0.148041,-0.221203,-0.031851
3,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-0.493951,0.375968,1.142089,-0.148041,-0.221203,0.370464
4,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-1.373736,0.213488,-0.036433,-0.148041,-0.221203,-1.962962
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
9764,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-0.787212,-0.782908,1.142089,-0.148041,-0.221203,-0.434166
9765,-0.172409,-0.256716,-0.01431,0.658054,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-0.567266,0.338517,-0.429274,-0.148041,-0.221203,-0.031851
9766,5.800159,-0.256716,-0.01431,-1.519632,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-0.200689,-0.052032,1.142089,-0.148041,-0.221203,-0.031851
9767,-0.172409,-0.256716,-0.01431,-1.519632,-0.185213,-0.300078,-0.19813,-0.02479,-0.398717,-0.02479,...,-0.020239,0.340884,-0.045293,-0.020239,-1.373736,-0.249955,-0.429274,-0.148041,-0.221203,-1.641110


In [13]:
best_clf.fit(encoded_features, train_label_df)
y_pred = best_clf.predict(X_pred)
y_pred

array([0, 1, 0, ..., 1, 0, 0])

In [14]:
# 指定要写入的文件路径和文件名
output_file = "../data/decision_tree_label.txt"

# 打开文件并写入预测结果
with open(output_file, "w") as file:
    for prediction in y_pred:
        file.write(str(prediction) + "\n")

print("预测结果已写入文件:", output_file)

预测结果已写入文件: ../data/decision_tree_label.txt
