New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-9005][MLlib]Fix RegressionMetrics computation of explainedVariance #7361
Changes from 6 commits
4c4e56f
c235de0
bde9761
db8605a
08a0e1b
1a3d098
f1112fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -53,14 +53,21 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend | |
) | ||
summary | ||
} | ||
private lazy val SSerr = math.pow(summary.normL2(1), 2) | ||
private lazy val SStot = summary.variance(0) * (summary.count - 1) | ||
private lazy val SSreg = { | ||
val yMean = summary.mean(0) | ||
predictionAndObservations.map { | ||
case (prediction, _) => math.pow(prediction - yMean, 2) | ||
}.sum() | ||
} | ||
|
||
/** | ||
* Returns the explained variance regression score. | ||
* explainedVariance = 1 - variance(y - \hat{y}) / variance(y) | ||
* Reference: [[http://en.wikipedia.org/wiki/Explained_variation]] | ||
* Returns the variance explained by regression. | ||
* @see [[https://en.wikipedia.org/wiki/Fraction_of_variance_unexplained]] | ||
*/ | ||
def explainedVariance: Double = { | ||
1 - summary.variance(1) / summary.variance(0) | ||
SSreg / summary.count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Before this was really fraction of variance unexplained, and that's what your new reference says, but then shouldn't this be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated the reference to one which is about explained/unexplained variance in the context of regression and which also provides explicit formulas for calculation. The calculation before this PR doesn't seem to correspond to anything on either reference. When the regression model is unbiased (e.g. has an intercept term), the sum of squares can be partitioned (SStot = SSreg + SSerr) and the fraction of variance explained (SSreg / SStot) is R^2. The same reference defines explained variance as the variance of the model's predictions (SSreg / n), which I think is more appropriate given that this method is called There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agree with all that. The proportion is kind of redundant. This is going beyond fixing the formula, and also 'fixing' it to return something more consistent with its name. It's an experimental class so seems legitimate. +1 |
||
} | ||
|
||
/** | ||
|
@@ -76,23 +83,22 @@ class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extend | |
* expected value of the squared error loss or quadratic loss. | ||
*/ | ||
def meanSquaredError: Double = { | ||
val rmse = summary.normL2(1) / math.sqrt(summary.count) | ||
rmse * rmse | ||
SSerr / summary.count | ||
} | ||
|
||
/** | ||
* Returns the root mean squared error, which is defined as the square root of | ||
* the mean squared error. | ||
*/ | ||
def rootMeanSquaredError: Double = { | ||
summary.normL2(1) / math.sqrt(summary.count) | ||
math.sqrt(this.meanSquaredError) | ||
} | ||
|
||
/** | ||
* Returns R^2^, the coefficient of determination. | ||
* Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] | ||
* Returns R^2^, the unadjusted coefficient of determination. | ||
* @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] | ||
*/ | ||
def r2: Double = { | ||
1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1)) | ||
1 - SSerr / SStot | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you put the formula here please?