diff --git a/core/src/main/java/hivemall/HivemallConstants.java b/core/src/main/java/hivemall/HivemallConstants.java index 0eb9febbd..67bb228f6 100644 --- a/core/src/main/java/hivemall/HivemallConstants.java +++ b/core/src/main/java/hivemall/HivemallConstants.java @@ -18,6 +18,7 @@ */ package hivemall; + public final class HivemallConstants { public static final String VERSION = "0.4.2-rc.2"; @@ -35,6 +36,7 @@ public final class HivemallConstants { public static final String BIGINT_TYPE_NAME = "bigint"; public static final String FLOAT_TYPE_NAME = "float"; public static final String DOUBLE_TYPE_NAME = "double"; + public static final String DECIMAL_TYPE_NAME = "decimal"; public static final String STRING_TYPE_NAME = "string"; public static final String DATE_TYPE_NAME = "date"; public static final String DATETIME_TYPE_NAME = "datetime"; diff --git a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java index 81cf07585..7c218497e 100644 --- a/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java +++ b/core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java @@ -18,8 +18,12 @@ */ package hivemall.evaluation; +import hivemall.utils.lang.Preconditions; +import hivemall.utils.math.MathUtils; + import java.util.List; +import javax.annotation.Nonnegative; import javax.annotation.Nonnull; /** @@ -40,19 +44,25 @@ private BinaryResponsesMeasures() {} * @return nDCG */ public static double nDCG(@Nonnull final List rankedList, - @Nonnull final List groundTruth, @Nonnull final int recommendSize) { + @Nonnull final List groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + double dcg = 0.d; - double idcg = IDCG(Math.min(recommendSize, groundTruth.size())); - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (!groundTruth.contains(item_id)) { continue; } int rank = i + 1; - dcg += Math.log(2) / Math.log(rank + 1); + dcg += 1.d / MathUtils.log2(rank + 1); } + final double idcg = IDCG(Math.min(groundTruth.size(), k)); + if (idcg == 0.d) { + return 0.d; + } return dcg / idcg; } @@ -62,10 +72,12 @@ public static double nDCG(@Nonnull final List rankedList, * @param n the number of positive items * @return ideal DCG */ - public static double IDCG(final int n) { + public static double IDCG(@Nonnegative final int n) { + Preconditions.checkArgument(n >= 0); + double idcg = 0.d; for (int i = 0; i < n; i++) { - idcg += Math.log(2) / Math.log(i + 2); + idcg += 1.d / MathUtils.log2(i + 2); } return idcg; } @@ -79,8 +91,26 @@ public static double IDCG(final int n) { * @return Precision */ public static double Precision(@Nonnull final List rankedList, - @Nonnull final List groundTruth, @Nonnull final int recommendSize) { - return (double) countTruePositive(rankedList, groundTruth, recommendSize) / recommendSize; + @Nonnull final List groundTruth, @Nonnegative final int recommendSize) { + if (rankedList.isEmpty()) { + if (groundTruth.isEmpty()) { + return 1.d; + } + return 0.d; + } + + Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty + + int nTruePositive = 0; + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { + Object item_id = rankedList.get(i); + if (groundTruth.contains(item_id)) { + nTruePositive++; + } + } + + return ((double) nTruePositive) / k; } /** @@ -92,8 +122,15 @@ public static double Precision(@Nonnull final List rankedList, * @return Recall */ public static double Recall(@Nonnull final List rankedList, - @Nonnull final List groundTruth, @Nonnull final int recommendSize) { - return (double) countTruePositive(rankedList, groundTruth, recommendSize) + @Nonnull final List groundTruth, @Nonnegative final int recommendSize) { + if (groundTruth.isEmpty()) { + if (rankedList.isEmpty()) { + return 1.d; + } + return 0.d; + } + + return ((double) TruePositives(rankedList, groundTruth, recommendSize)) / groundTruth.size(); } @@ -105,11 +142,14 @@ public static double Recall(@Nonnull final List rankedList, * @param recommendSize top-`recommendSize` items in `rankedList` are recommended * @return number of true positives */ - public static int countTruePositive(final List rankedList, final List groundTruth, - final int recommendSize) { + public static int TruePositives(final List rankedList, final List groundTruth, + @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + int nTruePositive = 0; - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { nTruePositive++; @@ -120,48 +160,65 @@ public static int countTruePositive(final List rankedList, final List grou } /** - * Computes Mean Reciprocal Rank (MRR) + * Computes Reciprocal Rank * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended - * @return MRR + * @return Reciprocal Rank + * @link https://en.wikipedia.org/wiki/Mean_reciprocal_rank */ - public static double MRR(@Nonnull final List rankedList, @Nonnull final List groundTruth, - @Nonnull final int recommendSize) { - for (int i = 0, n = recommendSize; i < n; i++) { + public static double ReciprocalRank(@Nonnull final List rankedList, + @Nonnull final List groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { - return 1.0 / (i + 1.0); + return 1.d / (i + 1); } } - return 0.0; + return 0.d; } /** - * Computes Mean Average Precision (MAP) + * Computes Average Precision (AP) * * @param rankedList a list of ranked item IDs (first item is highest-ranked) * @param groundTruth a collection of positive/correct item IDs * @param recommendSize top-`recommendSize` items in `rankedList` are recommended - * @return MAP + * @return AveragePrecision */ - public static double MAP(@Nonnull final List rankedList, @Nonnull final List groundTruth, - @Nonnull final int recommendSize) { + public static double AveragePrecision(@Nonnull final List rankedList, + @Nonnull final List groundTruth, @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + + if (groundTruth.isEmpty()) { + if (rankedList.isEmpty()) { + return 1.d; + } + return 0.d; + } + int nTruePositive = 0; - double sumPrecision = 0.0; + double sumPrecision = 0.d; // accumulate precision@1 to @recommendSize - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { nTruePositive++; - sumPrecision += nTruePositive / (i + 1.0); + sumPrecision += nTruePositive / (i + 1.d); } } - return sumPrecision / groundTruth.size(); + if (nTruePositive == 0) { + return 0.d; + } + return sumPrecision / nTruePositive; } /** @@ -173,11 +230,14 @@ public static double MAP(@Nonnull final List rankedList, @Nonnull final List< * @return AUC */ public static double AUC(@Nonnull final List rankedList, @Nonnull final List groundTruth, - @Nonnull final int recommendSize) { + @Nonnegative final int recommendSize) { + Preconditions.checkArgument(recommendSize > 0); + int nTruePositive = 0, nCorrectPairs = 0; // count # of pairs of items that are ranked in the correct order (i.e. TP > FP) - for (int i = 0, n = recommendSize; i < n; i++) { + final int k = Math.min(rankedList.size(), recommendSize); + for (int i = 0; i < k; i++) { Object item_id = rankedList.get(i); if (groundTruth.contains(item_id)) { // # of true positives which are ranked higher position than i-th recommended item @@ -197,7 +257,7 @@ public static double AUC(@Nonnull final List rankedList, @Nonnull final List< } // AUC can equivalently be calculated by counting the portion of correctly ordered pairs - return (double) nCorrectPairs / nPairs; + return ((double) nCorrectPairs) / nPairs; } } diff --git a/core/src/main/java/hivemall/evaluation/MAPUDAF.java b/core/src/main/java/hivemall/evaluation/MAPUDAF.java index cac6de5a2..38786847e 100644 --- a/core/src/main/java/hivemall/evaluation/MAPUDAF.java +++ b/core/src/main/java/hivemall/evaluation/MAPUDAF.java @@ -235,7 +235,7 @@ void merge(double o_sum, long o_count) { void iterate(@Nonnull List recommendList, @Nonnull List truthList, @Nonnull int recommendSize) { - sum += BinaryResponsesMeasures.MAP(recommendList, truthList, recommendSize); + sum += BinaryResponsesMeasures.AveragePrecision(recommendList, truthList, recommendSize); count++; } } diff --git a/core/src/main/java/hivemall/evaluation/MRRUDAF.java b/core/src/main/java/hivemall/evaluation/MRRUDAF.java index 41a236d24..f5aba3ba4 100644 --- a/core/src/main/java/hivemall/evaluation/MRRUDAF.java +++ b/core/src/main/java/hivemall/evaluation/MRRUDAF.java @@ -235,7 +235,7 @@ void merge(double o_sum, long o_count) { void iterate(@Nonnull List recommendList, @Nonnull List truthList, @Nonnull int recommendSize) { - sum += BinaryResponsesMeasures.MRR(recommendList, truthList, recommendSize); + sum += BinaryResponsesMeasures.ReciprocalRank(recommendList, truthList, recommendSize); count++; } } diff --git a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java index f50d27a51..f1ba8320d 100644 --- a/core/src/main/java/hivemall/evaluation/NDCGUDAF.java +++ b/core/src/main/java/hivemall/evaluation/NDCGUDAF.java @@ -18,6 +18,8 @@ */ package hivemall.evaluation; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector; +import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector; import hivemall.utils.hadoop.HiveUtils; import java.util.ArrayList; @@ -38,10 +40,11 @@ import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructField; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; -import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils; import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector; import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo; import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo; @@ -120,8 +123,8 @@ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws Hive } private static StructObjectInspector internalMergeOI() { - ArrayList fieldNames = new ArrayList(); - ArrayList fieldOIs = new ArrayList(); + List fieldNames = new ArrayList<>(); + List fieldOIs = new ArrayList<>(); fieldNames.add("sum"); fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector); @@ -180,20 +183,31 @@ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg, StructObjectInspector sOI = (StructObjectInspector) recommendListOI.getListElementObjectInspector(); List fieldRefList = sOI.getAllStructFieldRefs(); StructField relScoreField = (StructField) fieldRefList.get(0); - WritableDoubleObjectInspector relScoreFieldOI = (WritableDoubleObjectInspector) relScoreField.getFieldObjectInspector(); + PrimitiveObjectInspector relScoreFieldOI = HiveUtils.asDoubleCompatibleOI(relScoreField.getFieldObjectInspector()); for (int i = 0, n = recommendList.size(); i < n; i++) { Object structObj = recommendList.get(i); List fieldList = sOI.getStructFieldsDataAsList(structObj); - double relScore = (double) relScoreFieldOI.get(fieldList.get(0)); + Object field0 = fieldList.get(0); + if (field0 == null) { + throw new UDFArgumentException("Field 0 of a struct field is null: " + + fieldList); + } + double relScore = PrimitiveObjectInspectorUtils.getDouble(field0, + relScoreFieldOI); recommendRelScoreList.add(relScore); } // Create a ordered list of relevance scores for truth items List truthRelScoreList = new ArrayList(); - WritableDoubleObjectInspector truthRelScoreOI = (WritableDoubleObjectInspector) truthListOI.getListElementObjectInspector(); + PrimitiveObjectInspector truthRelScoreOI = HiveUtils.asDoubleCompatibleOI(truthListOI.getListElementObjectInspector()); for (int i = 0, n = truthList.size(); i < n; i++) { Object relScoreObj = truthList.get(i); - double relScore = (double) truthRelScoreOI.get(relScoreObj); + if (relScoreObj == null) { + throw new UDFArgumentException("Found null in the ground truth: " + + truthList); + } + double relScore = PrimitiveObjectInspectorUtils.getDouble(relScoreObj, + truthRelScoreOI); truthRelScoreList.add(relScore); } @@ -224,8 +238,8 @@ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object Object sumObj = internalMergeOI.getStructFieldData(partial, sumField); Object countObj = internalMergeOI.getStructFieldData(partial, countField); - double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj); - long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj); + double sum = writableDoubleObjectInspector.get(sumObj); + long count = writableLongObjectInspector.get(countObj); NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer) agg; myAggr.merge(sum, count); diff --git a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java index e88a16c76..52c521c57 100644 --- a/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java +++ b/core/src/main/java/hivemall/tools/list/UDAFToOrderedList.java @@ -207,7 +207,7 @@ public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveExce || (argOIs.length == 3 && HiveUtils.isConstString(argOIs[2])); if (sortByKey) { - this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]); + this.valueOI = argOIs[0]; this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]); } else { // sort values by value itself diff --git a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java index 8fba349c4..b8b344c6a 100644 --- a/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java +++ b/core/src/main/java/hivemall/utils/hadoop/HiveUtils.java @@ -21,6 +21,7 @@ import static hivemall.HivemallConstants.BIGINT_TYPE_NAME; import static hivemall.HivemallConstants.BINARY_TYPE_NAME; import static hivemall.HivemallConstants.BOOLEAN_TYPE_NAME; +import static hivemall.HivemallConstants.DECIMAL_TYPE_NAME; import static hivemall.HivemallConstants.DOUBLE_TYPE_NAME; import static hivemall.HivemallConstants.FLOAT_TYPE_NAME; import static hivemall.HivemallConstants.INT_TYPE_NAME; @@ -47,6 +48,7 @@ import org.apache.hadoop.hive.serde2.SerDeException; import org.apache.hadoop.hive.serde2.io.ByteWritable; import org.apache.hadoop.hive.serde2.io.DoubleWritable; +import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable; import org.apache.hadoop.hive.serde2.io.ShortWritable; import org.apache.hadoop.hive.serde2.lazy.ByteArrayRef; import org.apache.hadoop.hive.serde2.lazy.LazyDouble; @@ -265,6 +267,7 @@ public static boolean isNumberOI(@Nonnull final ObjectInspector argOI) { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: case BYTE: //case TIMESTAMP: return true; @@ -357,6 +360,7 @@ public static boolean isNumberTypeInfo(@Nonnull TypeInfo typeInfo) { case LONG: case FLOAT: case DOUBLE: + case DECIMAL: return true; default: return false; @@ -404,6 +408,7 @@ public static boolean isFloatingPointTypeInfo(@Nonnull TypeInfo typeInfo) { switch (((PrimitiveTypeInfo) typeInfo).getPrimitiveCategory()) { case DOUBLE: case FLOAT: + case DECIMAL: return true; default: return false; @@ -630,6 +635,9 @@ public static float getAsConstFloat(@Nonnull final ObjectInspector numberOI) } else if (TINYINT_TYPE_NAME.equals(typeName)) { ByteWritable v = getConstValue(numberOI); return v.get(); + } else if (DECIMAL_TYPE_NAME.equals(typeName)) { + HiveDecimalWritable v = getConstValue(numberOI); + return v.getHiveDecimal().floatValue(); } throw new UDFArgumentException("Unexpected argument type to cast as double: " + TypeInfoUtils.getTypeInfoFromObjectInspector(numberOI)); @@ -656,6 +664,9 @@ public static double getAsConstDouble(@Nonnull final ObjectInspector numberOI) } else if (TINYINT_TYPE_NAME.equals(typeName)) { ByteWritable v = getConstValue(numberOI); return v.get(); + } else if (DECIMAL_TYPE_NAME.equals(typeName)) { + HiveDecimalWritable v = getConstValue(numberOI); + return v.getHiveDecimal().doubleValue(); } throw new UDFArgumentException("Unexpected argument type to cast as double: " + TypeInfoUtils.getTypeInfoFromObjectInspector(numberOI)); @@ -923,10 +934,10 @@ public static PrimitiveObjectInspector asIntCompatibleOI(@Nonnull final ObjectIn case LONG: case FLOAT: case DOUBLE: + case DECIMAL: case BOOLEAN: case BYTE: case STRING: - case DECIMAL: break; default: throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() @@ -951,9 +962,9 @@ public static PrimitiveObjectInspector asLongCompatibleOI(@Nonnull final ObjectI case BOOLEAN: case FLOAT: case DOUBLE: + case DECIMAL: case STRING: case TIMESTAMP: - case DECIMAL: break; default: throw new UDFArgumentTypeException(0, "Unxpected type '" + argOI.getTypeName() @@ -998,6 +1009,7 @@ public static PrimitiveObjectInspector asDoubleCompatibleOI(@Nonnull final Objec case LONG: case FLOAT: case DOUBLE: + case DECIMAL: case STRING: case TIMESTAMP: break; @@ -1020,6 +1032,7 @@ public static PrimitiveObjectInspector asFloatingPointOI(@Nonnull final ObjectIn switch (oi.getPrimitiveCategory()) { case FLOAT: case DOUBLE: + case DECIMAL: break; default: throw new UDFArgumentTypeException(0, @@ -1044,6 +1057,7 @@ public static PrimitiveObjectInspector asNumberOI(@Nonnull final ObjectInspector case LONG: case FLOAT: case DOUBLE: + case DECIMAL: break; default: throw new UDFArgumentTypeException(0, diff --git a/core/src/main/java/hivemall/utils/math/MathUtils.java b/core/src/main/java/hivemall/utils/math/MathUtils.java index 6162adb10..ee533dcd3 100644 --- a/core/src/main/java/hivemall/utils/math/MathUtils.java +++ b/core/src/main/java/hivemall/utils/math/MathUtils.java @@ -43,6 +43,7 @@ import org.apache.commons.math3.special.Gamma; public final class MathUtils { + private static final double LOG2 = Math.log(2); private MathUtils() {} @@ -246,6 +247,10 @@ public static double log(final double n, final int base) { return Math.log(n) / Math.log(base); } + public static double log2(final double n) { + return Math.log(n) / LOG2; + } + public static int floorDiv(final int x, final int y) { int r = x / y; // if the signs are different and modulo not zero, round down diff --git a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java index 9f8a04ee1..5e8f253ad 100644 --- a/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java +++ b/core/src/test/java/hivemall/evaluation/BinaryResponsesMeasuresTest.java @@ -18,8 +18,8 @@ */ package hivemall.evaluation; -import java.util.Collections; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.junit.Assert; @@ -39,6 +39,18 @@ public void testNDCG() { Assert.assertEquals(0.6131471927654585d, actual, 0.0001d); } + @Test + public void testNDCG2() { + List rankedList = Arrays.asList(3, 2, 1, 6); + List groundTruth = Arrays.asList(1); + + double actual = BinaryResponsesMeasures.nDCG(rankedList, groundTruth, 2); + Assert.assertEquals(0.d, actual, 0.0001d); + + actual = BinaryResponsesMeasures.nDCG(rankedList, groundTruth, 3); + Assert.assertEquals(0.5d, actual, 0.0001d); + } + @Test public void testRecall() { List rankedList = Arrays.asList(1, 3, 2, 6); @@ -51,6 +63,16 @@ public void testRecall() { Assert.assertEquals(0.3333333333333333d, actual, 0.0001d); } + @Test + public void testRecallEmpty() { + Assert.assertEquals(1.d, + BinaryResponsesMeasures.Recall(Collections.emptyList(), Collections.emptyList(), 2), + 0.d); + + Assert.assertEquals(0.d, + BinaryResponsesMeasures.Recall(Arrays.asList(1, 3, 2), Collections.emptyList(), 2), 0.d); + } + @Test public void testPrecision() { List rankedList = Arrays.asList(1, 3, 2, 6); @@ -65,32 +87,91 @@ public void testPrecision() { } @Test - public void testMRR() { + public void testPrecisionEmpty() { + Assert.assertEquals(1.d, + BinaryResponsesMeasures.Precision(Collections.emptyList(), Collections.emptyList(), 2), + 0.d); + + Assert.assertEquals(0.d, + BinaryResponsesMeasures.Precision(Arrays.asList(1, 3, 2), Collections.emptyList(), 2), + 0.d); + } + + @Test + public void testRR() { List rankedList = Arrays.asList(1, 3, 2, 6); List groundTruth = Arrays.asList(1, 2, 4); - double actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, rankedList.size()); + double actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, + rankedList.size()); Assert.assertEquals(1.0d, actual, 0.0001d); Collections.reverse(rankedList); - actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, rankedList.size()); + actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, rankedList.size()); Assert.assertEquals(0.5d, actual, 0.0001d); - actual = BinaryResponsesMeasures.MRR(rankedList, groundTruth, 1); + actual = BinaryResponsesMeasures.ReciprocalRank(rankedList, groundTruth, 1); Assert.assertEquals(0.0d, actual, 0.0001d); } @Test - public void testMAP() { + public void testAP() { List rankedList = Arrays.asList(1, 3, 2, 6); List groundTruth = Arrays.asList(1, 2, 4); - double actual = BinaryResponsesMeasures.MAP(rankedList, groundTruth, rankedList.size()); - Assert.assertEquals(0.5555555555555555d, actual, 0.0001d); + double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, + rankedList.size()); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d); + + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 4); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d); - actual = BinaryResponsesMeasures.MAP(rankedList, groundTruth, 2); - Assert.assertEquals(0.3333333333333333d, actual, 0.0001d); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 3); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 3.0), actual, 0.0001d); + + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 1.0 * (1.0 / 1.0), actual, 0.0001d); + + rankedList = Arrays.asList(3, 1, 2, 6); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 1.0 * (1.0 / 2.0), actual, 0.0001d); + + groundTruth = Arrays.asList(1, 2, 3); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 2.0 * (1.0 / 1.0 + 2.0 / 2.0), actual, 0.0001d); + + rankedList = Arrays.asList(3, 1); + groundTruth = Arrays.asList(1, 2); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 2); + Assert.assertEquals(1.0 / 1.0 * (1.0 / 2.0), actual, 0.0001d); + } + + @Test + public void testAPString() { + List rankedList = Arrays.asList("a", "b", "c", "d", "e", "f", "g"); + List groundTruth = Arrays.asList("a", "x", "x", "d", "x", "x"); + + double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 6); + Assert.assertEquals(0.75d, actual, 0.0001d); + } + + @Test + public void testAPString10() { + List rankedList = Arrays.asList("a", "b", "c", "d", "e", "f", "g", "h", "i", "j"); + List groundTruth = Arrays.asList("a", "x", "c", "x", "e", "f"); + + double actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 10); + Assert.assertEquals(1.0 / 4.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0 + 4.0 / 6.0), actual, + 0.0001d); + + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 5); + Assert.assertEquals(1.0 / 3.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0), actual, 0.0001d); + + groundTruth = Arrays.asList("a", "x", "c", "x", "e", "f", "x", "x", "x", "x"); + actual = BinaryResponsesMeasures.AveragePrecision(rankedList, groundTruth, 10); + Assert.assertEquals(1.0 / 4.0 * (1.0 / 1.0 + 2.0 / 3.0 + 3.0 / 5.0 + 4.0 / 6.0), actual, + 0.0001d); } @Test diff --git a/docs/gitbook/eval/rank.md b/docs/gitbook/eval/rank.md index 207418e57..30d82e563 100644 --- a/docs/gitbook/eval/rank.md +++ b/docs/gitbook/eval/rank.md @@ -83,7 +83,8 @@ with truth as ( rec as ( select userid, - map_values(to_ordered_map(score, itemid, true)) as rec, + -- map_values(to_ordered_map(score, itemid, true)) as rec, + to_ordered_list(itemid, score, '-reverse') as rec, cast(count(itemid) as int) as max_k from dummy_rec group by userid @@ -222,7 +223,7 @@ While the binary response setting simply considers positive-only ranked list of Unlike separated `dummy_truth` and `dummy_rec` table in the binary setting, we assume the following single table named `dummy_recrel` which contains item-$$\mathrm{rel}_n$$ pairs: -| userid | itemid | score
(predicted) | rel
(expected) | +| userid | itemid | score
(predicted) | relscore
(expected) | | :-: | :-: | :-: | :-: | | 1 | 1 | 10.0 | 5.0 | | 1 | 3 | 8.0 | 2.0 | @@ -244,27 +245,31 @@ The function `ndcg()` can take non-binary `truth` values as the second argument: ```sql with truth as ( - select userid, map_keys(to_ordered_map(relscore, itemid, true)) as truth - from dummy_recrel - group by userid + select + userid, + to_ordered_list(relscore, '-reverse') as truth + from + dummy_recrel + group by + userid ), rec as ( select userid, - map_values ( - to_ordered_map(score, struct(relscore, itemid), true) - ) as rec, - cast(count(itemid) as int) as max_k - from dummy_recrel - group by userid + to_ordered_list(struct(relscore, itemid), score, "-reverse") as rec, + count(itemid) as max_k + from + dummy_recrel + group by + userid ) select -- top-2 recommendation ndcg(t1.rec, t2.truth, 2), -- => 0.8128912838590544 - -- top-3 recommendation ndcg(t1.rec, t2.truth, 3) -- => 0.9187707805346093 -from rec t1 -join truth t2 on (t1.userid = t2.userid) +from + rec t1 + join truth t2 on (t1.userid = t2.userid) ; ```