New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-14862][ML] Updated Classifiers to not require labelCol metadata #12663
Conversation
…BTClassifier to not require input column metadata
* and put it in an RDD with strong types. | ||
* @throws SparkException if any label is not an integer >= 0 and < numClasses | ||
*/ | ||
def extractLabeledPoints( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note this is private[ml]
instead of protected within class Classifier since GBTClassifier does not yet implement Classifier.
Test build #56904 has finished for PR 12663 at commit
|
} | ||
val maxLabel: Int = maxLabelRow.head.getDouble(0).toInt | ||
val numClasses = maxLabel + 1 | ||
require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree we should set a limit here. It might not be clear to someone who receives this error that they can have more than 1000 classes when they set the metadata themselves. Maybe the last sentence could say "For labels containing more than $maxNumClasses, specify the numClasses explicitly in metadata, such as ..."
protected def getNumClasses(dataset: Dataset[_], maxNumClasses: Int = 1000): Int = { | ||
MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { | ||
case Some(n: Int) => n | ||
case None => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On one hand, this could cause issues with improperly indexed labels, so it would be nice to issue a warning saying that spark is trying to infer the number of labels. On the other hand, this is a nice convenience and it might be annoying to see a warning every time it's used. Thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Logging a warning seems reasonable to me. We could also decrease maxNumClasses to force users to do indexing in iffier situations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I'm worried about a warning being annoying, as you suggested. Maybe I'll just do logInfo and decrease maxNumClasses to 100.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree that a warning is annoying. However, if we make things too restrictive, we could limit its usefulness. Is this an annoyance for a lot of users and do they typically have less than 100 label classes? I already know it's an annoyance for developers doing small tests :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure how many users would notice/like/dislike the warning in general. I feel like 100 classes is pretty large for trees, though I know there are plenty of cases exceeding 100.
I'll wait to update this until your reviews are done. Thanks for taking a look! |
Test build #56914 has finished for PR 12663 at commit
|
These changes look good. I left a couple small comments. |
Thanks! updated now |
Test build #56936 has finished for PR 12663 at commit
|
Out of curiosity, if or when SPARK-7126 is implemented, do we plan to remove this behavior? Regarding small datasets and cross validation, it is a bit concerning that the model could get trained with an incorrect number of classes, and since it will happen silently, it could create some confusion. However, I think it is reasonable to expect that end users should realize that some splits of their data could be missing label class values, and without explicitly flagging the number of classes, there is no way for the algorithm to know. |
I feel like SPARK-7126 should not change this behavior. If a user passes in data which are all integers in [0, ..., some small-ish integer], then I think it's reasonable, to assume they are class labels. That is what most ML libraries do, I believe. I bet we could fix the cross validation issue by having CV check to see if the algorithm is a classifier, but I don't think we need to rush on that. @sethah @MLnick Thanks for taking a look, and please let me know if there are other items to address. I'll take your votes on the maxNumClasses value. : ) |
If SPARK-7126 were implemented, would this patch still be useful? We could just check for absence of metadata and prepend a StringIndexer. I guess my question is this: Can we avoid some of the fuzziness and potentially confusing situations that might be introduced by this patch, by simply implementing SPARK-7126? If we could, then I think it's a tradeoff between complexity, urgency, and correctness. It's hard for me to say without knowing more about how this affects and has affected users and what specifically is driving this change. That said, I think 100 is a reasonable number for |
This PR makes the behavior of trees and ensembles consistent with others, such as NaiveBayes. I think it's a pretty common ML use case to have labels already indexed from 0, at least in canned ML datasets. StringIndexer could also behave differently in rarer situations. If classes 0,...,numClasses-1 were present in the labels except for some intermediate class i, then StringIndexer would re-index the classes, whereas this would not. An alternative would be to force users to specify numClasses, but that seems unnecessary to me. |
override protected def extractLabeledPoints(dataset: Dataset[_]): RDD[LabeledPoint] = { | ||
dataset.select(col($(labelCol)).cast(DoubleType), col($(featuresCol))).rdd.map { | ||
case Row(label: Double, features: Vector) => | ||
require(label % 1 == 0 && label >= 0, s"Classifier was given dataset with invalid label" + |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add a comment in the code for future developers that the number of classes has already been checked before with a call to getNumClasses
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better yet, I'll just change this to take and check numClasses
A quick look at the source code of scikit-learn shows that it always reindexes, but it uses some efficient numpy primitive for doing that. I think assuming an index for small integers is an acceptable tradeoff for the users (especially in the binary case). @jkbradley what happens when a class label is missing from the dataset? I presume this is not a cause for concern? |
If a class label is missing, that's fine, unless it is the largest label. |
Check for overflow in conversion double -> int in getNumClasses. cleanups Unit tests for the above
Updated! |
Test build #57159 has finished for PR 12663 at commit
|
LGTM |
test this please |
Test build #57166 has finished for PR 12663 at commit
|
Thanks @MLnick @sethah @thunterdb for reviewing! |
What changes were proposed in this pull request?
Updated Classifier, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier to not require input column metadata.
This functionality is implemented in a new Classifier.getNumClasses method.
Also
How was this patch tested?