Skip to content
This repository has been archived by the owner on Sep 20, 2022. It is now read-only.

Commit

Permalink
c2b9578 replay with gib
Browse files Browse the repository at this point in the history
  • Loading branch information
DrRacket committed Oct 18, 2017
1 parent 36ce3a3 commit ba8621c
Show file tree
Hide file tree
Showing 10 changed files with 250 additions and 69 deletions.
2 changes: 2 additions & 0 deletions core/src/main/java/hivemall/HivemallConstants.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
*/
package hivemall;


public final class HivemallConstants {

public static final String VERSION = "0.4.2-rc.2";
Expand All @@ -35,6 +36,7 @@ public final class HivemallConstants {
public static final String BIGINT_TYPE_NAME = "bigint";
public static final String FLOAT_TYPE_NAME = "float";
public static final String DOUBLE_TYPE_NAME = "double";
public static final String DECIMAL_TYPE_NAME = "decimal";
public static final String STRING_TYPE_NAME = "string";
public static final String DATE_TYPE_NAME = "date";
public static final String DATETIME_TYPE_NAME = "datetime";
Expand Down
122 changes: 91 additions & 31 deletions core/src/main/java/hivemall/evaluation/BinaryResponsesMeasures.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
*/
package hivemall.evaluation;

import hivemall.utils.lang.Preconditions;
import hivemall.utils.math.MathUtils;

import java.util.List;

import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;

/**
Expand All @@ -40,19 +44,25 @@ private BinaryResponsesMeasures() {}
* @return nDCG
*/
public static double nDCG(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) {
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
Preconditions.checkArgument(recommendSize > 0);

double dcg = 0.d;
double idcg = IDCG(Math.min(recommendSize, groundTruth.size()));

for (int i = 0, n = recommendSize; i < n; i++) {
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (!groundTruth.contains(item_id)) {
continue;
}
int rank = i + 1;
dcg += Math.log(2) / Math.log(rank + 1);
dcg += 1.d / MathUtils.log2(rank + 1);
}

final double idcg = IDCG(Math.min(groundTruth.size(), k));
if (idcg == 0.d) {
return 0.d;
}
return dcg / idcg;
}

Expand All @@ -62,10 +72,12 @@ public static double nDCG(@Nonnull final List<?> rankedList,
* @param n the number of positive items
* @return ideal DCG
*/
public static double IDCG(final int n) {
public static double IDCG(@Nonnegative final int n) {
Preconditions.checkArgument(n >= 0);

double idcg = 0.d;
for (int i = 0; i < n; i++) {
idcg += Math.log(2) / Math.log(i + 2);
idcg += 1.d / MathUtils.log2(i + 2);
}
return idcg;
}
Expand All @@ -79,8 +91,26 @@ public static double IDCG(final int n) {
* @return Precision
*/
public static double Precision(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) {
return (double) countTruePositive(rankedList, groundTruth, recommendSize) / recommendSize;
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
if (rankedList.isEmpty()) {
if (groundTruth.isEmpty()) {
return 1.d;
}
return 0.d;
}

Preconditions.checkArgument(recommendSize > 0); // can be zero when groundTruth is empty

int nTruePositive = 0;
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
}
}

return ((double) nTruePositive) / k;
}

/**
Expand All @@ -92,8 +122,15 @@ public static double Precision(@Nonnull final List<?> rankedList,
* @return Recall
*/
public static double Recall(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnull final int recommendSize) {
return (double) countTruePositive(rankedList, groundTruth, recommendSize)
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
if (groundTruth.isEmpty()) {
if (rankedList.isEmpty()) {
return 1.d;
}
return 0.d;
}

return ((double) TruePositives(rankedList, groundTruth, recommendSize))
/ groundTruth.size();
}

Expand All @@ -105,11 +142,14 @@ public static double Recall(@Nonnull final List<?> rankedList,
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return number of true positives
*/
public static int countTruePositive(final List<?> rankedList, final List<?> groundTruth,
final int recommendSize) {
public static int TruePositives(final List<?> rankedList, final List<?> groundTruth,
@Nonnegative final int recommendSize) {
Preconditions.checkArgument(recommendSize > 0);

int nTruePositive = 0;

for (int i = 0, n = recommendSize; i < n; i++) {
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
Expand All @@ -120,48 +160,65 @@ public static int countTruePositive(final List<?> rankedList, final List<?> grou
}

/**
* Computes Mean Reciprocal Rank (MRR)
* Computes Reciprocal Rank
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return MRR
* @return Reciprocal Rank
* @link https://en.wikipedia.org/wiki/Mean_reciprocal_rank
*/
public static double MRR(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
for (int i = 0, n = recommendSize; i < n; i++) {
public static double ReciprocalRank(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
Preconditions.checkArgument(recommendSize > 0);

final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
return 1.0 / (i + 1.0);
return 1.d / (i + 1);
}
}

return 0.0;
return 0.d;
}

/**
* Computes Mean Average Precision (MAP)
* Computes Average Precision (AP)
*
* @param rankedList a list of ranked item IDs (first item is highest-ranked)
* @param groundTruth a collection of positive/correct item IDs
* @param recommendSize top-`recommendSize` items in `rankedList` are recommended
* @return MAP
* @return AveragePrecision
*/
public static double MAP(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
public static double AveragePrecision(@Nonnull final List<?> rankedList,
@Nonnull final List<?> groundTruth, @Nonnegative final int recommendSize) {
Preconditions.checkArgument(recommendSize > 0);

if (groundTruth.isEmpty()) {
if (rankedList.isEmpty()) {
return 1.d;
}
return 0.d;
}

int nTruePositive = 0;
double sumPrecision = 0.0;
double sumPrecision = 0.d;

// accumulate precision@1 to @recommendSize
for (int i = 0, n = recommendSize; i < n; i++) {
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
nTruePositive++;
sumPrecision += nTruePositive / (i + 1.0);
sumPrecision += nTruePositive / (i + 1.d);
}
}

return sumPrecision / groundTruth.size();
if (nTruePositive == 0) {
return 0.d;
}
return sumPrecision / nTruePositive;
}

/**
Expand All @@ -173,11 +230,14 @@ public static double MAP(@Nonnull final List<?> rankedList, @Nonnull final List<
* @return AUC
*/
public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<?> groundTruth,
@Nonnull final int recommendSize) {
@Nonnegative final int recommendSize) {
Preconditions.checkArgument(recommendSize > 0);

int nTruePositive = 0, nCorrectPairs = 0;

// count # of pairs of items that are ranked in the correct order (i.e. TP > FP)
for (int i = 0, n = recommendSize; i < n; i++) {
final int k = Math.min(rankedList.size(), recommendSize);
for (int i = 0; i < k; i++) {
Object item_id = rankedList.get(i);
if (groundTruth.contains(item_id)) {
// # of true positives which are ranked higher position than i-th recommended item
Expand All @@ -197,7 +257,7 @@ public static double AUC(@Nonnull final List<?> rankedList, @Nonnull final List<
}

// AUC can equivalently be calculated by counting the portion of correctly ordered pairs
return (double) nCorrectPairs / nPairs;
return ((double) nCorrectPairs) / nPairs;
}

}
2 changes: 1 addition & 1 deletion core/src/main/java/hivemall/evaluation/MAPUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void merge(double o_sum, long o_count) {

void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
@Nonnull int recommendSize) {
sum += BinaryResponsesMeasures.MAP(recommendList, truthList, recommendSize);
sum += BinaryResponsesMeasures.AveragePrecision(recommendList, truthList, recommendSize);
count++;
}
}
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/hivemall/evaluation/MRRUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ void merge(double o_sum, long o_count) {

void iterate(@Nonnull List<?> recommendList, @Nonnull List<?> truthList,
@Nonnull int recommendSize) {
sum += BinaryResponsesMeasures.MRR(recommendList, truthList, recommendSize);
sum += BinaryResponsesMeasures.ReciprocalRank(recommendList, truthList, recommendSize);
count++;
}
}
Expand Down
32 changes: 23 additions & 9 deletions core/src/main/java/hivemall/evaluation/NDCGUDAF.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
package hivemall.evaluation;

import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableDoubleObjectInspector;
import static org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory.writableLongObjectInspector;
import hivemall.utils.hadoop.HiveUtils;

import java.util.ArrayList;
Expand All @@ -38,10 +40,11 @@
import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StructField;
import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableDoubleObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.WritableIntObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
Expand Down Expand Up @@ -120,8 +123,8 @@ public ObjectInspector init(Mode mode, ObjectInspector[] parameters) throws Hive
}

private static StructObjectInspector internalMergeOI() {
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
List<String> fieldNames = new ArrayList<>();
List<ObjectInspector> fieldOIs = new ArrayList<>();

fieldNames.add("sum");
fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
Expand Down Expand Up @@ -180,20 +183,31 @@ public void iterate(@SuppressWarnings("deprecation") AggregationBuffer agg,
StructObjectInspector sOI = (StructObjectInspector) recommendListOI.getListElementObjectInspector();
List<?> fieldRefList = sOI.getAllStructFieldRefs();
StructField relScoreField = (StructField) fieldRefList.get(0);
WritableDoubleObjectInspector relScoreFieldOI = (WritableDoubleObjectInspector) relScoreField.getFieldObjectInspector();
PrimitiveObjectInspector relScoreFieldOI = HiveUtils.asDoubleCompatibleOI(relScoreField.getFieldObjectInspector());
for (int i = 0, n = recommendList.size(); i < n; i++) {
Object structObj = recommendList.get(i);
List<Object> fieldList = sOI.getStructFieldsDataAsList(structObj);
double relScore = (double) relScoreFieldOI.get(fieldList.get(0));
Object field0 = fieldList.get(0);
if (field0 == null) {
throw new UDFArgumentException("Field 0 of a struct field is null: "
+ fieldList);
}
double relScore = PrimitiveObjectInspectorUtils.getDouble(field0,
relScoreFieldOI);
recommendRelScoreList.add(relScore);
}

// Create a ordered list of relevance scores for truth items
List<Double> truthRelScoreList = new ArrayList<Double>();
WritableDoubleObjectInspector truthRelScoreOI = (WritableDoubleObjectInspector) truthListOI.getListElementObjectInspector();
PrimitiveObjectInspector truthRelScoreOI = HiveUtils.asDoubleCompatibleOI(truthListOI.getListElementObjectInspector());
for (int i = 0, n = truthList.size(); i < n; i++) {
Object relScoreObj = truthList.get(i);
double relScore = (double) truthRelScoreOI.get(relScoreObj);
if (relScoreObj == null) {
throw new UDFArgumentException("Found null in the ground truth: "
+ truthList);
}
double relScore = PrimitiveObjectInspectorUtils.getDouble(relScoreObj,
truthRelScoreOI);
truthRelScoreList.add(relScore);
}

Expand Down Expand Up @@ -224,8 +238,8 @@ public void merge(@SuppressWarnings("deprecation") AggregationBuffer agg, Object

Object sumObj = internalMergeOI.getStructFieldData(partial, sumField);
Object countObj = internalMergeOI.getStructFieldData(partial, countField);
double sum = PrimitiveObjectInspectorFactory.writableDoubleObjectInspector.get(sumObj);
long count = PrimitiveObjectInspectorFactory.writableLongObjectInspector.get(countObj);
double sum = writableDoubleObjectInspector.get(sumObj);
long count = writableLongObjectInspector.get(countObj);

NDCGAggregationBuffer myAggr = (NDCGAggregationBuffer) agg;
myAggr.merge(sum, count);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ public ObjectInspector init(Mode mode, ObjectInspector[] argOIs) throws HiveExce
|| (argOIs.length == 3 && HiveUtils.isConstString(argOIs[2]));

if (sortByKey) {
this.valueOI = HiveUtils.asPrimitiveObjectInspector(argOIs[0]);
this.valueOI = argOIs[0];
this.keyOI = HiveUtils.asPrimitiveObjectInspector(argOIs[1]);
} else {
// sort values by value itself
Expand Down
Loading

0 comments on commit ba8621c

Please sign in to comment.