Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
Merge 810f540 into c53b9ff
Browse files Browse the repository at this point in the history
  • Loading branch information
takuti committed Apr 11, 2017
2 parents c53b9ff + 810f540 commit 5f29328
Show file tree
Hide file tree
Showing 3 changed files with 333 additions and 59 deletions.
225 changes: 175 additions & 50 deletions core/src/main/java/hivemall/evaluation/AUCUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.HashMap;
import java.util.SortedMap;
import java.util.TreeMap;

import javax.annotation.Nonnull;

Expand All @@ -36,6 +40,7 @@
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AbstractAggregationBuffer;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
Expand Down Expand Up @@ -86,12 +91,17 @@ public static class ClassificationEvaluator extends GenericUDAFEvaluator {
private PrimitiveObjectInspector labelOI;

private StructObjectInspector internalMergeOI;
private StructField aField;
private StructField scorePrevField;
private StructField indexScoreField;
private StructField areaField;
private StructField fpField;
private StructField tpField;
private StructField fpPrevField;
private StructField tpPrevField;
private StructField areaPartialMapField;
private StructField fpPartialMapField;
private StructField tpPartialMapField;
private StructField fpPrevPartialMapField;
private StructField tpPrevPartialMapField;

public ClassificationEvaluator() {}

Expand All @@ -107,12 +117,17 @@ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws Hive
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
this.internalMergeOI = soi;
this.aField = soi.getStructFieldRef("a");
this.scorePrevField = soi.getStructFieldRef("scorePrev");
this.indexScoreField = soi.getStructFieldRef("indexScore");
this.areaField = soi.getStructFieldRef("area");
this.fpField = soi.getStructFieldRef("fp");
this.tpField = soi.getStructFieldRef("tp");
this.fpPrevField = soi.getStructFieldRef("fpPrev");
this.tpPrevField = soi.getStructFieldRef("tpPrev");
this.areaPartialMapField = soi.getStructFieldRef("areaPartialMap");
this.fpPartialMapField = soi.getStructFieldRef("fpPartialMap");
this.tpPartialMapField = soi.getStructFieldRef("tpPartialMap");
this.fpPrevPartialMapField = soi.getStructFieldRef("fpPrevPartialMap");
this.tpPrevPartialMapField = soi.getStructFieldRef("tpPrevPartialMap");
}

// initialize output
Expand All @@ -129,9 +144,9 @@ private static StructObjectInspector internalMergeOI() {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("a");
fieldNames.add("indexScore");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("scorePrev");
fieldNames.add("area");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("fp");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);
Expand All @@ -142,6 +157,36 @@ private static StructObjectInspector internalMergeOI() {
fieldNames.add("tpPrev");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);

MapObjectInspector areaPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector);
fieldNames.add("areaPartialMap");
fieldOIs.add(areaPartialMapOI);

MapObjectInspector fpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector);
fieldNames.add("fpPartialMap");
fieldOIs.add(fpPartialMapOI);

MapObjectInspector tpPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector);
fieldNames.add("tpPartialMap");
fieldOIs.add(tpPartialMapOI);

MapObjectInspector fpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector);
fieldNames.add("fpPrevPartialMap");
fieldOIs.add(fpPrevPartialMapOI);

MapObjectInspector tpPrevPartialMapOI = ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector);
fieldNames.add("tpPrevPartialMap");
fieldOIs.add(tpPrevPartialMapOI);

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}

Expand Down Expand Up @@ -188,13 +233,19 @@ public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveExcep
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;

Object[] partialResult = new Object[6];
partialResult[0] = new DoubleWritable(myAggr.a);
partialResult[1] = new DoubleWritable(myAggr.scorePrev);
Object[] partialResult = new Object[11];
partialResult[0] = new DoubleWritable(myAggr.indexScore);
partialResult[1] = new DoubleWritable(myAggr.area);
partialResult[2] = new LongWritable(myAggr.fp);
partialResult[3] = new LongWritable(myAggr.tp);
partialResult[4] = new LongWritable(myAggr.fpPrev);
partialResult[5] = new LongWritable(myAggr.tpPrev);
partialResult[6] = myAggr.areaPartialMap;
partialResult[7] = myAggr.fpPartialMap;
partialResult[8] = myAggr.tpPartialMap;
partialResult[9] = myAggr.fpPrevPartialMap;
partialResult[10] = myAggr.tpPrevPartialMap;

return partialResult;
}

Expand All @@ -204,21 +255,53 @@ public void merge(AggregationBuffer agg, Object partial) throws HiveException {
return;
}

Object aObj = internalMergeOI.getStructFieldData(partial, aField);
Object scorePrevObj = internalMergeOI.getStructFieldData(partial, scorePrevField);
Object indexScoreObj = internalMergeOI.getStructFieldData(partial, indexScoreField);
Object areaObj = internalMergeOI.getStructFieldData(partial, areaField);
Object fpObj = internalMergeOI.getStructFieldData(partial, fpField);
Object tpObj = internalMergeOI.getStructFieldData(partial, tpField);
Object fpPrevObj = internalMergeOI.getStructFieldData(partial, fpPrevField);
Object tpPrevObj = internalMergeOI.getStructFieldData(partial, tpPrevField);
double a = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(aObj);
double scorePrev = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(scorePrevObj);
Object areaPartialMapObj = internalMergeOI.getStructFieldData(partial, areaPartialMapField);
Object fpPartialMapObj = internalMergeOI.getStructFieldData(partial, fpPartialMapField);
Object tpPartialMapObj = internalMergeOI.getStructFieldData(partial, tpPartialMapField);
Object fpPrevPartialMapObj = internalMergeOI.getStructFieldData(partial, fpPrevPartialMapField);
Object tpPrevPartialMapObj = internalMergeOI.getStructFieldData(partial, tpPrevPartialMapField);

double indexScore = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(indexScoreObj);
double area = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(areaObj);
long fp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpObj);
long tp = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpObj);
long fpPrev = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(fpPrevObj);
long tpPrev = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(tpPrevObj);

Map<Double, Double> areaPartialMap = (Map<Double, Double>) ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector).getMap(
HiveUtils.castLazyBinaryObject(areaPartialMapObj));

Map<Double, Long> fpPartialMap = (Map<Double, Long>) ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector).getMap(
HiveUtils.castLazyBinaryObject(fpPartialMapObj));

Map<Double, Long> tpPartialMap = (Map<Double, Long>) ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector).getMap(
HiveUtils.castLazyBinaryObject(tpPartialMapObj));

Map<Double, Long> fpPrevPartialMap = (Map<Double, Long>) ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector).getMap(
HiveUtils.castLazyBinaryObject(fpPrevPartialMapObj));

Map<Double, Long> tpPrevPartialMap = (Map<Double, Long>) ObjectInspectorFactory.getStandardMapObjectInspector(
PrimitiveObjectInspectorFactory.javaDoubleObjectInspector,
PrimitiveObjectInspectorFactory.javaLongObjectInspector).getMap(
HiveUtils.castLazyBinaryObject(tpPrevPartialMapObj));

ClassificationAUCAggregationBuffer myAggr = (ClassificationAUCAggregationBuffer) agg;
myAggr.merge(a, scorePrev, fp, tp, fpPrev, tpPrev);
myAggr.merge(indexScore, area, fp, tp, fpPrev, tpPrev,
areaPartialMap, fpPartialMap, tpPartialMap, fpPrevPartialMap, tpPrevPartialMap);
}

@Override
Expand All @@ -232,67 +315,109 @@ public DoubleWritable terminate(AggregationBuffer agg) throws HiveException {

public static class ClassificationAUCAggregationBuffer extends AbstractAggregationBuffer {

double a, scorePrev;
double area, scorePrev, indexScore;
long fp, tp, fpPrev, tpPrev;
Map<Double, Double> areaPartialMap;
Map<Double, Long> fpPartialMap, tpPartialMap, fpPrevPartialMap, tpPrevPartialMap;

public ClassificationAUCAggregationBuffer() {
super();
}

void reset() {
this.a = 0.d;
this.area = 0.d;
this.scorePrev = Double.POSITIVE_INFINITY;
this.indexScore = 0.d;
this.fp = 0;
this.tp = 0;
this.fpPrev = 0;
this.tpPrev = 0;
this.areaPartialMap = new HashMap<Double, Double>();
this.fpPartialMap = new HashMap<Double, Long>();
this.tpPartialMap = new HashMap<Double, Long>();
this.fpPrevPartialMap = new HashMap<Double, Long>();
this.tpPrevPartialMap = new HashMap<Double, Long>();
}

void merge(double o_a, double o_scorePrev, long o_fp, long o_tp, long o_fpPrev,
long o_tpPrev) {
// compute the latest, not scaled AUC
a += trapezoidArea(fp, fpPrev, tp, tpPrev);
o_a += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev);
void merge(double o_indexScore, double o_area, long o_fp, long o_tp, long o_fpPrev, long o_tpPrev,
Map<Double, Double> o_areaPartialMap,
Map<Double, Long> o_fpPartialMap, Map<Double, Long> o_tpPartialMap,
Map<Double, Long> o_fpPrevPartialMap, Map<Double, Long> o_tpPrevPartialMap) {

// merge past partial results
areaPartialMap.putAll(o_areaPartialMap);
fpPartialMap.putAll(o_fpPartialMap);
tpPartialMap.putAll(o_tpPartialMap);
fpPrevPartialMap.putAll(o_fpPrevPartialMap);
tpPrevPartialMap.putAll(o_tpPrevPartialMap);

// finalize source AUC computation
o_area += trapezoidArea(o_fp, o_fpPrev, o_tp, o_tpPrev);

// store source results
areaPartialMap.put(o_indexScore, o_area);
fpPartialMap.put(o_indexScore, o_fp);
tpPartialMap.put(o_indexScore, o_tp);
fpPrevPartialMap.put(o_indexScore, o_fpPrev);
tpPrevPartialMap.put(o_indexScore, o_tpPrev);
}

// sum up the partial areas
a += o_a;
if (scorePrev >= o_scorePrev) { // self is left-side
// adjust combined area by adding missing rectangle
a += trapezoidArea(fp + o_fp, fp, tp, tp);

// combine TP/FP counts; left-side curve should be base
fp += o_fp;
tp += o_tp;
fpPrev = fp + o_fpPrev;
tpPrev = tp + o_tpPrev;
} else { // self is right-side
a = a + trapezoidArea(fp + o_fp, o_fp, o_tp, o_tp);

fp += o_fp;
tp += o_tp;
fpPrev += o_fp;
tpPrev += o_tp;
}
double get() throws HiveException {
// store self results
areaPartialMap.put(indexScore, area);
fpPartialMap.put(indexScore, fp);
tpPartialMap.put(indexScore, tp);
fpPrevPartialMap.put(indexScore, fpPrev);
tpPrevPartialMap.put(indexScore, tpPrev);

SortedMap<Double, Double> areaPartialSortedMap = new TreeMap<Double, Double>(Collections.reverseOrder());
areaPartialSortedMap.putAll(areaPartialMap);

// initialize with leftmost partial result
double firstKey = areaPartialSortedMap.firstKey();
double res = areaPartialSortedMap.get(firstKey);
long fpAccum = fpPartialMap.get(firstKey);
long tpAccum = tpPartialMap.get(firstKey);
long fpPrevAccum = fpPrevPartialMap.get(firstKey);
long tpPrevAccum = tpPrevPartialMap.get(firstKey);

// Merge from left (larger score) to right (smaller score)
for (double k : areaPartialSortedMap.keySet()) {
if (k == firstKey) { // variables are already initialized with the leftmost partial result
continue;
}

// set current appropriate `scorePrev`
scorePrev = Math.min(scorePrev, o_scorePrev);
// sum up partial area
res += areaPartialSortedMap.get(k);

// subtract so that get() works correctly
a -= trapezoidArea(fp, fpPrev, tp, tpPrev);
}
// adjust combined area by adding missing rectangle
res += trapezoidArea(0, fpPartialMap.get(k), tpAccum, tpAccum);

double get() throws HiveException {
if (tp == 0 || fp == 0) {
// sum up (prev) TP/FP count
fpPrevAccum = fpAccum + fpPrevPartialMap.get(k);
tpPrevAccum = tpAccum + tpPrevPartialMap.get(k);
fpAccum = fpAccum + fpPartialMap.get(k);
tpAccum = tpAccum + tpPartialMap.get(k);
}

if (tpAccum == 0 || fpAccum == 0) {
throw new HiveException(
"AUC score is not defined because there is only one class in `label`.");
}
double res = a + trapezoidArea(fp, fpPrev, tp, tpPrev);
return res / (tp * fp); // scale

// finalize by adding a trapezoid based on the last tp/fp counts
res += trapezoidArea(fpAccum, fpPrevAccum, tpAccum, tpPrevAccum);

return res / (tpAccum * fpAccum); // scale
}

void iterate(double score, int label) {
if (score != scorePrev) {
a += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, tp)-(fpPrev, tpPrev)
if (scorePrev == Double.POSITIVE_INFINITY) {
// store maximum score as an index
indexScore = score;
}
area += trapezoidArea(fp, fpPrev, tp, tpPrev); // under (fp, tp)-(fpPrev, tpPrev)
scorePrev = score;
fpPrev = fp;
tpPrev = tp;
Expand Down
11 changes: 11 additions & 0 deletions core/src/main/java/hivemall/utils/hadoop/HiveUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
import org.apache.hadoop.hive.serde2.lazy.LazyInteger;
import org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe;
import org.apache.hadoop.hive.serde2.lazy.LazyString;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryArray;
import org.apache.hadoop.hive.serde2.lazybinary.LazyBinaryMap;
import org.apache.hadoop.hive.serde2.objectinspector.ConstantObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector;
Expand Down Expand Up @@ -987,4 +989,13 @@ public static LazySimpleSerDe getLineSerde(@Nonnull final PrimitiveObjectInspect
serde.initialize(conf, tbl);
return serde;
}

public static Object castLazyBinaryObject(@Nonnull final Object obj) {
if (obj instanceof LazyBinaryMap) {
return ((LazyBinaryMap) obj).getMap();
} else if (obj instanceof LazyBinaryArray) {
return ((LazyBinaryArray) obj).getList();
}
return obj;
}
}
Loading

0 comments on commit 5f29328

Please sign in to comment.