Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

[HIVEMALL-78] Implement AUC UDAF for binary classification #52

Closed
wants to merge 11 commits into from

Conversation

takuti
Copy link
Member

@takuti takuti commented Feb 28, 2017

What changes were proposed in this pull request?

In addition to current auc(array, array) for ranking (myui/hivemall#326), this patch supports auc(double, double) for binary classification.

What type of PR is it?

Feature

What is the Jira issue?

https://issues.apache.org/jira/browse/HIVEMALL-78

How was this patch tested?

Created unit test for the UDAF, and passed:

$ mvn -Dtest=hivemall.evaluation.AUCUDAFTest test

Moreover, I have launched manual tests by the following queries:

with data as (
  select 0.5 as prob, 0 as label
  union all
  select 0.3 as prob, 1 as label
  union all
  select 0.2 as prob, 0 as label
  union all
  select 0.8 as prob, 1 as label
  union all
  select 0.7 as prob, 1 as label
), data_ordered as (
  select prob, label
  from data
  order by prob desc
)
select auc(prob, label)
from (
  select prob, label
  from data_ordered
  distribute by floor(prob / 0.2)
) t;
with data as (
  select 0.5 as prob, 0 as label
  union all
  select 0.3 as prob, 1 as label
  union all
  select 0.2 as prob, 0 as label
  union all
  select 0.8 as prob, 1 as label
  union all
  select 0.7 as prob, 1 as label
), data_ordered as (
  select prob, label
  from data
  order by prob desc
)
select auc(prob, label)
from data_ordered;

Both showed AUC=0.83333. This result is same as scikit-learn's roc_auc_score():

>>> roc_auc_score([0,1,0,1,1],[0.5,0.3,0.2,0.8,0.7])
0.83333333333333326

How to use this feature?

See above queries. Input data needs to be ordered by scores in a descending order.

Copy link
Member

@myui myui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@takuti Comments attached.

Assert.assertEquals(0.83333, agg.get(), 1e-5);
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a blank line

ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;

double score = HiveUtils.getDouble(parameters[0], scoreOI);
double label = HiveUtils.getDouble(parameters[1], labelOI);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(parameters[0] == null) {
   continue;
}
if(parameters[1]) == null) { // separate ^ for debugging (e.g., attaching breakpoints)
   continue;
}

int label = PrimitiveObjectInspectorUtils.getInt(parameters[1], labelOI);
if(label == -1) {
   label = 0;
} else if (label != 0 && label != 1) {
   throw new UDFArgumentException("label MUST be 0/1 or -1/1: " + label);
}

https://github.com/apache/hive/blob/master/serde/src/java/org/apache/hadoop/hive/serde2/objectinspector/primitive/PrimitiveObjectInspectorUtils.java#L634

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

continue; -> return;

// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
this.scoreOI = (PrimitiveObjectInspector) parameters[0];
this.labelOI = (PrimitiveObjectInspector) parameters[1];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this.scoreOI = HiveUtils.asDoubleCompatibleOI(parameters[0]);
this.labelOI = HiveUtils.asIntegerOI(parameters[1]);

if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo())) {
throw new UDFArgumentTypeException(0,
"The first argument `array rankItems` is invalid form: " + typeInfo[0]);
if (HiveUtils.isNumberTypeInfo(typeInfo[0]) && HiveUtils.isNumberTypeInfo(typeInfo[1])) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

&& HiveUtils.isIntegerTypeInfo(typeInfo[1])

Copy link
Member

@myui myui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. @takuti Added comments again.

Better to validate score value.

return res / (tp * fp); // scale
}

void iterate(@Nonnull double score, @Nonnull int label) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove @Nonnull for primitives (never become null)


# Area Under the ROC Curve

[ROC curve](https://en.wikipedia.org/wiki/Receiver_operating_characteristic) and Area Under the ROC Curve (AUC) are widely-used metric for binary (i.e., positive or negative) classification problems such as [Logistic Regression](../binaryclass/a9a_lr.md).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix the link to ../binaryclass/a9a_lr.html

distribute by floor(prob / 0.2)
) t;
```

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note explaining what floor(prob / 0.2) is meaning. Distribute AUC computation into 5 bins.

return;
}

double score = HiveUtils.getDouble(parameters[0], scoreOI);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if(score < 0.0d || score > 1.0d) {
throw new UDFArgumentException("score value MUST be in range [0,1]: " + score);
}

@asfgit asfgit closed this in 97bc912 Feb 28, 2017
@myui
Copy link
Member

myui commented Feb 28, 2017

LGTM. Merged. Thank you so much for a big contribution!

@coveralls
Copy link

coveralls commented Feb 28, 2017

Coverage Status

Coverage increased (+0.4%) to 36.739% when pulling 681e7fb on takuti:auc into 19d472b on apache:master.

@takuti takuti deleted the auc branch March 1, 2017 05:47
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
3 participants