In [None]:
# 随机森林使用有放回的采样方法
# 特征可以使用全部特征或者是平方根的特征个数

In [541]:
import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score

In [542]:
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

In [543]:
clf = RandomForestClassifier()
clf.fit(X_train, y_train)

RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=None, max_features='auto', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=1, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=10, n_jobs=1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False)

In [544]:
# 参数详解：
# min_sample_leaf：分裂后每个子结点包含的最小样本数量
# min_sample_split：分裂之前每个结点包含的最小样本数量
# max_features：默认情况为特征总数的开平方
# n_estimators：树分类器的数量，默认情况下为10
# criterion：模型情况下选择gini
# max_depth：树的最大深度

## 交叉验证

In [545]:
# KFold：K折交叉验证
# RepeatedKFold：使用p次K折交叉验证
# ShuffleSplit：数据打乱后进行划分
# StratifiedKFold：训练集与原始数据分布一致
# StratifiedShuffleSplit：划分后的训练集数据分不一致，但不一定与原始数据一致

In [546]:
# """KFold"""
# from sklearn.model_selection import KFold
# kf = KFold(n_splits=5)
# for train_index, test_index in kf.split(X):
#     X_train, y_train = X[train_index], y[train_index]
#     X_test, y_test = X[test_index], y[test_index]

In [547]:
# """RepeatedKFold"""
# from sklearn.model_selection import RepeatedKFold
# rkf = RepeatedKFold(n_splits=5, n_repeats=2, random_state=42)
# for train_index, test_index in rkf.split(X):
#     print(train_index, test_index)

In [548]:
# """ShuffleSplit"""
# from sklearn.model_selection import ShuffleSplit
# sp = ShuffleSplit(n_splits=5, random_state=24, test_size=0.2)
# for train_index, test_index in sp.split(X):
#     print(train_index, test_index)

In [549]:
# """StratifiedKFold"""
# from sklearn.model_selection import StratifiedKFold
# sk = StratifiedKFold(n_splits=5, shuffle=True)
# for train_index, test_index in sk.split(X, y):
#     X_train, y_train = X[train_index], y[train_index]
#     X_test, y_test = X[test_index], y[test_index]

In [550]:
# """StratifiedShuffleSplit"""
# from sklearn.model_selection import StratifiedShuffleSplit
# ss = StratifiedShuffleSplit(n_splits=5, random_state=24)
# for train_index, test_index in ss.split(X, y):
#     X_train, y_train = X[train_index], y[train_index]
#     X_test, y_test = X[test_index], y[test_index]

## 交叉验证进行模型评估

In [None]:
cross_value_score：cv取整数的时候默认进行KFold
cross_validate：可以通过列表或者字典的方法进行传入
cross_val_predict：与cross_value_score原理相同，返回交叉验证之后的输出值

In [554]:
"""cross_value_score"""
from sklearn.model_selection import cross_val_score
scores = cross_val_score(clf, iris.data, iris.target, cv=5)
# scores = cross_val_score(clf, iris.data, iris.target, cv=5, scoring='f1_micro')
print(scores)

[0.96666667 0.96666667 0.93333333 0.96666667 1.        ]


In [559]:
"""cross_validate"""
from sklearn.model_selection import cross_validate
scoring = ['precision_macro', 'recall_macro', 'f1_micro']
scores = cross_validate(clf, iris.data, iris.target, cv=5, scoring=scoring)
"""返回模型训练时间，评估时间，实用结果"""
print(scores.keys())
print(scores['test_f1_micro'])

dict_keys(['fit_time', 'score_time', 'test_precision_macro', 'train_precision_macro', 'test_recall_macro', 'train_recall_macro', 'test_f1_micro', 'train_f1_micro'])
[0.96666667 0.96666667 0.93333333 0.9        1.        ]


In [562]:
"""cross_val_predict"""
from sklearn.model_selection import cross_val_predict
predict = cross_val_predict(clf, iris.data, iris.target, cv=5)
print(predict)
print(accuracy_score(predict, iris.target))

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 1 1 1
 1 1 1 2 1 1 1 1 1 2 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 2 2 2 2 2 2 1 2 2 2 2
 2 2 2 2 2 2 2 2 1 2 2 2 2 2 2 2 2 2 1 2 2 2 1 2 2 2 2 2 2 2 2 2 2 2 2 2 2
 2 2]
0.9533333333333334
