This notebook contains the definitions and training pipeline of the machine learning algorithms applied in the paper.
For each algorithm, a model will be created with the same hyperparameters as the ones declared in the paper (defaults hyperparameters of <code>sklearn</code>).

Said model will be trained and tested on the split already provided from the beginning of the SeqScout procedure. In this way, we can have a score that is as much as comparable as possible with the values reported by the paper.


Then, an hyperparameter search will be conducted on the models via cross validation on the training split. The best performing model will be again trained on the full train set and tested on test set, in order to compare it to the default model.
Each model will be compared on accuracy, precision and recall. Moreover, the confusion matrix for each best performing model will be produced.

In [13]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.mllib.evaluation import MulticlassMetrics
from utils import load_df

## Preliminaries and dataset loading

In [9]:
spark = SparkSession.builder.appName("RocketLeagueL").getOrCreate()

path = "./"
train = load_df(path, "encoded_df", spark)
test = load_df(path, "encoded_test", spark)

accuracy = MulticlassClassificationEvaluator(labelCol='class', predictionCol='prediction', metricName='accuracy')
precision = MulticlassClassificationEvaluator(labelCol='class', predictionCol='prediction', metricName='weightedPrecision')
recall = MulticlassClassificationEvaluator(labelCol='class', predictionCol='prediction', metricName='weightedRecall')

## Decision Tree

In [15]:
dtc = DecisionTreeClassifier(featuresCol='features', labelCol='class', maxDepth=30)
# apparently 30 is the currently supported maximum depth for pyspark

dtcf = dtc.fit(train)

predictions = dtcf.transform(test)

predictions.select("prediction", "class").show(5)

print("Accuracy: ")
print(accuracy.evaluate(predictions))
print("Weighted precision: ")
print(precision.evaluate(predictions))
print("Weighted recall: ")
print(recall.evaluate(predictions))

+----------+-----+
|prediction|class|
+----------+-----+
|       5.0|    5|
|       6.0|    6|
|       6.0|    6|
|       0.0|    6|
|       6.0|    6|
+----------+-----+
only showing top 5 rows

Accuracy: 
0.7333333333333333
Weighted precision: 
0.8226527149321267
Weighted recall: 
0.7333333333333333
