Skip to content

Commit

Permalink
[SPARK-31768][ML] add getMetrics in Evaluators
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
add getMetrics in Evaluators to get the corresponding Metrics instance, so users can use it to get any of the metrics scores. For example:

```
    val trainer = new LinearRegression
    val model = trainer.fit(dataset)
    val predictions = model.transform(dataset)

    val evaluator = new RegressionEvaluator()

    val metrics = evaluator.getMetrics(predictions)
    val rmse = metrics.rootMeanSquaredError
    val r2 = metrics.r2
    val mae = metrics.meanAbsoluteError
    val variance = metrics.explainedVariance
```

### Why are the changes needed?
Currently, Evaluator.evaluate only access to one metrics, but most users may need to get multiple metrics. This PR adds getMetrics in all the Evaluators, so users can use it to get an instance of the corresponding Metrics to get any of the metrics they want.

### Does this PR introduce _any_ user-facing change?
Yes. Add getMetrics in Evaluators.
For example:
```
  /**
   * Get a RegressionMetrics, which can be used to get any of the regression
   * metrics such as rootMeanSquaredError, meanSquaredError, etc.
   *
   * param dataset a dataset that contains labels/observations and predictions.
   * return RegressionMetrics
   */
  Since("3.1.0")
  def getMetrics(dataset: Dataset[_]): RegressionMetrics
```

### How was this patch tested?
Add new unit tests

Closes #28590 from huaxingao/getMetrics.

Authored-by: Huaxin Gao <huaxing@us.ibm.com>
Signed-off-by: Sean Owen <srowen@gmail.com>
  • Loading branch information
huaxingao authored and srowen committed May 24, 2020
1 parent cf7463f commit d0fe433
Show file tree
Hide file tree
Showing 14 changed files with 905 additions and 591 deletions.
Expand Up @@ -98,6 +98,24 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va

@Since("2.0.0")
override def evaluate(dataset: Dataset[_]): Double = {
val metrics = getMetrics(dataset)
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
}
metrics.unpersist()
metric
}

/**
* Get a BinaryClassificationMetrics, which can be used to get binary classification
* metrics such as areaUnderROC and areaUnderPR.
*
* @param dataset a dataset that contains labels/observations and predictions.
* @return BinaryClassificationMetrics
*/
@Since("3.1.0")
def getMetrics(dataset: Dataset[_]): BinaryClassificationMetrics = {
val schema = dataset.schema
SchemaUtils.checkColumnTypes(schema, $(rawPredictionCol), Seq(DoubleType, new VectorUDT))
SchemaUtils.checkNumericType(schema, $(labelCol))
Expand All @@ -119,13 +137,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va
case Row(rawPrediction: Double, label: Double, weight: Double) =>
(rawPrediction, label, weight)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins))
val metric = $(metricName) match {
case "areaUnderROC" => metrics.areaUnderROC()
case "areaUnderPR" => metrics.areaUnderPR()
}
metrics.unpersist()
metric
new BinaryClassificationMetrics(scoreAndLabelsWithWeights, $(numBins))
}

@Since("1.5.0")
Expand Down

0 comments on commit d0fe433

Please sign in to comment.