New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label #12467
Conversation
@BenFradet Thanks for this PR. I think we can not make |
@yanboliang thanks for your input, I reverted the affected commits |
@BenFradet I'm really very sorry that I did not notice the #10355 has been merged. Please ignore my last comments because it's not valid after #10355. Would you mind to add RFormula support back? Thanks! |
@yanboliang yup, no problem. |
Test build #56068 has finished for PR 12467 at commit
|
Test build #56220 has finished for PR 12467 at commit
|
@@ -290,4 +291,18 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul | |||
val newModel = testDefaultReadWrite(model) | |||
checkModelData(model, newModel) | |||
} | |||
|
|||
test("should support all NumericType labels") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we use MLTestingUtils.checkNumericTypes
to test this? It will eliminate some redundant code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It'd work expect for the expected exception when dealing with a dataframe containing string labels because the label column gets indexed by RFormula's fit
.
Consequently, an exception is thrown by StringIndexer
.
What I could do is add a validateSchema
to RFormula (called at the beginiing of the fit
method) checking that the label column is of numeric type, then I could use MLTestingUtils.checkNumericTypes
.
What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or simply just:
val schema = dataset.schema
SchemaUtils.checkNumericType(schema, $(labelCol))
at the beginning of RFormula's fit
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After reviewing the suite, I don't think the same tests apply since RFormula also accepts string labels.
Consequently, I think it's best as is.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@BenFradet Sorry for late response.
I'm OK with what you have done here for the issue mentioned above. Could you add more tests for RFormulaModel
equality check? Here you have checked resolvedFormula
which is produced by RFormulaParser
rather than the entire RFormula
. It's better also check the equality of pipelineModel
of RFormulaModel
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, thanks for your input.
Test build #56258 has finished for PR 12467 at commit
|
@@ -254,8 +254,8 @@ class RFormulaModel private[feature]( | |||
val columnNames = schema.map(_.name) | |||
require(!columnNames.contains($(featuresCol)), "Features column already exists.") | |||
require( | |||
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType, | |||
"Label column already exists and is not of type DoubleType.") | |||
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the ||
not be &&
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so no. What do you think @yanboliang ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 @BenFradet It should be ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
e.g. before this PR this works (and I don't believe it's supposed to?).
scala> val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
original: org.apache.spark.sql.DataFrame = [x: int, y: int]
scala> formula.fit(original).transform(original).show
+---+---+--------+-----+
| x| y|features|label|
+---+---+--------+-----+
| 0| 1| [0.0]| 1.0|
| 2| 2| [2.0]| 2.0|
+---+---+--------+-----+
And to make it clear that this check is not actually being performed:
scala> val original = sqlContext.createDataFrame(Seq((0, Seq(1)), (2, Seq(2)))).toDF("x", "y")
original: org.apache.spark.sql.DataFrame = [x: int, y: array<int>]
scala> formula.fit(original).transform(original).show
java.lang.IllegalArgumentException: Unsupported type for label: ArrayType(IntegerType,false)
at org.apache.spark.ml.feature.RFormulaModel.transformLabel(RFormula.scala:246)
at org.apache.spark.ml.feature.RFormulaModel.transform(RFormula.scala:211)
... 48 elided
... so it's catching it, but at L244 not here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah I see now, never mind.
This LGTM. @yanboliang anything further? |
Look good overall, I have my last inline comment. After that, it should be ready to go. |
Test build #57359 has finished for PR 12467 at commit
|
assert(expected.resolvedFormula.label === actual.resolvedFormula.label) | ||
assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms) | ||
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yanboliang is this what you had in mind?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
Ping @yanboliang |
Sorry for late response. This LGTM and the conflicts should be resolved. Thanks! |
@yanboliang thanks a lot! |
Test build #58502 has finished for PR 12467 at commit
|
Test build #58503 has finished for PR 12467 at commit
|
pinging @MLnick |
Merged to master and branch-2.0. Thanks! |
…other numeric types for label ## What changes were proposed in this pull request? Made ChiSqSelector and RFormula accept all numeric types for label ## How was this patch tested? Unit tests Author: BenFradet <benjamin.fradet@gmail.com> Closes #12467 from BenFradet/SPARK-13961. (cherry picked from commit 31f1aeb) Signed-off-by: Nick Pentreath <nick.pentreath@gmail.com>
What changes were proposed in this pull request?
Made ChiSqSelector and RFormula accept all numeric types for label
How was this patch tested?
Unit tests