In [None]:
import numpy as np
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, FloatType, DoubleType
from pyspark.sql.functions import col, udf
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
path = "/home/mahjoubi/Documents/github/world_cup_2018/test/transform/regression_model/2014"

In [3]:
regression_models = ["linear_regression", "decision_tree", "random_forest", "gbt_regressor"]

In [4]:
schema =  StructType([
            StructField("id", DoubleType(), True),
            StructField("label", DoubleType(), True),
            StructField("prediction", DoubleType(), True)])

In [18]:
def get_prediction(x, threshold):
    if float(np.abs(x)) <= threshold:
        return 0.0
    elif x > threshold:
        return 2.0
    elif x < -1.0 * threshold:
        return 1.0
    
udf_prediction = udf(lambda x: get_prediction(x, threshold), DoubleType())

In [19]:
evaluator = MulticlassClassificationEvaluator(predictionCol="prediction", labelCol="label", metricName="accuracy")

In [28]:
dic_data = {}
for threshold in [0.1, 0.25, 0.5, 0.75, 1.0]:
    print("threshold: {0}".format(threshold))
    for regression in ["linear_regression", "decision_tree", "random_forest", "gbt_regressor"]:
        udf_prediction = udf(lambda x: get_prediction(x, threshold), DoubleType())
        transform = (spark.read.csv(os.path.join(path, regression), header=True, schema=schema)
                .withColumnRenamed("prediction", "diff_point")
                .withColumn("prediction", udf_prediction(col("diff_point"))))
        dic_data[regression] = transform.withColumnRenamed("diff_point", regression) #.drop("prediction")
        #     dic_data[regression].show(5)
        print("{0}: {1}, ".format(regression, evaluator.evaluate(transform))),
    print("")
        #     dic_data[regression].groupBy("prediction").count().show()

threshold: 0.1
linear_regression: 0.579150579151,  decision_tree: 0.58832046332,  random_forest: 0.601833976834,  gbt_regressor: 0.616312741313,  
threshold: 0.25
linear_regression: 0.583976833977,  decision_tree: 0.58832046332,  random_forest: 0.598938223938,  gbt_regressor: 0.61583011583,  
threshold: 0.5
linear_regression: 0.560328185328,  decision_tree: 0.585424710425,  random_forest: 0.574806949807,  gbt_regressor: 0.602799227799,  
threshold: 0.75
linear_regression: 0.513513513514,  decision_tree: 0.553088803089,  random_forest: 0.540057915058,  gbt_regressor: 0.583494208494,  
threshold: 1.0
linear_regression: 0.464285714286,  decision_tree: 0.422297297297,  random_forest: 0.492760617761,  gbt_regressor: 0.535714285714,  


In [25]:
keys = dic_data.keys()
data = dic_data[keys[0]].drop("prediction")
for key in keys[1:]:
    data = data.join(dic_data[key].drop("label").drop("prediction"), on="id")

In [26]:
data.show(5)

+------+-----+-------------------+------------------+-------------------+-------------------+
|    id|label|  linear_regression|     decision_tree|      gbt_regressor|      random_forest|
+------+-----+-------------------+------------------+-------------------+-------------------+
|1659.0|  0.0|  2.046165815269463|1.9066666666666667| 1.6222924102693044|  2.117691399663445|
| 295.0|  0.0| 0.2668095739650328|0.5215311004784688| 0.5137104972870641| 0.3967606393942596|
| 888.0|  0.0| 0.2275194770319885|0.5215311004784688|0.17202622632228662|0.19120766749424767|
|1103.0|  0.0| 1.3884214994383401| 0.992867332382311| 1.1842599089851558| 1.2288847218732601|
|1533.0|  0.0|0.33472604811520623|0.5215311004784688| 0.6654867996158518| 0.4784945178222387|
+------+-----+-------------------+------------------+-------------------+-------------------+
only showing top 5 rows



In [36]:
threshold = 0.25
udf_prediction = udf(lambda x: get_prediction(x, threshold), DoubleType())
udf_mean = udf(lambda x, y, z, t: (x + y + z + t) / 4.0, DoubleType())
transform = (data
 .withColumn("diff_points", udf_mean(col("linear_regression"), col("decision_tree"), 
                                     col("gbt_regressor"), col("random_forest")))
 .withColumn("prediction", udf_prediction(col("diff_points"))))

In [37]:
evaluator.evaluate(transform)

0.6023166023166023

In [38]:
transform.groupBy("prediction").count().show()

+----------+-----+
|prediction|count|
+----------+-----+
|       0.0|  289|
|       1.0|  193|
|       2.0| 1590|
+----------+-----+



In [39]:
transform.groupBy("label").count().show()

+-----+-----+
|label|count|
+-----+-----+
|  0.0|  515|
|  1.0|  391|
|  2.0| 1166|
+-----+-----+

