Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-14862][ML] Updated Classifiers to not require labelCol metadata #12663

Closed
wants to merge 5 commits into from

Conversation

jkbradley
Copy link
Member

What changes were proposed in this pull request?

Updated Classifier, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier to not require input column metadata.

  • They first check for metadata.
  • If numClasses is not specified in metadata, they identify the largest label value (up to a limit).

This functionality is implemented in a new Classifier.getNumClasses method.

Also

  • Updated Classifier.extractLabeledPoints to (a) check label values and (b) include a second version which takes a numClasses value for validity checking.

How was this patch tested?

  • Unit tests in ClassifierSuite for helper methods
  • Unit tests for DecisionTreeClassifier, RandomForestClassifier, GBTClassifier with toy datasets lacking label metadata

@jkbradley
Copy link
Member Author

CC: @MLnick @sethah Would one of you have time to help review this? This is a long-time annoyance for spark.ml trees and ensembles. Thanks in advance!

* and put it in an RDD with strong types.
* @throws SparkException if any label is not an integer >= 0 and < numClasses
*/
def extractLabeledPoints(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is private[ml] instead of protected within class Classifier since GBTClassifier does not yet implement Classifier.

@SparkQA
Copy link

SparkQA commented Apr 25, 2016

Test build #56904 has finished for PR 12663 at commit 94e206f.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

}
val maxLabel: Int = maxLabelRow.head.getDouble(0).toInt
val numClasses = maxLabel + 1
require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree we should set a limit here. It might not be clear to someone who receives this error that they can have more than 1000 classes when they set the metadata themselves. Maybe the last sentence could say "For labels containing more than $maxNumClasses, specify the numClasses explicitly in metadata, such as ..."

protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 1000): Int = {
MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
case Some(n: Int) => n
case None =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On one hand, this could cause issues with improperly indexed labels, so it would be nice to issue a warning saying that spark is trying to infer the number of labels. On the other hand, this is a nice convenience and it might be annoying to see a warning every time it's used. Thoughts?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logging a warning seems reasonable to me. We could also decrease maxNumClasses to force users to do indexing in iffier situations.

Copy link
Member Author

@jkbradley jkbradley Apr 25, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I'm worried about a warning being annoying, as you suggested. Maybe I'll just do logInfo and decrease maxNumClasses to 100.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that a warning is annoying. However, if we make things too restrictive, we could limit its usefulness. Is this an annoyance for a lot of users and do they typically have less than 100 label classes? I already know it's an annoyance for developers doing small tests :D

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how many users would notice/like/dislike the warning in general. I feel like 100 classes is pretty large for trees, though I know there are plenty of cases exceeding 100.

@jkbradley
Copy link
Member Author

I'll wait to update this until your reviews are done. Thanks for taking a look!

@SparkQA
Copy link

SparkQA commented Apr 25, 2016

Test build #56914 has finished for PR 12663 at commit 657faef.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor

sethah commented Apr 25, 2016

These changes look good. I left a couple small comments.

@jkbradley
Copy link
Member Author

Thanks! updated now

@SparkQA
Copy link

SparkQA commented Apr 25, 2016

Test build #56936 has finished for PR 12663 at commit 07ef697.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor

sethah commented Apr 26, 2016

Out of curiosity, if or when SPARK-7126 is implemented, do we plan to remove this behavior?

Regarding small datasets and cross validation, it is a bit concerning that the model could get trained with an incorrect number of classes, and since it will happen silently, it could create some confusion. However, I think it is reasonable to expect that end users should realize that some splits of their data could be missing label class values, and without explicitly flagging the number of classes, there is no way for the algorithm to know.

@jkbradley
Copy link
Member Author

I feel like SPARK-7126 should not change this behavior. If a user passes in data which are all integers in [0, ..., some small-ish integer], then I think it's reasonable, to assume they are class labels. That is what most ML libraries do, I believe.

I bet we could fix the cross validation issue by having CV check to see if the algorithm is a classifier, but I don't think we need to rush on that.

@sethah @MLnick Thanks for taking a look, and please let me know if there are other items to address. I'll take your votes on the maxNumClasses value. : )

@sethah
Copy link
Contributor

sethah commented Apr 26, 2016

If SPARK-7126 were implemented, would this patch still be useful? We could just check for absence of metadata and prepend a StringIndexer. I guess my question is this: Can we avoid some of the fuzziness and potentially confusing situations that might be introduced by this patch, by simply implementing SPARK-7126? If we could, then I think it's a tradeoff between complexity, urgency, and correctness. It's hard for me to say without knowing more about how this affects and has affected users and what specifically is driving this change.

That said, I think 100 is a reasonable number for maxNumClasses and the updates look good.

@jkbradley
Copy link
Member Author

This PR makes the behavior of trees and ensembles consistent with others, such as NaiveBayes. I think it's a pretty common ML use case to have labels already indexed from 0, at least in canned ML datasets.

StringIndexer could also behave differently in rarer situations. If classes 0,...,numClasses-1 were present in the labels except for some intermediate class i, then StringIndexer would re-index the classes, whereas this would not.

An alternative would be to force users to specify numClasses, but that seems unnecessary to me.

override protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = {
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map {
case Row(label: Double, features: Vector) =>
require(label % 1 == 0 && label >= 0, s"Classifier was given dataset with invalid label" +
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment in the code for future developers that the number of classes has already been checked before with a call to getNumClasses?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better yet, I'll just change this to take and check numClasses

@thunterdb
Copy link
Contributor

A quick look at the source code of scikit-learn shows that it always reindexes, but it uses some efficient numpy primitive for doing that. I think assuming an index for small integers is an acceptable tradeoff for the users (especially in the binary case).

@jkbradley what happens when a class label is missing from the dataset? I presume this is not a cause for concern?

@jkbradley
Copy link
Member Author

If a class label is missing, that's fine, unless it is the largest label.

Check for overflow in conversion double -> int in getNumClasses.
cleanups
Unit tests for the above
@jkbradley
Copy link
Member Author

Updated!

@SparkQA
Copy link

SparkQA commented Apr 27, 2016

Test build #57159 has finished for PR 12663 at commit a3653c7.

  • This patch fails MiMa tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@sethah
Copy link
Contributor

sethah commented Apr 27, 2016

LGTM

@jkbradley
Copy link
Member Author

test this please

@SparkQA
Copy link

SparkQA commented Apr 27, 2016

Test build #57166 has finished for PR 12663 at commit a3653c7.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@jkbradley
Copy link
Member Author

Thanks @MLnick @sethah @thunterdb for reviewing!
Merging with master

@asfgit asfgit closed this in 4f4721a Apr 28, 2016
@jkbradley jkbradley deleted the trees-no-metadata branch April 29, 2016 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
5 participants