In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler,StringIndexer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.sql.functions import col

In [2]:
spark = SparkSession.builder.appName("TennisRandomForest").getOrCreate()

In [3]:
df = spark.read.csv("./game_based_df.csv", header=True, inferSchema=True).sample(fraction=0.5, seed=1234)

In [4]:
indexer = StringIndexer(inputCol="SetWinP", outputCol="label")
df = indexer.fit(df).transform(df)

In [5]:
feature_columns = ['G_P1GmWon', 'G_P2GmWon', 'G_Ser', 'G_P1Ace', 'G_P2Ace', 'G_P1Wn', 
                   'G_P2Wn', 'G_P1Df', 'G_P2Df', 'G_P1UE', 'G_P2UE', 'G_P1NP', 'G_P2NP', 
                   'G_P1NPW', 'G_P2NPW', 'G_P1BP', 'G_P2BP', 'G_P1BPWon', 'G_P2BPWon', 
                   'G_P1FW', 'G_P1BW', 'G_P2FW', 'G_P2BW', 'G_P1SerW', 'G_P2SerW', 
                   'G_avg_SerSp']
label_column = "label"

In [6]:
# 数据预处理：将特征列合并为一个特征向量
assembler = VectorAssembler(inputCols=feature_columns, outputCol='features')
df = assembler.transform(df)

# 将数据切分为训练集和测试集
train_data, test_data = df.randomSplit([0.8, 0.2], seed=1234)

# 初始化随机森林分类器
rf = RandomForestClassifier(labelCol=label_column, featuresCol='features')

# 使用二分类评估器
evaluator = BinaryClassificationEvaluator(labelCol=label_column)

# 设置交叉验证的参数网格
param_grid = ParamGridBuilder().addGrid(rf.numTrees, [10, 20]).addGrid(rf.maxDepth, [5, 10]).build()

# 配置交叉验证
crossval = CrossValidator(estimator=rf, 
                          estimatorParamMaps=param_grid,
                          evaluator=evaluator,
                          numFolds=3)  # 3折交叉验证


In [7]:
# 训练模型
cv_model = crossval.fit(train_data)

# 使用交叉验证后的最佳模型进行预测
predictions = cv_model.transform(test_data)

In [11]:
# 评估模型
auc = evaluator.evaluate(predictions)
print(f"AUC: {auc}")

# 打印最佳模型的参数
best_model = cv_model.bestModel
print("Best Model Parameters:")
print(f"Number of Trees: {best_model.getNumTrees}")
print(f"Max Depth: {best_model.getMaxDepth}")
print(f"Max Bins: {best_model.getMaxBins}")

AUC: 0.8345238066364642
Best Model Parameters:
Number of Trees: 20
Max Depth: <bound method _DecisionTreeParams.getMaxDepth of RandomForestClassificationModel: uid=RandomForestClassifier_f331438830f7, numTrees=20, numClasses=2, numFeatures=26>
Max Bins: <bound method _DecisionTreeParams.getMaxBins of RandomForestClassificationModel: uid=RandomForestClassifier_f331438830f7, numTrees=20, numClasses=2, numFeatures=26>
