Skip to content

Commit

Permalink
[SPARK-7833] [ML] Add python wrapper for RegressionEvaluator
Browse files Browse the repository at this point in the history
Author: Ram Sriharsha <rsriharsha@hw11853.local>

Closes #6365 from harsha2010/SPARK-7833 and squashes the following commits:

923f288 [Ram Sriharsha] cleanup
7623b7d [Ram Sriharsha] python style fix
9743f83 [Ram Sriharsha] [SPARK-7833][ml] Add python wrapper for RegressionEvaluator

(cherry picked from commit 65c696e)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
  • Loading branch information
Ram Sriharsha authored and mengxr committed May 24, 2015
1 parent b06389c commit 16a6da5
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 4 deletions.
Expand Up @@ -31,14 +31,14 @@ import org.apache.spark.sql.types.DoubleType
* Evaluator for regression, which expects two input columns: prediction and label.
*/
@AlphaComponent
class RegressionEvaluator(override val uid: String)
final class RegressionEvaluator(override val uid: String)
extends Evaluator with HasPredictionCol with HasLabelCol {

def this() = this(Identifiable.randomUID("regEval"))

/**
* param for metric name in evaluation
* @group param
* @group param supports mse, rmse, r2, mae as valid metric names.
*/
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("mse", "rmse", "r2", "mae"))
Expand Down
Expand Up @@ -39,6 +39,7 @@ class RegressionEvaluatorSuite extends FunSuite with MLlibTestSparkContext {
val dataset = sqlContext.createDataFrame(
sc.parallelize(LinearDataGenerator.generateLinearInput(
6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))

/**
* Using the following R code to load the data, train the model and evaluate metrics.
*
Expand Down
68 changes: 66 additions & 2 deletions python/pyspark/ml/evaluation.py
Expand Up @@ -19,11 +19,11 @@

from pyspark.ml.wrapper import JavaWrapper
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc

__all__ = ['Evaluator', 'BinaryClassificationEvaluator']
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator']


@inherit_doc
Expand Down Expand Up @@ -148,6 +148,70 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label",
return self._set(**kwargs)


@inherit_doc
class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
"""
Evaluator for Regression, which expects two input
columns: prediction and label.
>>> scoreAndLabels = [(-28.98343821, -27.0), (20.21491975, 21.5),
... (-25.98418959, -22.0), (30.69731842, 33.0), (74.69283752, 71.0)]
>>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"])
...
>>> evaluator = RegressionEvaluator(predictionCol="raw")
>>> evaluator.evaluate(dataset)
2.842...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "r2"})
0.993...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
2.649...
"""
# a placeholder to make it appear in the generated doc
metricName = Param(Params._dummy(), "metricName",
"metric name in evaluation (mse|rmse|r2|mae)")

@keyword_only
def __init__(self, predictionCol="prediction", labelCol="label",
metricName="rmse"):
"""
__init__(self, predictionCol="prediction", labelCol="label", \
metricName="rmse")
"""
super(RegressionEvaluator, self).__init__()
self._java_obj = self._new_java_obj(
"org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid)
#: param for metric name in evaluation (mse|rmse|r2|mae)
self.metricName = Param(self, "metricName",
"metric name in evaluation (mse|rmse|r2|mae)")
self._setDefault(predictionCol="prediction", labelCol="label",
metricName="rmse")
kwargs = self.__init__._input_kwargs
self._set(**kwargs)

def setMetricName(self, value):
"""
Sets the value of :py:attr:`metricName`.
"""
self._paramMap[self.metricName] = value
return self

def getMetricName(self):
"""
Gets the value of metricName or its default value.
"""
return self.getOrDefault(self.metricName)

@keyword_only
def setParams(self, predictionCol="prediction", labelCol="label",
metricName="rmse"):
"""
setParams(self, predictionCol="prediction", labelCol="label",
metricName="rmse")
Sets params for regression evaluator.
"""
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)

if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
Expand Down

0 comments on commit 16a6da5

Please sign in to comment.