Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-9005][MLlib]Fix RegressionMetrics computation of explainedVariance #7361

Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -53,14 +53,21 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
)
summary
}
private lazy val SSerr = math.pow(summary.normL2(1), 2)
private lazy val SStot = summary.variance(0) * (summary.count - 1)
private lazy val SSreg = {
val yMean = summary.mean(0)
predictionAndObservations.map {
case (prediction, _) => math.pow(prediction - yMean, 2)
}.sum()
}

/**
* Returns the explained variance regression score.
* explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
* Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
* Returns the variance explained by regression.
* @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you put the formula here please?

*/
def explainedVariance: Double = {
1 - summary.variance(1) / summary.variance(0)
SSreg / summary.count
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before this was really fraction of variance unexplained, and that's what your new reference says, but then shouldn't this be SSreg / SStot?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the reference to one which is about explained/unexplained variance in the context of regression and which also provides explicit formulas for calculation. The calculation before this PR doesn't seem to correspond to anything on either reference.

When the regression model is unbiased (e.g. has an intercept term), the sum of squares can be partitioned (SStot = SSreg + SSerr) and the fraction of variance explained (SSreg / SStot) is R^2.

The same reference defines explained variance as the variance of the model's predictions (SSreg / n), which I think is more appropriate given that this method is called explainedVariance not proportionVarianceExplained (which is also a bit redundant with r2).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with all that. The proportion is kind of redundant. This is going beyond fixing the formula, and also 'fixing' it to return something more consistent with its name. It's an experimental class so seems legitimate. +1

}

/**
Expand All @@ -76,23 +83,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend
* expected value of the squared error loss or quadratic loss.
*/
def meanSquaredError: Double = {
val rmse = summary.normL2(1) / math.sqrt(summary.count)
rmse * rmse
SSerr / summary.count
}

/**
* Returns the root mean squared error, which is defined as the square root of
* the mean squared error.
*/
def rootMeanSquaredError: Double = {
summary.normL2(1) / math.sqrt(summary.count)
math.sqrt(this.meanSquaredError)
}

/**
* Returns R^2^, the coefficient of determination.
* Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
* Returns R^2^, the unadjusted coefficient of determination.
* @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
*/
def r2: Double = {
1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
1 - SSerr / SStot
}
}
Expand Up @@ -23,24 +23,85 @@ import org.apache.spark.mllib.util.TestingUtils._

class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {

test("regression metrics") {
test("regression metrics for unbiased (includes intercept term) predictor") {
/* Verify results in R:
preds = c(2.25, -0.25, 1.75, 7.75)
obs = c(3.0, -0.5, 2.0, 7.0)

SStot = sum((obs - mean(obs))^2)
SSreg = sum((preds - mean(obs))^2)
SSerr = sum((obs - preds)^2)

explainedVariance = SSreg / length(obs)
explainedVariance
> [1] 8.796875
meanAbsoluteError = mean(abs(preds - obs))
meanAbsoluteError
> [1] 0.5
meanSquaredError = mean((preds - obs)^2)
meanSquaredError
> [1] 0.3125
rmse = sqrt(meanSquaredError)
rmse
> [1] 0.559017
r2 = 1 - SSerr / SStot
r2
> [1] 0.9571734
*/
val predictionAndObservations = sc.parallelize(
Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch")
}

test("regression metrics for biased (no intercept term) predictor") {
/* Verify results in R:
preds = c(2.5, 0.0, 2.0, 8.0)
obs = c(3.0, -0.5, 2.0, 7.0)

SStot = sum((obs - mean(obs))^2)
SSreg = sum((preds - mean(obs))^2)
SSerr = sum((obs - preds)^2)

explainedVariance = SSreg / length(obs)
explainedVariance
> [1] 8.859375
meanAbsoluteError = mean(abs(preds - obs))
meanAbsoluteError
> [1] 0.5
meanSquaredError = mean((preds - obs)^2)
meanSquaredError
> [1] 0.375
rmse = sqrt(meanSquaredError)
rmse
> [1] 0.6123724
r2 = 1 - SSerr / SStot
r2
> [1] 0.9486081
*/
val predictionAndObservations = sc.parallelize(
Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
"root mean squared error mismatch")
assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch")
}

test("regression metrics with complete fitting") {
val predictionAndObservations = sc.parallelize(
Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2)
val metrics = new RegressionMetrics(predictionAndObservations)
assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5,
"explained variance regression score mismatch")
assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/mllib/evaluation.py
Expand Up @@ -82,7 +82,7 @@ class RegressionMetrics(JavaModelWrapper):
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
>>> metrics = RegressionMetrics(predictionAndObservations)
>>> metrics.explainedVariance
0.95...
8.859...
>>> metrics.meanAbsoluteError
0.5...
>>> metrics.meanSquaredError
Expand Down