[HIVEMALL-78] Implement AUC UDAF for binary classification #52
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@takuti Comments attached.
Assert.assertEquals(0.83333, agg.get(), 1e-5); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a blank line
ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg; | ||
|
||
double score = HiveUtils.getDouble(parameters[0], scoreOI); | ||
double label = HiveUtils.getDouble(parameters[1], labelOI); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(parameters[0] == null) {
continue;
}
if(parameters[1]) == null) { // separate ^ for debugging (e.g., attaching breakpoints)
continue;
}
int label = PrimitiveObjectInspectorUtils.getInt(parameters[1], labelOI);
if(label == -1) {
label = 0;
} else if (label != 0 && label != 1) {
throw new UDFArgumentException("label MUST be 0/1 or -1/1: " + label);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
continue;
-> return;
// initialize input | ||
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data | ||
this.scoreOI = (PrimitiveObjectInspector) parameters[0]; | ||
this.labelOI = (PrimitiveObjectInspector) parameters[1]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this.scoreOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
this.labelOI = HiveUtils.asIntegerOI(parameters[1]);
if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) { | ||
throw new UDFArgumentTypeException(0, | ||
"The first argument `array rankItems` is invalid form: " + typeInfo[0]); | ||
if (HiveUtils.isNumberTypeInfo(typeInfo[0]) && HiveUtils.isNumberTypeInfo(typeInfo[1])) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
&& HiveUtils.isIntegerTypeInfo(typeInfo[1])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. @takuti Added comments again.
Better to validate score
value.
return res / (tp * fp); // scale | ||
} | ||
|
||
void iterate(@Nonnull double score, @Nonnull int label) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove @Nonnull
for primitives (never become null)
docs/gitbook/eval/auc.md
Outdated
|
||
# Area Under the ROC Curve | ||
|
||
[ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) and Area Under the ROC Curve (AUC) are widely-used metric for binary (i.e., positive or negative) classification problems such as [Logistic Regression](../binaryclass/a9a_lr.md). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix the link to ../binaryclass/a9a_lr.html
distribute by floor(prob / 0.2) | ||
) t; | ||
``` | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a note explaining what floor(prob / 0.2)
is meaning. Distribute AUC computation into 5 bins.
return; | ||
} | ||
|
||
double score = HiveUtils.getDouble(parameters[0], scoreOI); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if(score < 0.0d || score > 1.0d) {
throw new UDFArgumentException("score value MUST be in range [0,1]: " + score);
}
LGTM. Merged. Thank you so much for a big contribution! |
What changes were proposed in this pull request?
In addition to current
auc(array, array)
for ranking (myui/hivemall#326), this patch supportsauc(double, double)
for binary classification.What type of PR is it?
Feature
What is the Jira issue?
https://issues.apache.org/jira/browse/HIVEMALL-78
How was this patch tested?
Created unit test for the UDAF, and passed:
Moreover, I have launched manual tests by the following queries:
Both showed
AUC=0.83333
. This result is same as scikit-learn's roc_auc_score():How to use this feature?
See above queries. Input data needs to be ordered by scores in a descending order.