Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-13961][ML] spark.ml ChiSqSelector and RFormula should support other numeric types for label #12467

Closed
wants to merge 6 commits into from

Conversation

BenFradet
Copy link
Contributor

What changes were proposed in this pull request?

Made ChiSqSelector and RFormula accept all numeric types for label

How was this patch tested?

Unit tests

@yanboliang
Copy link
Contributor

@BenFradet Thanks for this PR. I think we can not make RFormula supporting other numeric types for label as your proposal. If the label column already exists, it must be type of DoubleType, otherwise it will cause the downstream model can not recognize the label column when validateAndTransformSchema.

@BenFradet
Copy link
Contributor Author

@yanboliang thanks for your input, I reverted the affected commits

@yanboliang
Copy link
Contributor

yanboliang commented Apr 18, 2016

@BenFradet I'm really very sorry that I did not notice the #10355 has been merged. Please ignore my last comments because it's not valid after #10355. Would you mind to add RFormula support back? Thanks!

@BenFradet
Copy link
Contributor Author

@yanboliang yup, no problem.

@SparkQA
Copy link

SparkQA commented Apr 18, 2016

Test build #56068 has finished for PR 12467 at commit f05c217.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 19, 2016

Test build #56220 has finished for PR 12467 at commit c9edad0.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -290,4 +291,18 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
val newModel = testDefaultReadWrite(model)
checkModelData(model, newModel)
}

test("should support all NumericType labels") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use MLTestingUtils.checkNumericTypes to test this? It will eliminate some redundant code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd work expect for the expected exception when dealing with a dataframe containing string labels because the label column gets indexed by RFormula's fit.
Consequently, an exception is thrown by StringIndexer.

What I could do is add a validateSchema to RFormula (called at the beginiing of the fit method) checking that the label column is of numeric type, then I could use MLTestingUtils.checkNumericTypes.

What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or simply just:

val schema = dataset.schema
SchemaUtils.checkNumericType(schema, $(labelCol))

at the beginning of RFormula's fit method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After reviewing the suite, I don't think the same tests apply since RFormula also accepts string labels.

Consequently, I think it's best as is.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BenFradet Sorry for late response.
I'm OK with what you have done here for the issue mentioned above. Could you add more tests for RFormulaModel equality check? Here you have checked resolvedFormula which is produced by RFormulaParser rather than the entire RFormula. It's better also check the equality of pipelineModel of RFormulaModel.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, thanks for your input.

@SparkQA
Copy link

SparkQA commented Apr 19, 2016

Test build #56258 has finished for PR 12467 at commit 4cb27cf.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@@ -254,8 +254,8 @@ class RFormulaModel private[feature](
val columnNames = schema.map(_.name)
require(!columnNames.contains($(featuresCol)), "Features column already exists.")
require(
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType == DoubleType,
"Label column already exists and is not of type DoubleType.")
!columnNames.contains($(labelCol)) || schema($(labelCol)).dataType.isInstanceOf[NumericType],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the || not be &&?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so no. What do you think @yanboliang ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 @BenFradet It should be ||

Copy link
Contributor

@MLnick MLnick Apr 22, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

e.g. before this PR this works (and I don't believe it's supposed to?).

scala> val original = sqlContext.createDataFrame(Seq((0, 1), (2, 2))).toDF("x", "y")
original: org.apache.spark.sql.DataFrame = [x: int, y: int]

scala> formula.fit(original).transform(original).show
+---+---+--------+-----+
|  x|  y|features|label|
+---+---+--------+-----+
|  0|  1|   [0.0]|  1.0|
|  2|  2|   [2.0]|  2.0|
+---+---+--------+-----+

And to make it clear that this check is not actually being performed:

scala> val original = sqlContext.createDataFrame(Seq((0, Seq(1)), (2, Seq(2)))).toDF("x", "y")
original: org.apache.spark.sql.DataFrame = [x: int, y: array<int>]

scala> formula.fit(original).transform(original).show
java.lang.IllegalArgumentException: Unsupported type for label: ArrayType(IntegerType,false)
  at org.apache.spark.ml.feature.RFormulaModel.transformLabel(RFormula.scala:246)
  at org.apache.spark.ml.feature.RFormulaModel.transform(RFormula.scala:211)
  ... 48 elided

... so it's catching it, but at L244 not here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah I see now, never mind.

@MLnick
Copy link
Contributor

MLnick commented Apr 27, 2016

This LGTM. @yanboliang anything further?

@yanboliang
Copy link
Contributor

Look good overall, I have my last inline comment. After that, it should be ready to go.

@SparkQA
Copy link

SparkQA commented Apr 29, 2016

Test build #57359 has finished for PR 12467 at commit 79b0f9d.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

assert(expected.resolvedFormula.label === actual.resolvedFormula.label)
assert(expected.resolvedFormula.terms === actual.resolvedFormula.terms)
assert(expected.resolvedFormula.hasIntercept === actual.resolvedFormula.hasIntercept)
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yanboliang is this what you had in mind?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

@BenFradet
Copy link
Contributor Author

Ping @yanboliang

@yanboliang
Copy link
Contributor

yanboliang commented May 12, 2016

Sorry for late response. This LGTM and the conflicts should be resolved. Thanks!

@BenFradet
Copy link
Contributor Author

@yanboliang thanks a lot!
Will rebase soon.

@SparkQA
Copy link

SparkQA commented May 12, 2016

Test build #58502 has finished for PR 12467 at commit 23f80ee.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented May 12, 2016

Test build #58503 has finished for PR 12467 at commit 3786ef9.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BenFradet
Copy link
Contributor Author

pinging @MLnick

@MLnick
Copy link
Contributor

MLnick commented May 13, 2016

Merged to master and branch-2.0. Thanks!

asfgit pushed a commit that referenced this pull request May 13, 2016
…other numeric types for label

## What changes were proposed in this pull request?

Made ChiSqSelector and RFormula accept all numeric types for label

## How was this patch tested?

Unit tests

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #12467 from BenFradet/SPARK-13961.
(cherry picked from commit 31f1aeb)

Signed-off-by: Nick Pentreath <nick.pentreath@gmail.com>
@asfgit asfgit closed this in 31f1aeb May 13, 2016
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
4 participants