Skip to content
Permalink
Browse files
Merge pull request #326 from takuti/ranking-measures
Implement additional ranking measures
  • Loading branch information
myui committed Sep 6, 2016
2 parents a4729ea + 8b199f4 commit 450ef8deda9a8343f13c4e71d5edd82da96b10a9
Showing 11 changed files with 1,526 additions and 34 deletions.
@@ -0,0 +1,228 @@
/*
* Hivemall: Hive scalable Machine Learning Library
*
* Copyright (C) 2015 Makoto YUI
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package hivemall.evaluation;

import hivemall.utils.hadoop.HiveUtils;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer;
import org.apache.hadoop.hive.serde2.io.DoubleWritable;
import org.apache.hadoop.hive.serde2.objectinspector.*;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.io.LongWritable;

import javax.annotation.Nonnull;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

@Description(
name = "auc",
value = "_FUNC_(array rankItems, array correctItems [, const int recommendSize = rankItems.size])"
+ " - Returns AUC")
public final class AUCUDAF extends AbstractGenericUDAFResolver {

// prevent instantiation
private AUCUDAF() {}

@Override
public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException {
if (typeInfo.length != 2 && typeInfo.length != 3) {
throw new UDFArgumentTypeException(typeInfo.length - 1,
"_FUNC_ takes two or three arguments");
}

ListTypeInfo arg1type = HiveUtils.asListTypeInfo(typeInfo[0]);
if (!HiveUtils.isPrimitiveTypeInfo(arg1type.getListElementTypeInfo()) &&
!HiveUtils.isStructTypeInfo(arg1type.getListElementTypeInfo())) {
throw new UDFArgumentTypeException(0,
"The first argument `array rankItems` is invalid form: " + typeInfo[0]);
}
ListTypeInfo arg2type = HiveUtils.asListTypeInfo(typeInfo[1]);
if (!HiveUtils.isPrimitiveTypeInfo(arg2type.getListElementTypeInfo())) {
throw new UDFArgumentTypeException(1,
"The first argument `array rankItems` is invalid form: " + typeInfo[1]);
}

return new Evaluator();
}

public static class Evaluator extends GenericUDAFEvaluator {

private ListObjectInspector recommendListOI;
private ListObjectInspector truthListOI;
private WritableIntObjectInspector recommendSizeOI;

private StructObjectInspector internalMergeOI;
private StructField countField;
private StructField sumField;

public Evaluator() {}

@Override
public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws HiveException {
assert (parameters.length == 2 || parameters.length == 3) : parameters.length;
super.init(mode, parameters);

// initialize input
if (mode == Mode.PARTIAL1 || mode == Mode.COMPLETE) {// from original data
this.recommendListOI = (ListObjectInspector) parameters[0];
this.truthListOI = (ListObjectInspector) parameters[1];
if (parameters.length == 3) {
this.recommendSizeOI = (WritableIntObjectInspector) parameters[2];
}
} else {// from partial aggregation
StructObjectInspector soi = (StructObjectInspector) parameters[0];
this.internalMergeOI = soi;
this.countField = soi.getStructFieldRef("count");
this.sumField = soi.getStructFieldRef("sum");
}

// initialize output
final ObjectInspector outputOI;
if (mode == Mode.PARTIAL1 || mode == Mode.PARTIAL2) {// terminatePartial
outputOI = internalMergeOI();
} else {// terminate
outputOI = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
}
return outputOI;
}

private static StructObjectInspector internalMergeOI() {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();

fieldNames.add("sum");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
fieldNames.add("count");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableLongObjectInspector);

return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
}

@Override
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
AggregationBuffer myAggr = new AUCAggregationBuffer();
reset(myAggr);
return myAggr;
}

@Override
public void reset(AggregationBuffer agg) throws HiveException {
AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
myAggr.reset();
}

@Override
public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;

List<?> recommendList = recommendListOI.getList(parameters[0]);
if (recommendList == null) {
recommendList = Collections.emptyList();
}
List<?> truthList = truthListOI.getList(parameters[1]);
if (truthList == null) {
return;
}

int recommendSize = recommendList.size();
if (parameters.length == 3) {
recommendSize = recommendSizeOI.get(parameters[2]);
}
if (recommendSize < 0 || recommendSize > recommendList.size()) {
throw new UDFArgumentException(
"The third argument `int recommendSize` must be in [0, " + recommendList.size() + "]");
}

myAggr.iterate(recommendList, truthList, recommendSize);
}

@Override
public Object terminatePartial(AggregationBuffer agg) throws HiveException {
AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;

Object[] partialResult = new Object[2];
partialResult[0] = new DoubleWritable(myAggr.sum);
partialResult[1] = new LongWritable(myAggr.count);
return partialResult;
}

@Override
public void merge(AggregationBuffer agg, Object partial) throws HiveException {
if (partial == null) {
return;
}

Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
Object countObj = internalMergeOI.getStructFieldData(partial, countField);
double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);

AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
myAggr.merge(sum, count);
}

@Override
public DoubleWritable terminate(AggregationBuffer agg) throws HiveException {
AUCAggregationBuffer myAggr = (AUCAggregationBuffer) agg;
double result = myAggr.get();
return new DoubleWritable(result);
}

}

public static class AUCAggregationBuffer implements AggregationBuffer {

double sum;
long count;

public AUCAggregationBuffer() {}

void reset() {
this.sum = 0.d;
this.count = 0;
}

void merge(double o_sum, long o_count) {
sum += o_sum;
count += o_count;
}

double get() {
if (count == 0) {
return 0.d;
}
return sum / count;
}

void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList, @Nonnull int recommendSize) {
sum += BinaryResponsesMeasures.AUC(recommendList, truthList, recommendSize);
count++;
}
}

}
@@ -23,19 +23,31 @@
import javax.annotation.Nonnull;

/**
* Utility class of various measures.
*
* See http://recsyswiki.com/wiki/Discounted_Cumulative_Gain
* Binary responses measures for item recommendation (i.e. ranking problems)
*
* References:
* B. McFee and G. R. Lanckriet. "Metric Learning to Rank" ICML 2010.
* MyMediaLite http://mymedialite.net/
* LibRec http://www.librec.net/
*/
public final class BinaryResponsesMeasures {

private BinaryResponsesMeasures() {}

public static double nDCG(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth) {
/**
* Computes binary nDCG (i.e. relevance score is 0 or 1)
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return nDCG
*/
public static double nDCG(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
double dcg = 0.d;
double idcg = IDCG(groundTruth.size());
double idcg = IDCG(Math.min(recommendSize, groundTruth.size()));

for (int i = 0, n = rankedList.size(); i < n; i++) {
for (int i = 0, n = recommendSize; i < n; i++) {
Object item_id = rankedList.get(i);
if (!groundTruth.contains(item_id)) {
continue;
@@ -61,4 +73,127 @@ public static double IDCG(final int n) {
return idcg;
}

/**
* Computes Precision@`recommendSize`
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return Precision
*/
public static double Precision(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
return (double) countTruePositive(rankedList, groundTruth, recommendSize) / recommendSize;
}

/**
* Computes Recall@`recommendSize`
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return Recall
*/
public static double Recall(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
return (double) countTruePositive(rankedList, groundTruth, recommendSize) / groundTruth.size();
}

/**
* Counts the number of true positives
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return number of true positives
*/
public static int countTruePositive(final List<?> rankedList, final List<?> groundTruth, final int recommendSize) {
int nTruePositive = 0;

for (int i = 0, n = recommendSize; i < n; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
}
}

return nTruePositive;
}

/**
* Computes Mean Reciprocal Rank (MRR)
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return MRR
*/
public static double MRR(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
for (int i = 0, n = recommendSize; i < n; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
return 1.0 / (i + 1.0);
}
}

return 0.0;
}

/**
* Computes Mean Average Precision (MAP)
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return MAP
*/
public static double MAP(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
int nTruePositive = 0;
double sumPrecision = 0.0;

// accumulate precision@1 to @recommendSize
for (int i = 0, n = recommendSize; i < n; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
sumPrecision += nTruePositive / (i + 1.0);
}
}

return sumPrecision / groundTruth.size();
}

/**
* Computes the area under the ROC curve (AUC)
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return AUC
*/
public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
int nTruePositive = 0, nCorrectPairs = 0;

// count # of pairs of items that are ranked in the correct order (i.e. TP > FP)
for (int i = 0, n = recommendSize; i < n; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
// # of true positives which are ranked higher position than i-th recommended item
nTruePositive++;
} else {
// for each FP item, # of correct ordered <TP, FP> pairs equals to # of TPs at i-th position
nCorrectPairs += nTruePositive;
}
}

// # of all possible <TP, FP> pairs
int nPairs = nTruePositive * (recommendSize - nTruePositive);

// AUC can equivalently be calculated by counting the portion of correctly ordered pairs
return (double) nCorrectPairs / nPairs;
}

}

0 comments on commit 450ef8d

Please sign in to comment.